Facebook製時系列予測「Prophet」で株価を予測してみました【Python】

4月 13, 2021

今回は,Facebook社が開発した時系列予測ライブラリ「Prophet」を使ってみます.

「Prophet」によって,機械学習的に株価の予測をすることができ,以下のような結果を得ることができます.

これは,コカ・コーラ(KO)の株価とその予測です.

この記事では,
Prophetの概要と説明
Prophetによる予測コードとその説明
結果の見方
を解説しようと思います.

Prophetの概要

ProphetはFacebook社が公開している時系列予測ライブラリです.

Prophetの特徴として

・計算が正確で高速
・時間がかかる工程がなく簡単に使える
・予測モデルをチューニング可能
・RもしくはPythonで利用可能

であることがProphetのページで述べられています.

使いやすいことが,一番わかりやすい利点ですね.後で紹介するように,数行のコードで予測が可能です.機械学習や統計の知識がなくても,プログラミングの実行の仕方さえ分かれば使えます.

予測モデルをチューニング可能というのは,Prophetではディフォルトでもモデルを自動で作ってくれる一方,ユーザーが手を加えたいときにはモデルを調整できるということです.これによって,株価を予測する際には株価に関するドメイン知識を予測モデルに組み込むなどすることができます(今回はしません).

Prophetのもう少し詳しい説明

Prophetのページによると,Prophetは非線形のデータを年・週・日単位の周期性や変動点でフィッティングします.Prophetはは強い周期性を持つ時系列データやいくつかの周期性を持つ歴史のデータに対する予測を得意としているようです.また,時系列データの間にあるデータの抜けや外れ値に対しても強いようです.

つまり,周期性(年・週・日単位)やイベントを見つけ,その関数でフィッティングするといった感じですかね.予測は基本はあくまで時系列分析なので,人の意志によってランダムに動く株を予測できるかどうかはかなり微妙な気がします.

とはいっても,機械学習的(テクニカル分析)に予測するという範囲のなかでは十分に意味はありそうです.テクニカル分析は一部の観点では有効だと考えられるからです.

また,Prophetでは株のドメイン知識を加えることができるようなので,調整するユーザーの腕次第では予測精度の高いモデルが作れるかもしれません(あまりドメイン知識を入れすぎると,テクニカル要素が薄くなるのでそれはそれで意味がない気がしますが).

注意

今回は本当に「使ってみた」だけなので,現段階ではProphetの中身についてはあまり理解できていません.そのため,予測結果は私自身も信頼できないままでいます.Prophetの予測を信じるか否かはご自身でご判断ください.出力される結果を用いた投資については一切責任を負いません.

※詳細な説明記事は後日別途出稿するつもりです.

中身をきちんと理解したい人は,以下のリンクなどをご覧ください.

・Prophet概要&特徴:https://facebook.github.io/prophet/
・公式ドキュメント:https://facebook.github.io/prophet/docs/quick_start.html
・わかりやすいスライド:Prophet入門【Python編】Facebookの時系列予測ツール

使い方(コード)

Prophetのインストール

Prophetのインストールは少し厄介です.

Windowsでインストールしようとしたらエラーが出たので今回はいったん諦めてGoogle Colabで実行ました.

Google Colabならfbprophetだけでなく今回使うpandas_datareaderなどのライブラリもインストールする必要がありませんでした(過去にインストールしたからかも?).

もしインストールされていなかったら,pipでインストールしてください.また,当然ですがGoogle Colabの使用にはネット接続が必要です.

Prophetで株価を予測

コードは以下の通りとなっています.今回はKO(コカ・コーラ)の2015/1/1~2021/1/1のデータを使って予測してみました.コードの説明は後で行います.

import pandas as pd
import matplotlib.pyplot as plt
from fbprophet import Prophet
import pandas_datareader.data as web
import datetime

start = datetime.date(2015,1,1)
end = datetime.date(2021,1,1)

data = web.DataReader('KO', 'yahoo', start, end) #get data

data['ds'] = data.index
data = data.rename({'Adj Close':'y'},axis=1)

model = Prophet()
model.fit(data)

future_data = model.make_future_dataframe(periods=250, freq = 'd')

forecast_data = model.predict(future_data)

fig = model.plot(forecast_data)
model.plot_components(forecast_data)

実行結果

コード解説

データの取得

株価データの取得には,pandas_datareaderを用いました.

start = datetime.date(2015,1,1)
end = datetime.date(2021,1,1)

data = web.DataReader('KO', 'yahoo', start, end) #get data

このdataは,以下のような構造のDataFarameとなっています.

print(data)
                 High        Low       Open      Close      Volume  Adj Close
Date                                                                         
2015-01-02  42.400002  41.799999  42.259998  42.139999   9921100.0  34.332104
2015-01-05  42.970001  42.080002  42.689999  42.139999  26292600.0  34.332104
2015-01-06  42.939999  42.240002  42.410000  42.459999  16897500.0  34.592823
2015-01-07  43.110001  42.580002  42.799999  42.990002  13412300.0  35.024609
2015-01-08  43.570000  43.099998  43.180000  43.509998  21743600.0  35.448261
...               ...        ...        ...        ...         ...        ...
2020-12-24  53.549999  53.020000  53.020000  53.439999   3265500.0  52.998867
2020-12-28  54.439999  53.730000  53.849998  54.160000   9020500.0  53.712925
2020-12-29  54.490002  54.020000  54.450001  54.130001   8320600.0  53.683174
2020-12-30  54.630001  54.029999  54.049999  54.439999   8142700.0  53.990612
2020-12-31  54.930000  54.270000  54.450001  54.840000   8495000.0  54.387314

Prophetを用いるには,日付カラム名を「ds」,予測するカラム名を「y」にする必要があります.今回は修正後終値「Adj Close」を予測データとして扱うので,以下のように加工します.

data['ds'] = data.index
data = data.rename({'Adj Close':'y'},axis=1)

つまり,indexとなっている日付データを「ds」という名前で新たなカラムとして加え,また「Adj Close」列の名前を「y」に変えています.

これで,以下のような構造になります.

                 High        Low       Open  ...      Volume          y         ds
Date                                         ...                                  
2015-01-02  42.400002  41.799999  42.259998  ...   9921100.0  34.332104 2015-01-02
2015-01-05  42.970001  42.080002  42.689999  ...  26292600.0  34.332104 2015-01-05
2015-01-06  42.939999  42.240002  42.410000  ...  16897500.0  34.592823 2015-01-06
2015-01-07  43.110001  42.580002  42.799999  ...  13412300.0  35.024609 2015-01-07
2015-01-08  43.570000  43.099998  43.180000  ...  21743600.0  35.448261 2015-01-08
...               ...        ...        ...  ...         ...        ...        ...
2020-12-24  53.549999  53.020000  53.020000  ...   3265500.0  52.998867 2020-12-24
2020-12-28  54.439999  53.730000  53.849998  ...   9020500.0  53.712925 2020-12-28
2020-12-29  54.490002  54.020000  54.450001  ...   8320600.0  53.683174 2020-12-29
2020-12-30  54.630001  54.029999  54.049999  ...   8142700.0  53.990612 2020-12-30
2020-12-31  54.930000  54.270000  54.450001  ...   8495000.0  54.387314 2020-12-31

本来はindexのDateは消去した方がきれいですが,特に問題はなさそうなのでこのままにしました.

予測モデル作成

以下だけで予測モデルが生成されます.

model = Prophet()
model.fit(data)

モデルに手を加えたい場合は,Prophet()の()の中に記述します.今回はデフォルトで行うので空です.

予測期間の設定

今回は2021/1/1の後250日間の予測をしようと思うので,periods=250と設定します.また,freq = 'd’で予測単位を日単位に設定します.

future_data = model.make_future_dataframe(periods=250, freq = 'd')

予測結果

forecast_data = model.predict(future_data)

予測結果の描画

データ予測結果の可視化

fig = model.plot(forecast_data)

グラフの右の方(黒い点(実データ)がない部分)が予測部分です.

規則性の可視化

model.plot_components(forecast_data)

上から順にトレンド,週周期,年周期です.

結果の見方

予測結果のデータforecast_dataを見てみると,以下のようになっています.

pd.set_option('display.max_columns', 100) #DataFarameを省略しない
print(forecast_data)

forecast_dataの中身

             ds      trend  yhat_lower  yhat_upper  trend_lower  trend_upper  \
0    2015-01-02  34.288451   33.720347   37.876288    34.288451    34.288451   
1    2015-01-05  34.286100   33.321429   37.566870    34.286100    34.286100   
2    2015-01-06  34.285316   33.498347   37.700032    34.285316    34.285316   
3    2015-01-07  34.284533   33.446557   37.441409    34.284533    34.284533   
4    2015-01-08  34.283749   33.377396   37.406854    34.283749    34.283749   
...         ...        ...         ...         ...          ...          ...   
1756 2021-09-03  46.057433   41.651736   50.647310    41.836265    49.772723   
1757 2021-09-04  46.051406   41.621779   50.462587    41.802755    49.801785   
1758 2021-09-05  46.045378   41.441068   50.710119    41.769605    49.830673   
1759 2021-09-06  46.039351   41.530790   50.832855    41.736339    49.849572   
1760 2021-09-07  46.033324   41.603948   50.845615    41.684599    49.864172   

      additive_terms  additive_terms_lower  additive_terms_upper    weekly  \
0           1.479553              1.479553              1.479553  0.020683   
1           1.296744              1.296744              1.296744  0.023665   
2           1.300049              1.300049              1.300049  0.075154   
3           1.229096              1.229096              1.229096  0.044339   
4           1.188483              1.188483              1.188483  0.035377   
...              ...                   ...                   ...       ...   
1756        0.231965              0.231965              0.231965  0.020683   
1757        0.136596              0.136596              0.136596 -0.099610   
1758        0.157056              0.157056              0.157056 -0.099610   
1759        0.296244              0.296244              0.296244  0.023665   
1760        0.359144              0.359144              0.359144  0.075154   

      weekly_lower  weekly_upper    yearly  yearly_lower  yearly_upper  \
0         0.020683      0.020683  1.458870      1.458870      1.458870   
1         0.023665      0.023665  1.273079      1.273079      1.273079   
2         0.075154      0.075154  1.224895      1.224895      1.224895   
3         0.044339      0.044339  1.184756      1.184756      1.184756   
4         0.035377      0.035377  1.153106      1.153106      1.153106   
...            ...           ...       ...           ...           ...   
1756      0.020683      0.020683  0.211282      0.211282      0.211282   
1757     -0.099610     -0.099610  0.236205      0.236205      0.236205   
1758     -0.099610     -0.099610  0.256666      0.256666      0.256666   
1759      0.023665      0.023665  0.272579      0.272579      0.272579   
1760      0.075154      0.075154  0.283990      0.283990      0.283990   

      multiplicative_terms  multiplicative_terms_lower  \
0                      0.0                         0.0   
1                      0.0                         0.0   
2                      0.0                         0.0   
3                      0.0                         0.0   
4                      0.0                         0.0   
...                    ...                         ...   
1756                   0.0                         0.0   
1757                   0.0                         0.0   
1758                   0.0                         0.0   
1759                   0.0                         0.0   
1760                   0.0                         0.0   

      multiplicative_terms_upper       yhat  
0                            0.0  35.768004  
1                            0.0  35.582844  
2                            0.0  35.585365  
3                            0.0  35.513628  
4                            0.0  35.472232  
...                          ...        ...  
1756                         0.0  46.289398  
1757                         0.0  46.188001  
1758                         0.0  46.202435  
1759                         0.0  46.335595  
1760                         0.0  46.392468  

[1761 rows x 19 columns]

以下の19個の時系列データを持つことが分かりました.

forecast_data

ds … 日付データ

trend … データのトレンド性

yhat_lower … データ予測の下限

yhat_upper … データ予測の上限

trend_lower … トレンド予測の下限

trend_upper … トレンド予測の上限

additive_terms … 周期性部分のトータル(weeklyとyearlyの和)

additive_terms_lower … additive_termsの下限

additive_terms_upper … additive_termsの上限

weekly … 1週間の周期性

weekly_lower … 1週間の周期性の下限

weekly_upper … 1週間の周期性の上限

yearly … 1年間の周期性

yearly_lower … 1年間の周期性の下限

yearly_upper … 1年間の周期性の上限

multiplicative_terms … 増加する周期性

multiplicative_terms_lower … multiplicative_termsの下限

multiplicative_terms_upper … multiplicative_termsの上限

yhat … データ(株価)予測値.yの青い線.

また,グラフの黒い点は実データです.

このうち下のグラフにプロットされているのは,実データ(株価),y(予測値)とその範囲(おそらく分散)です.これをみると,コカ・コーラの株価はいったん下落した後に上昇するという予測が出ていますね.株のドメイン知識を加えてないのであまり信用できませんが…

また,以下のグラフではトレンド,週単位の周期性,年単位の周期性が出力されています.

コカ・コーラの場合,水曜日に向かって上がり,金曜日にはまた下がるようですね.また,4月頃に下がることが多いようです.

これを見ればアノマリーとかは気にしなくてもいいですね(米国株(S&P500)の季節アノマリーを調べてみましたを参照).やはりデータで見る素晴らしさはここにあります(予測値は置いといて).

コード再記

import pandas as pd
import matplotlib.pyplot as plt
from fbprophet import Prophet
import pandas_datareader.data as web
import datetime

start = datetime.date(2015,1,1)
end = datetime.date(2021,1,1)

data = web.DataReader('KO', 'yahoo', start, end) #get data

data['ds'] = data.index
data = data.rename({'Adj Close':'y'},axis=1)

model = Prophet()
model.fit(data)

future_data = model.make_future_dataframe(periods=250, freq = 'd')

forecast_data = model.predict(future_data)

fig = model.plot(forecast_data)
model.plot_components(forecast_data)

せっかく予測をしたので,実際の株価と予測値を比較してどれだけ当たっていたのかを見てみたいと思います.

以下の記事をご覧ください.

また,モデルの精度を確認する方法は,以下の記事をご覧ください.

※本記事では,基本は公式ドキュメント:https://facebook.github.io/prophet/docs/quick_start.htmlに沿っています.