import numpy as np
from numpy.lib.function_base import vectorize
import pandas as pd
import matplotlib.pyplot as plt
from pandas.core.indexes import datetimes
from pandas.core.indexes.base import Index
import tushare as ts
import datetime
from dateutil.parser import parse
from numba import cuda
#ts.set_token('c6a0826859d3f3e6c2bf5fe8d556137617ba7a7f22ff76049166c019')
CASH=1000000,
START_DATE=20180101
END_DATE=20211224
# 下载数据
#pro = ts.pro_api()
#df = pro.daily(ts_code='603518.SH', start_date='20110101', end_date='20211224')
#df.to_csv("603518.SH.csv")#保存股票日线数据
#trade_cal = pro.trade_cal()
#trade_cal.to_csv("trade_cal.csv")#保存交易开盘数据
trade_cal = pd.read_csv("trade_cal.csv")
class Context:
def __init__(self,cash,start_date,end_date):
self.cash=cash
self.start_date=start_date
self.end_date=end_date
self.positions={}
self.benchmark=None
self.date_range=trade_cal [(trade_cal['is_open']==1) & \
(trade_cal['cal_date']>=start_date)& \
(trade_cal['cal_date']<=end_date)]['cal_date'].values
self.dt=None
class G:
def __init__(self):
self.p1=None
self.p2=None
self.security=None
g=G()
context= Context(CASH,START_DATE,END_DATE)
print (context.date_range)
def set_benchmark(secturity):
context.benchmark=secturity
#获取一定时间段内历史数据
def attribute_daterange_history(secturity,start_date,end_date,fields=('open','close','high','low','vol')):
try:
f=open(secturity+'.csv','r')
df=pd.read_csv(f,index_col='trade_date').loc[end_date:start_date,:]
except FileNotFoundError:
pro = ts.pro_api()
df = pro.daily(ts_code=secturity, start_date=start_date, end_date=end_date)
return df
#获取day天前一定数量历史数据
def attribute_history(secturity,count,n_day,fields=('open','close','high','low','vol')):
current_time_str=parse(str(context.dt)).strftime("%Y-%m-%d")#20201201=>2020-12-01
current_time = datetime.datetime.strptime(current_time_str,'%Y-%m-%d')#str=>datetime
end_date=int((current_time-datetime.timedelta(days=n_day)).strftime("%Y%m%d"))
start_date=trade_cal [(trade_cal['is_open']==1) & \
(trade_cal['cal_date']<=end_date)]['cal_date'][-count:].iloc[0]
return attribute_daterange_history(secturity,start_date,end_date,fields=('open','close','high','low','vol'))
#获取当前日期数据
def get_today_data(secturity):
today =context.dt
try:
f=open(secturity+'.csv','r')
data=pd.read_csv(f,index_col='trade_date').loc[today,:]
except KeyError:
data=pd.Series()
return data
#下单底层函数
def _order(today_data,security,amount):
if len(today_data)==0:
print("今日停牌")
return
p=today_data['open']
if (context.cash-amount*p)<0:
amount=int(context.cash/p)
print("现金不足,已经调整为%d"%amount)
if (amount%100)!=0:
amount =int (amount/100)*100
#print("不是100的倍数,已经调整为%d"%amount)
if context.positions.get(security,0)<-amount:
amount =-context.positions.get(security,0)
print("卖出股票不能超过当前持仓数量,已经调整为%d"%amount)
context.positions[security]=context.positions.get(security,0)+amount
context.cash-=amount*p
if context.positions[security]==0:
del context.positions[security]
#按照一定数量下单
def order(security,amount):
today_data= get_today_data(security)
_order(today_data,security,amount)
#按照目标数量下单
def order_target(security,amount):
if amount<0:
amount=0
print("数量不能为负,已经调整为0")
today_data= get_today_data(security)
hold_amount=context.positions.get(security,0)
delta_amount=amount-hold_amount
_order(today_data,security,delta_amount)
#按照期待值下单
def order_value(security,value):
today_data= get_today_data(security)
amount=int(value/today_data['open'])
_order(today_data,security,amount)
#按照目标值下单
def order_target_value(security,value):
if value < 0:
value=0
print("价值不能为负,已经调整为0")
today_data= get_today_data(security)
hold_value=context.positions.get(security,0)*today_data['open']
delta_value=value-hold_value
order_value(security,delta_value)
#运行
def run(p1,p2):
date_list=[]
for date in context.date_range:
date_list.append(parse(str(date)).strftime("%Y-%m-%d"))
plt_df=pd.DataFrame(index=date_list,columns=['value'])
context.cash=[1000000]
context.positions={}
init_value=context.cash[0]
initialize(p1,p2,context)
last_prize={}
bm_df=attribute_daterange_history(context.benchmark,context.start_date,context.end_date)
bm_init=bm_df['open'].iloc[-1]
for dt in context.date_range:
context.dt=dt
dt=parse(str(dt)).strftime("%Y-%m-%d")
handle_data(context)
value=context.cash[0]
for stock in context.positions:
today_data= get_today_data(stock)
if len(today_data)==0:
p=last_prize[stock]
else:
p=today_data['open']
last_prize[stock]=p
value+=p*context.positions[stock]
plt_df.loc[dt,'value']=value
plt_df.loc[dt,'ratio']=(value-init_value)/init_value
today_data= get_today_data(g.security)
plt_df.loc[dt,'benchmark_ratio']=(today_data['open']-bm_init)/bm_init
plt_df.loc[dt,'open']=today_data['open']
hist=attribute_history(g.security,g.p2,1)
ma5=hist['open'][0:5].mean()
ma20=hist['open'].mean()
plt_df.loc[dt,'ma5']=ma5
plt_df.loc[dt,'ma20']=ma20
#print(plt_df)
print("参数:%f,%f,%f"%(p1,p2,((value-init_value)/init_value)))
#print("股票持有数量:%f",context.positions)
#print("现金:%f",context.cash)
#print("市值:%f",value)
#print("收益:%f",((value-init_value)/init_value))
figure,axes=plt.subplots(nrows=2,ncols=1,figsize=(14,8))
plt_df[['open','ma5','ma20']].plot(ax=axes[0])
plt_df[['ratio','benchmark_ratio']].plot(ax=axes[1])
plt.show()
#初始化参数
def initialize(p1,p2,context):
g.p1=p1
g.p2=p2
g.security='002269.SZ'
set_benchmark('002269.SZ')
#用户策略
def handle_data(context):
#3天前均线
hist3=attribute_history(g.security,g.p2,3)
ma_short_3=hist3['open'][0:g.p1].mean()
ma_long_3=hist3['open'].mean()
#2天前均线
hist2=attribute_history(g.security,g.p2,2)
ma_short_2=hist2['open'][0:g.p1].mean()
ma_long_2=hist2['open'].mean()
#1天前均线
hist1=attribute_history(g.security,g.p2,1)
ma_short_1=hist1['open'][0:g.p1].mean()
ma_long_1=hist1['open'].mean()
#2天前均线斜率
k_ma_short_2=ma_short_2-ma_short_3
k_ma_long_2=ma_long_2-ma_long_3
#1天前均线斜率
k_ma_short_1=ma_short_1-ma_short_2
k_ma_long_1=ma_long_1-ma_long_2
#1天前的均线二次斜率
kk_ma_short_1=k_ma_short_1-k_ma_short_2
kk_ma_long_1=k_ma_long_1-k_ma_long_2
if ma_short_1>ma_long_1:
order_value(g.security,context.cash)
else:
order_target(g.security,0)
#order('002269.SZ',1000)
#for i in range(6,10):
# for j in range(10,30):
# run(i,j)
run(5,20)
- 1
- 2
- 3
- 4
- 5
- 6
前往页