BackTrader:性能优化之多股策略速度优化

前言:

谈及BackTrader的回测速度优化,最常见的说法是从底层使用numpy等计算库来替换,但这种优化无疑非常新手不友好。因此本文着眼于如何最简单的优化多股情况下回测慢这一情况。考虑测试效率,本文使用100支股票回测。经过测试,优化后策略执行速度提升59%(93->38.4)。

策略描述:

前一天非一字涨停的股票进入候选池。

第二天10~11点若涨幅大于4%买入。

持仓股在14:30时若未涨停卖出。

V1策略及运行时间:

v1代码设计思路:

使用5分数据进行交易,而使用日线数据进行候选池判断及涨幅判断。添加定时器只在每天15:00点筛选候选池,然后在next中根据时间与涨幅判断是否需要买入或卖出。策略部分代码如下:

class MyStrategy(bt.Strategy):
    params = dict(
        when=bt.timer.SESSION_START,
        end=bt.timer.SESSION_END,
        timer=True,
        cheat=False,
        offset=timedelta(),
        repeat=timedelta(),
        weekdays=[],
        period=3,
    )

    def log(self, txt, dt=None):
        ''' Logging function fot this strategy'''
        dt = dt or self.datas[0].datetime.datetime(0)
        print('%s, %s' % (dt.isoformat(), txt))

    def __init__(self):
        self.order = None

        self.add_timer(
                when=time(15, 0),
                offset=self.p.offset,
                repeat=self.p.repeat,
                weekdays=self.p.weekdays,
        )

        s_m = []
        for i, d in enumerate(self.datas):
            if not d._name.endswith('_day'):
                s_m.append([d._name, i, None])
        self.st_df = pd.DataFrame(data=s_m, columns=['code', 'min', 'day'])
        for i, d in enumerate(self.datas):
            if d._name.endswith('_day'):
                n = d._name.split('_')[0]
                self.st_df.loc[self.st_df.code == n, 'day'] = i
        #         self.stock_names.append(d._name)
        # self.min_stocks = self.datas[:int(len(self.datas)/2)]
        # self.day_stocks = self.datas[-int(len(self.datas)/2):]

        self.zt_list = []
        self.last_hold = []
        self.new_hold = []
        self.zt_num = 0

    def notify_order(self, order):
        if order.status in [order.Submitted, order.Accepted]:
            # Buy/Sell order submitted/accepted to/by broker - Nothing to do
            return

        # Check if an order has been completed
        # Attention: broker could reject order if not enough cash
        idx = self.st_df.loc[self.st_df.code==order.data._name].index.values[0]
        if order.status in [order.Completed]:
            if order.isbuy():
                self.log(
                    'BUY EXECUTED, Price: %.2f, Cost: %.2f, Comm %.2f' %
                    (order.executed.price,
                     order.executed.value,
                     order.executed.comm))

                self.new_hold.append(idx)
                self.zt_list.remove(idx)
            else:  # Sell
                self.log('SELL EXECUTED, Price: %.2f, Cost: %.2f, Comm %.2f' %
                         (order.executed.price,
                          order.executed.value,
                          order.executed.comm))
                self.last_hold.remove(idx)
        elif order.status in [order.Canceled, order.Expired, order.Margin, order.Rejected]:
            self.log('Order Canceled/Expired/Margin/Rejected')
            self.new_hold.remove(idx)

        # Write down: no pending order
        self.order = None

    def next(self):
        t = self.datetime.time(0)
        #1.每天早上10:05至11:05买入
        len_for_new = 10 - len(self.last_hold) - len(self.new_hold)
        if len(self.zt_list) > 0 and len_for_new > 0:
            if t >= time(9,40) and t <= time(14,30):
                for i in self.zt_list:
                    if i in self.last_hold:
                        continue
                    d = self.datas[self.st_df.loc[i, 'min']]
                    if len_for_new <= 0:
                        break
                    last_close = self.datas[self.st_df.loc[i, 'day']].close[0]
                    if 1.045 * last_close < d.close[0] < 1.09 * last_close:
                        len_for_new -= 1
                        targetvalue = 0.1 * self.broker.getvalue()
                        size = targetvalue/(last_close*1.09)//100*100
                        self.buy(data=d, size=size, price=last_close*1.09, exectype=bt.Order.Limit,
                                 valid=self.datetime.datetime(0)+timedelta(minutes=5))

        #2.每天14:35卖出
        if len(self.last_hold) > 0:
            if t == time(14, 35):
                for i in self.last_hold:
                    m = self.datas[self.st_df.loc[i, 'min']]
                    d = self.datas[self.st_df.loc[i, 'day']]
                    if m.close[0] < d.high_limit[0]: #14:30时 day bar最新是昨天的
                        print('sell 平仓', m._name, self.getposition(m).size)
                        self.close(data=m)

    
    def notify_timer(self, timer, when, *args, **kwargs):
        # 2.合并买入卖出结果
        self.last_hold += self.new_hold
        self.new_hold = []
        # 1.根据涨停预选股票池
        self.zt_list = []
        for i, row in self.st_df.iterrows():
            d = self.datas[row['day']]
            if d.close[0] > d.low[0] and d.pctChg[0] > 9.9:
                self.log('zhangting ' + str(d.close[0]) + d._name)
                self.zt_list.append(i)
        # 3.删除已买入
        self.zt_list = list(set(self.zt_list)-set(self.last_hold))
        self.zt_num += len(self.zt_list)
        #print('平均涨停数', self.zt_num/len(self.data0))

运行时间:

总时间:72秒

读取csvcerebro.adddata执行完成
5.8462

可以看到耗时主要集中在cerebro添加数据完成到执行完成,[3]中所提及的优化数据读取的方式便不适用。而根据[2]中提出,Observers和Analyzers耗时能达到执行的一半,去掉以后重新运行得到总时间:71秒,没有明显提升,可能是本例中添加的Observers和Analyzers都比较简单。

V2策略及运行时间:

v2代码改进思路:

为了提高运行效率,考虑尽量减少next中的判断,将其放到cerebro之外,同时将信号直接附加到5min数据上,不再传入日数据。代码如下:

class PandasDataExtendInd(bt.feeds.PandasData):
    # 增加线
    lines = ('ind','high_limit','buy_ind', 'sell_ind',)
    params = (('ind', -1),('high_limit', -1),('buy_ind', -1),('sell_ind', -1), )  # 机构持股数量合计


class MyStrategy(bt.Strategy):
    params = dict(
        when=bt.timer.SESSION_START,
        end=bt.timer.SESSION_END,
        timer=True,
        cheat=False,
        offset=timedelta(),
        repeat=timedelta(),
        weekdays=[],
        period=3,
    )

    def log(self, txt, dt=None):
        ''' Logging function fot this strategy'''
        dt = dt or self.datas[0].datetime.datetime(0)
        print('%s, %s' % (dt.isoformat(), txt))

    def __init__(self):
        self.order = None

        self.add_timer(
                when=time(15, 0),
                offset=self.p.offset,
                repeat=self.p.repeat,
                weekdays=self.p.weekdays,
        )

        self.zt_list = []
        self.last_hold = []
        self.new_hold = []
        self.zt_num = 0

    def notify_order(self, order):
        if order.status in [order.Submitted, order.Accepted]:
            # Buy/Sell order submitted/accepted to/by broker - Nothing to do
            return

        # Check if an order has been completed
        # Attention: broker could reject order if not enough cash
        #idx = self.st_df.loc[self.st_df.code==order.data._name].index.values[0]
        if order.status in [order.Completed]:
            if order.isbuy():
                self.log(
                    'BUY EXECUTED, Price: %.2f, Cost: %.2f, Comm %.2f' %
                    (order.executed.price,
                     order.executed.value,
                     order.executed.comm))

                self.new_hold.append(order.data)
                self.zt_list.remove(self.datas.index(order.data))
            else:  # Sell
                self.log('SELL EXECUTED, Price: %.2f, Cost: %.2f, Comm %.2f' %
                         (order.executed.price,
                          order.executed.value,
                          order.executed.comm))
                self.last_hold.remove(order.data)
        elif order.status in [order.Canceled, order.Expired, order.Margin, order.Rejected]:
            self.log('Order Canceled/Expired/Margin/Rejected')
            self.new_hold.remove(order.data)

        # Write down: no pending order
        self.order = None

    def next(self):
        t = self.datetime.time(0)
        #1.每天早上10:05至11:05买入
        len_for_new = 10 - len(self.last_hold) - len(self.new_hold)
        if len(self.zt_list) > 0 and len_for_new > 0:
            if t >= time(9,40) and t <= time(14,30):
                for i in self.zt_list:
                    if i in self.last_hold:
                        continue
                    d = self.datas[i]
                    if len_for_new <= 0:
                        break
                    if d.buy_ind:
                        len_for_new -= 1
                        targetvalue = 0.1 * self.broker.getvalue()
                        size = targetvalue/(d.high_limit*0.99)//100*100
                        self.buy(data=d, size=size, price=d.high_limit*0.99, exectype=bt.Order.Limit,
                                 valid=self.datetime.datetime(0)+timedelta(minutes=5))

        #2.每天14:35卖出
        if len(self.last_hold) > 0:
            if t == time(14, 35):
                for m in self.last_hold:
                    if m.sell_ind: #14:30时 day bar最新是昨天的
                        print('sell 平仓', m._name, self.getposition(m).size)
                        self.close(data=m)

    
    def notify_timer(self, timer, when, *args, **kwargs):
        # 2.合并买入卖出结果
        self.last_hold += self.new_hold
        self.new_hold = []
        # 1.根据涨停预选股票池
        self.zt_list = []
        for i, d in enumerate(self.datas):
            if d.ind[0]:
                self.zt_list.append(i)
        # 3.删除已买入
        self.zt_list = list(set(self.zt_list)-set(self.last_hold))
        self.zt_num += len(self.zt_list)
        #print('平均涨停数', self.zt_num/len(self.data0))

运行时间:

总时间:103秒

读取csvcerebro.adddata执行完成
5.84.593

反向优化效果显著,也就是next中的比较操作+少传入日数据的效果远远小于传入了复杂的5分钟数据。详细打印运行时间,可以看到next第一次开始时为80秒,中间接近70秒的时间是cerebro进行各种初始化。

V3最终优化

优化思路:

 详细分析代码后可以得到其中最耗时的部分为:

# cerebro.py -> runstrategies()
for data in self.datas:
    data.preload()

# feed.py -> preload()
def preload(self):
    while self.load():
        pass
    self._last()
    self.home()

preload本身不好优化,但是对于runstrategies可以采用多线程执行进行优化,采用cerebro本身使用的Multiprocessing.Pool完成。

运行时间:

总时间:49秒

读取csvcerebro.adddata执行完成
5.94.338.4

数据读取、载入耗时不变,执行速度大幅提升。

结论

利用多线程可以大幅提升策略回测速度,同时修改难度较低。

电脑参数:

i7-10510U 2.30GHz, 4核8线程

15G内存

win10

参考:

[1] https://zhuanlan.zhihu.com/p/345815425

[2] https://www.zhihu.com/question/440467223

[3] https://community.backtrader.com/topic/2263/which-line-code-function-consume-more-time-when-doing-a-backtest/13

from jqdata import * from jqfactor import get_factor_values import datetime import math from scipy.optimize import minimize import pandas as pd # 初始化函数,设定基准等等 def initialize(context): # 设定沪深300作为基准 set_benchmark("399303.XSHE") # 打开防未来函数 set_option("avoid_future_data", True) # 开启动态复权模式(真实价格) set_option("use_real_price", True) # 输出内容到日志 log.info() log.info("初始函数开始运行") # 过滤掉order系列API产生的比error级别低的log log.set_level("order", "error") # 固定滑点设置ETF 0.001(即交易对手方一档价) set_slippage(FixedSlippage(0.002), type="fund") # 股票交易总成本0.3%(含固定滑点0.02) set_order_cost( OrderCost( open_tax=0, close_tax=0.001, open_commission=0.0003, close_commission=0.0003, close_today_commission=0, min_commission=5, ), type="stock", ) g.hold_list = [] # 记录策略的持仓股票 g.positions = {} # 记录策略的持仓股票 # 持仓股票数 g.stock_sum = 6 # 判断买卖点的行业数量 g.num = 1 # 空仓的月份 g.pass_months = [] # 策略执行计划 run_weekly(adjust, 1, "9:31") run_daily(check, "14:50") # 获取昨日涨停票并卖出 def check(context): # 获取已持有列表 g.hold_list = list(g.positions.keys()) banner_stocks = [] # 获取昨日涨停列表 if g.hold_list != []: df = get_price( g.hold_list, end_date=context.previous_date, frequency="daily", fields=["close", "high_limit"], count=1, panel=False, fill_paused=False, ) df = df[df["close"] == df["high_limit"]] banner_stocks = list(df.code) for stock in banner_stocks: order_target_value_(context, stock, 0) # 获取昨日跌停列表 if g.hold_list != []: df = get_price( g.hold_list, end_date=context.previous_date, frequency="daily", fields=["close", "low_limit"], count=1, panel=False, fill_paused=False, ) df = df[df["close"] == df["low_limit"]] banner_stocks = list(df.code) for stock in banner_stocks: order_target_value_(context, stock, 0) # 获取策略当前持仓市值 def get_total_value(context): return sum(context.portfolio.positions[key].price * value for key, value in g.positions.items()) # 调仓 def adjust(context): target = select(context) # 获取前stock_sum个标的 target = target[:min(len(target), g.stock_sum)] # 获取已持有列表 g.hold_list = list(g.positions.keys()) portfolio = context.portfolio # 调仓卖出 for stock in g.hold_list: if stock not in target: order_target_value_(context, stock, 0) # 调仓买入 count = len(set(target) - set(g.hold_list)) if count == 0: return # 目标市值 target_value = portfolio.total_value # 当前市值 position_value = get_total_value(context) # 可用现金:当前现金 available_cash = portfolio.available_cash # 买入股票的总市值 value = max(0, min(target_value - position_value, available_cash)) # 等价值买入每一个未买入的标的 for security in target: if security not in g.hold_list: order_target_value_(context, security, value / count) # 择时 def select(context): I = get_market_breadth(context) industries = {"银行I", "煤炭I", "采掘I", "钢铁I"} if not industries.intersection(I) and not is_empty_month(context): return filter(context) return [] # 获取市场 def get_market_breadth(context): # 指定日期防止未来数据 yesterday = context.previous_date # 获取初始列表 中证全指(000985.XSHG) stocks = get_index_stocks("000985.XSHG") count = 1 h = get_price( stocks, end_date=yesterday, frequency="1d", fields=["close"], count=count + 20, panel=False, ) h["date"] = pd.DatetimeIndex(h.time).date # 将长表格转换为宽表格,方便按日期分析股票价格。 df_close = h.pivot(index="code", columns="date", values="close").dropna(axis=0) # 计算20日均线 df_ma20 = df_close.rolling(window=20, axis=1).mean().iloc[:, -count:] # 计算偏离程度 df_bias = df_close.iloc[:, -count:] > df_ma20 df_bias["industry_name"] = getStockIndustry(stocks) # 计算行业偏离比例 df_ratio = ((df_bias.groupby("industry_name").sum() * 100.0) / df_bias.groupby("industry_name").count()).round() # 获取偏离程度最高的行业 top_values = df_ratio.loc[:, yesterday].nlargest(g.num) I = top_values.index.tolist() return I # 基础过滤(过滤科创北交、ST、停牌、次新股) def filter_basic_stock(context, stock_list): # 30开头的是深交所的创业板, # 68开头的是上交所的科创板, # 8开头的股票可能指的是北交所的, # 新三板北交所的股票代码通常以43、83、87等开头 # 4开头的股票可能属于退市板块 current_data = get_current_data() return [ stock for stock in stock_list if not current_data[stock].paused and not current_data[stock].is_st and "ST" not in current_data[stock].name and "*" not in current_data[stock].name and "退" not in current_data[stock].name and not (stock[0] == "4" or stock[0] == "8" or stock[:2] == "68") and not context.previous_date - get_security_info(stock).start_date < datetime.timedelta(375) ] # 过滤当前时间涨跌停的股票 def filter_limitup_limitdown_stock(stock_list): current_data = get_current_data() return [ stock for stock in stock_list if current_data[stock].last_price < current_data[stock].high_limit and current_data[stock].last_price > current_data[stock].low_limit ] # 判断今天是在空仓月 def is_empty_month(context): month = context.current_dt.month return month in g.pass_months def getStockIndustry(stocks): # 第一步:获取原始行业数据(假设stocks是股票代码列表) industry = get_industry(stocks) # 第二步:提取申万一级行业名称 return pd.Series({stock: info["sw_l1"]["industry_name"] for stock, info in industry.items() if "sw_l1" in info}) # 过滤股票 def filter(context): stocks = get_index_stocks("399303.XSHE") # 这里的有问题,需要由399303.XSHE代替 stocks = filter_basic_stock(context, stocks) stocks = ( get_fundamentals( query( valuation.code, ) .filter( valuation.code.in_(stocks), # 从现有股票池中筛选 indicator.adjusted_profit > 0, # 要求调整后净利润>0 ) .order_by(valuation.market_cap.asc()) # 按市值升序排列(从小市值开始) ) .head(20) # 取前20只股票 .code # 提取股票代码 ) stocks = filter_limitup_limitdown_stock(stocks) return stocks # 自定义下单(涨跌停不交易) def order_target_value_(context, security, value): current_data = get_current_data() # 检查标的是否停牌、涨停、跌停 if current_data[security].paused: log.info(f"{security}: 今日停牌") return False # 检查是否涨停 if current_data[security].last_price == current_data[security].high_limit: log.info(f"{security}: 当前涨停") return False # 检查是否跌停 if current_data[security].last_price == current_data[security].low_limit: log.info(f"{security}: 当前跌停") return False # 获取当前标的的价格 price = current_data[security].last_price # 获取当前策略的持仓数量 current_position = g.positions.get(security, 0) # 计算目标持仓数量 target_position = (int(value / price) // 100) * 100 if price != 0 else 0 # 计算需要调整的数量 adjustment = target_position - current_position # 检查是否当天买入卖出 closeable_amount = context.portfolio.positions[security].closeable_amount if security in context.portfolio.positions else 0 if adjustment < 0 and closeable_amount == 0: log.info(f"{security}: 当天买入不可卖出") return False # 下单并更新持仓 if adjustment != 0: o = order(security, adjustment) if o: # 更新持仓数量 amount = o.amount if o.is_buy else -o.amount g.positions[security] = amount + current_position # 如果目标持仓为零,移除该证券 if target_position == 0: g.positions.pop(security, None) # 更新持有列表 g.hold_list = list(g.positions.keys()) return True return False (把这个聚宽的代码迁移到backtrade
最新发布
07-20
### Backtrader 性能优化技术 Backtrader 是一种用于回测交易策略的强大框架,其灵活性和可扩展性使其成为许多量化分析师的选择。然而,在处理大规模数据集或复杂策略时,性能可能成为一个瓶颈。以下是几种常见的 backtrader 性能优化方法: #### 数据预处理 通过减少输入到 backtrader 中的数据量可以显著提高运行速度。例如,可以通过移除不必要的列或将时间序列降采样来降低计算负担[^1]。 ```python import pandas as pd def downsample_data(df, rule='D'): """ 对 DataFrame 进行降采样 """ return df.resample(rule).last() ``` #### 使用 Cerebro 配置选项 Cerebro 提供了一些配置参数可以帮助提升效率。比如 `stdstats=False` 可以禁用标准统计指标的计算;另外设置较小的 `cheat-on-close=True` 参数能够提前知道收盘价从而避免某些延迟操作带来的额外开销。 ```python cerebro = bt.Cerebro(stdstats=False) cerebro.broker.set_coc(True) # cheat on close ``` #### 并行化执行 如果硬件资源允许的话,则考虑利用多核 CPU 来加速多个独立测试案例或者不同参数组合下的评估过程。这通常涉及到修改默认单线程模式为并行模式。 ```python from concurrent.futures import ProcessPoolExecutor def runstrats_parallel(strategies): with ProcessPoolExecutor() as executor: results = list(executor.map(lambda s: cerebro.run(), strategies)) return results ``` #### 替代算法实现 对于特别耗时的部分逻辑(如复杂的信号生成函数),尝试寻找更高效的替代方案可能是有益处的。有时候简单的数学变换就能带来意想不到的效果改进。 ---
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值