在学习一篇关于马丁回测的博客中,出现了以下的问题,
这是那篇博客的地址
以下是我的代码,报错信息在最后,
由于是用别人的库,所以不需要数据爬取,还是没搞懂,
```python
import broker
import pandas as pd
import numpy as np
import tushare as ts
import matplotlib.pyplot as plt
from tushare import data
def __init__(self, startcash, start, end):
self.cash = startcash # 初始化现金
self.hold = 0 # 初始化持仓金额
self.holdper = self.hold / startcash # 初始化仓位
self.log = [] # 初始化日志
self.cost = 0 # 成本价
self.stock_num = 0 # 股票数量
self.starttime = start # 起始时间
self.endtime = end # 终止时间
self.quantlog = [] # 交易量记录
self.earn = [] # 总资产记录
self.num_log = []
self.droplog = [0]
def buy(self, currentprice, count):
self.cash -= currentprice * count
self.log.append('buy')
self.hold += currentprice * count
self.holdper = self.hold / (self.cash + self.hold)
self.stock_num += count
self.cost = self.hold / self.stock_num
self.quantlog.append(count // 100)
print('买入价:%.2f,手数:%d,现在成本价:%.2f,现在持仓:%.2f,现在筹码:%d' % (
currentprice, count // 100, self.cost, self.holdper, self.stock_num // 100))
self.earn.append(self.cash + currentprice * self.stock_num)
self.num_log.append(self.stock_num)
self.droplog = [0]
def sell(self, currentprice, count):
self.cash += currentprice * count
self.stock_num -= count
self.log.append('sell')
self.hold = self.stock_num * self.cost
self.holdper = self.hold / (self.cash + self.hold)
# self.cost = self.hold / self.stock_num
print('卖出价:%.2f,手数:%d,现在成本价:%.2f,现在持仓:%.2f,现在筹码:%d' % (
currentprice, count // 100, self.cost, self.holdper, self.stock_num // 100))
self.quantlog.append(count // 100)
self.earn.append(self.cash + currentprice * self.stock_num)
self.num_log.append(self.stock_num)
def holdstock(self, currentprice):
self.log.append('hold')
# print('持有,现在仓位为:%.2f。现在成本:%.2f' %(self.holdper,self.cost))
self.quantlog.append(0)
self.earn.append(self.cash + currentprice * self.stock_num)
self.num_log.append(self.stock_num)
def get_stock(self, code):
df = ts.get_k_data(code, autype='qfq', start=self.starttime, end=self.endtime)
df.index = pd.to_datetime(df.date)
df = df[['open', 'high', 'low', 'close', 'volume']]
return df
token = '495dc9cb17fb0c7e93f5402255a4aacee14a18eb7a312913e0850cc6'
ts.set_token(token)
pro = ts.pro_api()
def get_stock_pro(self, code):
code = code + '.SH'
df = pro.daily(ts_code= code, start_date = self.starttime, end_date= self.endtime)
return df
def startback(self, data, everyChange, accDropday, backtesting=None):
"""
回测函数
"""
for i in range(len(data)):
if i < 1:
continue
if i < accDropday:
drop = backtesting.accumulateVar(everyChange, i, i)
# print('现在累计涨跌幅度为:%.2f'%(drop))
self.martin(data[i], data[i - 1], drop, everyChange, i)
elif i < len(data) - 2:
drop = backtesting.accumulateVar(everyChange, i, accDropday)
# print('现在累计涨跌幅度为:%.2f'%(drop))
self.martin(data[i], data[i - 1], drop, everyChange, i)
else:
if self.stock_num > 0:
self.sell(data[-1], self.stock_num)
else:
self.holdstock(data[i])
def enter(self, currentprice, ex_price, accuDrop):
if accuDrop < -0.01: # and ex_price > currentprice:
count = (self.cash + self.hold) * 0.24 // currentprice // 100 * 100
print('再次入场')
self.buy(currentprice, count)
else:
self.holdstock(currentprice)
def martin(self, currentprice, ex_price, accuDrop, everyChange, i, backtesting=None):
diff = (ex_price - currentprice) / ex_price
self.droplog.append(diff)
if sum(self.droplog) <= 0:
self.droplog = [0]
if self.stock_num // 100 > 1:
if sum(self.droplog) >= 0.04:
if self.holdper * 2 < 0.24:
count = (self.cash + self.hold) * (0.25 - self.holdper) // currentprice // 100 * 100
self.buy(currentprice, count)
elif self.holdper * 2 < 1 and (self.hold / currentprice) // 100 * 100 > 0 and backtesting.computeCon(
self.log) < 5:
self.buy(currentprice, (self.hold / currentprice) // 100 * 100)
else:
self.sell(currentprice, self.stock_num // 100 * 100);
print('及时止损')
elif (everyChange[i - 2] < 0 and everyChange[
i - 1] < 0 and self.cost < currentprice): # or (everyChange[i-1] < -0.04 and self.cost < currentprice):
if (self.stock_num > 0) and ((self.stock_num * (1 / 2) // 100 * 100) > 0):
self.sell(currentprice, self.stock_num * (1 / 2) // 100 * 100)
# print("现在累计涨跌幅为: %.3f" %(accuDrop))
elif self.stock_num == 100:
self.sell(currentprice, 100)
else:
self.holdstock(currentprice)
else:
self.holdstock(currentprice)
else:
self.enter(currentprice, ex_price, accuDrop)
buylog = pd.Series(broker.log)
close = data.copy()
buy = np.zeros(len(close))
sell = np.zeros(len(close))
for i in range(len(buylog)):
if buylog[i] == 'buy':
buy[i] = close[i]
elif buylog[i] == 'sell':
sell[i] = close[i]
buy = pd.Series(buy)
sell = pd.Series(sell)
buy.index = close.index
sell.index = close.index
quantlog = pd.Series(broker.quantlog)
quantlog.index = close.index
earn = pd.Series(broker.earn)
earn.index = close.index
buy = buy.loc[buy > 0]
sell = sell.loc[sell>0]
plt.plot(close)
plt.scatter(buy.index,buy,label = 'buy')
plt.scatter(sell.index,sell, label = 'sell')
plt.title('马丁策略')
plt.legend()
#画图
plt.rcParams['font.sans-serif'] = ['SimHei']
fig, (ax1, ax2, ax3) = plt.subplots(3,figsize=(15,8))
ax1.plot(close)
ax1.scatter(buy.index,buy,label = 'buy',color = 'red')
ax1.scatter(sell.index,sell, label = 'sell',color = 'green')
ax1.set_ylabel('Price')
ax1.grid(True)
ax1.legend()
ax1.xaxis_date()
ax2.bar(quantlog.index, quantlog, width = 5)
ax2.set_ylabel('Volume')
ax2.xaxis_date()
ax2.grid(True)
ax3.xaxis_date()
ax3.plot(earn)
ax3.set_ylabel('总资产包括浮盈')
plt.show()