from pyalgotrade.barfeed.csvfeed import GenericBarFeed
from pyalgotrade.bar import Frequency
from collections import deque
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from pyalgotrade import strategy
from pyalgotrade.stratanalyzer import returns
from pyalgotrade.stratanalyzer import sharpe
from pyalgotrade.stratanalyzer import drawdown
from pyalgotrade.stratanalyzer import trades
from pyalgotrade import plotter
import matplotlib as mpl
mpl.style.use('classic')
class MyDTStrategy(strategy.BacktestingStrategy):
def init(self, feed, instrument):
super(MyDTStrategy, self).init(feed, 100000)
self.__position = None
self.__instrument = instrument
self.__window_length = 10
self.__recent_prices = deque(maxlen=self.__window_length +2)
self.__classifier = DecisionTreeClassifier()
# 自变量或者输入变量(Independent, or input variables)
self.__X = deque(maxlen=500)
# 因变量或者输出变量(Dependent, or output variable)
self.__Y = deque(maxlen=500)
self.__prediction = 0 # Stores most recent prediction
self.__predictionList = []
def onEnterOk(self, position):
execInfo = position.getEntryOrder().getExecutionInfo()
self.info("BUY at $%.2f" % (execInfo.getPrice()))
def onEnterCanceled(self, position):
self.__position = None
def onExitOk(self, position):
execInfo = position.getExitOrder().getExecutionInfo()
self.info("SELL at $%.2f" % (execInfo.getPrice()))
self.__position = None
def onExitCanceled(self, position):
# If the exit was canceled, re-submit it.
self.__position.exitMarket()
def getPredictionList(self):
return self.__predictionList
def handle_data(self, bars):
bar = bars[self.__instrument]
self.__recent_prices.append(bar.getPrice())
if len(self.__recent_prices) == self.__window_length+2:
changes = np.diff(self.__recent_prices) > 0
self.__X.append(changes[:-1])
self.__Y.append(changes[-1])
if len(self.__Y) >= 66:
self.__classifier.fit(self.__X, self.__Y)
# 预测下期的涨跌
self.__prediction = self.__classifier.predict(changes[1:].reshape(1, -1))
print('self.__prediction')
print(self.__prediction)
self.__predictionList.append(self.__prediction)
if self.__position is not None:
if (self.__prediction>=0.5):
pass
else:
if not self.__position.exitActive():
self.__position.exitMarket()
else:
if (self.__prediction>=0.5):
self.__position = self.enterLong(self.__instrument, 10, True)
else:
pass
def onBars(self, bars):
self.handle_data(bars)
feed = GenericBarFeed(Frequency.DAY, None, None)
feed.addBarsFromCSV("AAPL", "./AAPLnew.csv")
myDTStrategy = MyDTStrategy(feed, "AAPL")#
returnsAnalyzer = returns.Returns()
myDTStrategy.attachAnalyzer(returnsAnalyzer)
sharpeRatioAnalyzer = sharpe.SharpeRatio()
myDTStrategy.attachAnalyzer(sharpeRatioAnalyzer)
drawdownAnalyzer = drawdown.DrawDown()
myDTStrategy.attachAnalyzer(drawdownAnalyzer)
tradesAnalyzer = trades.Trades()
myDTStrategy.attachAnalyzer(tradesAnalyzer)
plt = plotter.StrategyPlotter(myDTStrategy)
plt.getOrCreateSubplot("returns").addDataSeries("Simple returns", returnsAnalyzer.getReturns())
plt.getOrCreateSubplot("predict").addDataSeries("predict", myDTStrategy.getPredictionList())
myDTStrategy.run()
print ("Final portfolio value1: %.2f" % (myDTStrategy.getBroker().getEquity()) )
print ("Final portfolio value2: %.2f" % (myDTStrategy.getResult()) )
print ("Cumulative returns: %.2f %%" % (returnsAnalyzer.getCumulativeReturns()[-1] * 100) )
print ("Sharpe ratio: %.2f" % (sharpeRatioAnalyzer.getSharpeRatio(0.03)) )
print ("Max. drawdown: %.2f %%" % (drawdownAnalyzer.getMaxDrawDown() * 100) )
print ("Longest drawdown duration: %s" % (drawdownAnalyzer.getLongestDrawDownDuration()) )
plt.plot()