Find Cointegrating Pairs To Trade

In this research I am going to search for cointegrating security pairs that can potentially be used in a pair trading strategy. I am going to fetch the data from the NASDAQ official web site.

Import the necessary Python modules

%matplotlib inline

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.stattools import coint, adfuller
import requests
import csv
from pathlib import Path
import time
import datetime
from os import listdir
from os.path import isfile, join
# -----------------------------------------------------------------------------
# Set the parameters
# -----------------------------------------------------------------------------
now = datetime.datetime.now()

start_date = '2010-01-01'
end_date = '{0}-{1}-{2}'.format(now.year, now.month, now.day)
entry_condition_long = -1
entry_condition_short = 1
abs_condition_exit = 0.5
s1_symbol = 'v'
s2_symbol = 'ma'
# -----------------------------------------------------------------------------
# Function to fetch symbols tickers form Nasdaq
# -----------------------------------------------------------------------------
def fetch_symbols(exchange:str):
    """
    Get list of ticker symbols from given exchange.
    Returns dict['AAPL'] -> {'name': 'Apple Inc', 'sector': 'Technology'}
    """
    url = "https://www.nasdaq.com/screening/companies-by-industry.aspx?exchange={0}&render=download".format(exchange)
    response = requests.get(url)
    content = response.content.decode('utf-8')

    csv_reader = csv.reader(content.splitlines(), delimiter=',')
    symbols = list(csv_reader)

    stocks = {}
    for row in symbols[1:]:
        symbol = row[0].strip()
        stocks[symbol] = {'name': row[1], 'sector': row[6]}

    return stocks
# -----------------------------------------------------------------------------
# Call fetch symbols and save results to CSV file
# -----------------------------------------------------------------------------
csv_path_symbols_industry = "symbols.csv"

symbols_industry_file = Path(csv_path_symbols_industry)
if not symbols_industry_file.is_file():
    symbols_industry = fetch_symbols('NASDAQ')
    
    # Write CSV 
    with open(csv_path_symbols_industry, 'w') as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(['symbol', 'name', 'industry'])
        for symbol in symbols_industry.keys():
           writer.writerow((symbol, symbols_industry[symbol]['name'], symbols_industry[symbol]['sector']))

symbols_industry_df = pd.read_csv(csv_path_symbols_industry, index_col='symbol', sep=',', encoding = 'utf8')
symbols = list(symbols_industry_df.index)
# -----------------------------------------------------------------------------
# Function to fetch 10 year historical prices form Nasdaq and save them as CSV
# -----------------------------------------------------------------------------
def historical_prices(symbol:str):
    symbol = symbol.lower()
    
    print("%s " % symbol, end = '')
    url = "https://www.nasdaq.com/symbol/{0}/historical".format(symbol)
    headers = {'content-type' : 'application/json'}
    data = "10y|true|{0}".format(symbol)
    
    historical_prices_df = pd.DataFrame(columns=('date', 'close', 'volume', 'open', 'high', 'low'))
    resp = requests.post(url, data=data, headers=headers)
    reader = csv.reader(resp.text.split('\n'), delimiter=',')

    # Skip the first two rows
    next(reader)
    next(reader)
    
    for row in reader:
        if row:
            try:
                historical_prices_df.loc[len(historical_prices_df)] = row
            except ValueError as e:
                print("Error with row: %s for symbol '%s'\nError message: %s" % (row, symbol, e))
     
    # Create a date index
    historical_prices_df.set_index('date', inplace=True)
    historical_prices_df.index = pd.to_datetime(historical_prices_df.index, format = '%Y/%m/%d')
    
    return historical_prices_df
# -----------------------------------------------------------------------------
# Function to initiate fetch of historical prices for all symbols 
# and write results to CSV
# -----------------------------------------------------------------------------
def generate_price_csv(max_symbols=20):
    count = 0
    for symbol in symbols:
        # Wait 2 seconds till next request
        time.sleep(2)
        
        # Fetch prices from Nasdaq and save to CSV if CSV doesn't exists
        csv_path_symbol = "tickers/{0}_10y.csv".format(symbol)

        file = Path(csv_path_symbol)
        if not file.is_file():
            df = historical_prices(symbol)
            df.to_csv(csv_path_symbol, sep=',', encoding='utf-8')
            
        count += 1
        if count == max_symbols:
            break
generate_price_csv()
# -----------------------------------------------------------------------------
# Function to read prices CSV and return a pandas DataFrame
# -----------------------------------------------------------------------------
def read_prices_csv_as_df(path):
    file = Path(path)
    if file.is_file():
        # Read s1 prices
        df = pd.read_csv(path, index_col='date', sep=',', encoding = 'utf8')
        df.index = pd.to_datetime(df.index, format = '%Y-%m-%d')
        
        return df
    
    else:
        return pd.DataFrame()
# -----------------------------------------------------------------------------
# Check for cointegration and write results to CSV
# -----------------------------------------------------------------------------

# Try only symbols for which we already have a CSV file with price data
symbols_csvs = [f for f in listdir('tickers') if isfile(join('tickers', f)) and f.endswith('.csv')]
available_symbols = [filename.split('_', 1)[0] for filename in symbols_csvs]

# Create CSV file to write the results
csv_file_pairs = open('pairs.csv', 'w', 1)
writer = csv.writer(csv_file_pairs)
writer.writerow(['s1_symbol', 's2_symbol', 'industry', 'p-value'])

for i, s1_symbol in enumerate(available_symbols):
    for j, s2_symbol in enumerate(available_symbols[i+1:]):
        s1_df = read_prices_csv_as_df("tickers/{0}_10y.csv".format(s1_symbol))
        s2_df = read_prices_csv_as_df("tickers/{0}_10y.csv".format(s2_symbol))
        
        if (not s1_df.empty) and (not s2_df.empty):
            # Find the date/index intersection of the two data sets
            s1_s2_index_intersection = s1_df.index.intersection(s2_df.index)
            s1_df = s1_df.reindex(s1_s2_index_intersection)
            s2_df = s2_df.reindex(s1_s2_index_intersection)
            
            calendar_dates = pd.date_range(start=start_date, end=end_date, freq='D', tz=None)
            s1_df = s1_df.reindex(calendar_dates)
            s2_df = s2_df.reindex(calendar_dates)

            s1_df = s1_df[start_date:end_date]
            s2_df = s2_df[start_date:end_date]
            
            s1_df = s1_df.dropna()
            s2_df = s2_df.dropna()

            industry_s1 = symbols_industry_df.loc[s1_symbol]['industry']
            industry_s2 = symbols_industry_df.loc[s2_symbol]['industry']
            
            if industry_s1 == industry_s2 and (len(s1_df['close'])>=250 and len(s2_df['close'])>=250):
                _, p_value,_ = coint(s1_df['close'], s2_df['close'])
                cointegrated = p_value < 0.05
                if cointegrated:
                    #print("Found pair: '%s'/'%s' (Industry=%s) (p-value=%s)" % (s1_symbol, s2_symbol, industry_s1, p_value))
                    writer.writerow([s1_symbol, s2_symbol, industry_s1, p_value])

csv_file_pairs.close()
# -----------------------------------------------------------------------------
# Plot the first 5 cointegrating pairs
# -----------------------------------------------------------------------------
csv_path_pairs = "pairs.csv"
count = 0

pairs_file = Path(csv_path_pairs)
if pairs_file.is_file():
    pairs_df = pd.read_csv(csv_path_pairs, sep=',', encoding = 'utf8')
    
    for index, row in pairs_df.iterrows():
        s1_symbol = row['s1_symbol']
        s2_symbol = row['s2_symbol']
        
        s1_df = read_prices_csv_as_df("tickers/{0}_10y.csv".format(s1_symbol))
        s2_df = read_prices_csv_as_df("tickers/{0}_10y.csv".format(s2_symbol))

        if (not s1_df.empty) and (not s2_df.empty):
             # Find the date/index intersection of the two data sets
            s1_s2_index_intersection = s1_df.index.intersection(s2_df.index)
            s1_df = s1_df.reindex(s1_s2_index_intersection)
            s2_df = s2_df.reindex(s1_s2_index_intersection)
            
            calendar_dates = pd.date_range(start=start_date, end=end_date, freq='D', tz=None)
            s1_df = s1_df.reindex(calendar_dates)
            s2_df = s2_df.reindex(calendar_dates)

            s1_df = s1_df[start_date:end_date]
            s2_df = s2_df[start_date:end_date]
            
            s1_df = s1_df.dropna()
            s2_df = s2_df.dropna()
            
            # Plot the spread
            plt.figure(figsize=(14,4))
            spread = s1_df['close'] - s2_df['close']
            spread.plot()
            plt.axhline(spread.mean(), color='red', linestyle='--') # Add the mean
            plt.xlabel('Time')
            plt.title("Spread Between {0} and {1}".format(s1_symbol, s2_symbol))
            plt.legend(['Price Spread', 'Mean'])
            
            count +=1
            if count == 5:
                break

As we can see not all the pairs found cointegrate. But by plotting their spread we can easily filter out the ones that might in order to analyze them further.

References
Source code