Time-Series-Library源码详解:模型架构与实现原理
项目概述
Time-Series-Library是一个专注于深度时间序列分析的开源代码库,支持五种主流任务:长短期预测、缺失值填补、异常检测、分类以及含外部变量的预测。该库整合了30+种先进时间序列模型,提供统一的实验框架和标准化评估流程,已成为学术界广泛使用的基准测试平台。本文将从代码架构、核心模型实现、关键技术原理三个维度,深入解析库的设计思想与技术细节。
整体架构设计
模块化架构
库采用分层设计思想,通过五个核心模块实现功能解耦:
- 数据层(data_provider): 提供统一数据接口,支持ETT、M4、UEA等10+数据集
- 模型层(models): 实现30+种时间序列模型,包括Autoformer、TimesNet、Mamba等SOTA方法
- 实验层(exp): 封装训练/验证/测试流程,支持多任务统一接口
- 工具层(utils): 提供指标计算、早停机制、数据标准化等基础功能
- 脚本层(scripts): 存储任务配置文件,支持一键复现论文实验
核心目录结构
Time-Series-Library/
├── data_provider/ # 数据加载与预处理
├── exp/ # 实验管理模块
├── layers/ # 基础网络层组件
├── models/ # 模型实现
├── scripts/ # 任务配置脚本
├── utils/ # 工具函数
└── run.py # 程序入口
核心模型架构解析
1. Autoformer: 自相关分解Transformer
Autoformer作为库中的经典模型,创新性地提出自相关机制(AutoCorrelation)替代传统注意力,解决长期依赖建模问题。
自相关机制实现
# layers/AutoCorrelation.py
class AutoCorrelation(nn.Module):
def forward(self, queries, keys, values, attn_mask):
# 1. FFT获取周期特征
q_fft = torch.fft.rfft(queries.permute(0,2,3,1).contiguous(), dim=-1)
k_fft = torch.fft.rfft(keys.permute(0,2,3,1).contiguous(), dim=-1)
res = q_fft * torch.conj(k_fft) # 复数域相乘
corr = torch.fft.irfft(res, dim=-1) # 逆变换获取自相关图谱
# 2. 时间延迟聚合
if self.training:
V = self.time_delay_agg_training(values.permute(0,2,3,1).contiguous(), corr)
else:
V = self.time_delay_agg_inference(...)
return V.contiguous(), corr.permute(0,3,1,2)
自相关机制通过傅里叶变换将时间域信号转换至频率域,捕捉序列的周期特性。核心创新点在于:
- 周期发现: 通过FFT将序列分解为不同频率分量,选择Top-K显著周期
- 时间延迟聚合: 基于周期特性聚合相关时序信息,替代传统注意力的O(n²)计算
序列分解模块
# layers/Autoformer_EncDec.py
def series_decomp(x, kernel_size):
"""使用移动平均进行序列分解"""
avg = F.avg_pool1d(x, kernel_size=kernel_size, stride=1, padding=kernel_size//2)
trend = avg
seasonal = x - trend
return seasonal, trend
Autoformer采用移动平均实现序列分解,将原始序列分离为趋势(trend)和季节(seasonal)分量,分别建模后重构输出。
2. TimesNet: 时序二维变化建模
TimesNet通过傅里叶变换提取周期特征,并创新性地将一维时序转化为二维图谱进行建模。
核心模块设计
# models/TimesNet.py
class TimesBlock(nn.Module):
def forward(self, x):
B, T, N = x.size()
# 1. FFT获取Top-K周期
period_list, period_weight = FFT_for_Period(x, self.k)
res = []
for i in range(self.k):
period = period_list[i]
# 2. 序列填充与reshape
if (self.seq_len + self.pred_len) % period != 0:
length = ((self.seq_len + self.pred_len) // period + 1) * period
padding = torch.zeros(...)
out = torch.cat([x, padding], dim=1)
# 3. 转化为2D图谱 [B, N, T/period, period]
out = out.reshape(B, length//period, period, N).permute(0,3,1,2).contiguous()
# 4. 2D卷积捕捉周期内和周期间依赖
out = self.conv(out)
res.append(out[...,:(self.seq_len+self.pred_len),:])
# 5. 自适应聚合多周期特征
res = torch.stack(res, dim=-1)
period_weight = F.softmax(period_weight, dim=1)
res = torch.sum(res * period_weight.unsqueeze(1).unsqueeze(1), -1)
return res + x # 残差连接
TimesBlock的核心创新在于:
- 周期感知建模: 通过FFT自动发现序列周期特性
- 2D卷积应用: 将时序问题转化为图像问题,同时捕捉局部和全局依赖
- 多周期聚合: 自适应融合不同周期的特征表示
3. Mamba: 状态空间模型新范式
Mamba作为新兴的线性时间序列模型,采用选择性状态空间架构,实现O(n)复杂度的长序列建模。
模型实现要点
# models/Mamba.py
class Model(nn.Module):
def __init__(self, configs):
super().__init__()
self.d_inner = configs.d_model * configs.expand
self.mamba = Mamba(
d_model=configs.d_model,
d_state=configs.d_ff,
d_conv=configs.d_conv,
expand=configs.expand
)
def forecast(self, x_enc, x_mark_enc):
# 1. 数据标准化
mean_enc = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True) + 1e-5).detach()
x_enc = x_enc / std_enc
# 2. 嵌入层与Mamba编码
x = self.embedding(x_enc, x_mark_enc)
x = self.mamba(x)
x_out = self.out_layer(x)
# 3. 逆标准化
x_out = x_out * std_enc + mean_enc
return x_out
Mamba的核心优势在于:
- 线性复杂度: 相比Transformer的O(n²)注意力,实现O(n)计算
- 长序列建模: 通过状态空间模型(SSM)有效捕捉长程依赖
- 高效推理: 适合实时性要求高的时间序列应用场景
关键技术原理
时间序列特征工程
库中实现了丰富的时序特征编码方式:
# layers/Embed.py
class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type='fixed', freq='h'):
super().__init__()
self.value_embedding = nn.Linear(c_in, d_model)
self.position_embedding = PositionalEmbedding(d_model)
self.temporal_embedding = TemporalEmbedding(d_model, freq=freq)
def forward(self, x, x_mark):
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
if self.embed_type == 'learned':
x += self.position_embedding(x)
return x
支持三种特征嵌入方式:
- 值嵌入(Value Embedding): 将原始特征投影到高维空间
- 位置嵌入(Position Embedding): 编码时序位置信息
- 时间特征嵌入(Temporal Embedding): 编码时间戳特征(年/月/日/时等)
多任务统一框架
通过Exp基类实现五种任务的统一接口:
# exp/exp_basic.py
class Exp_Basic(object):
def __init__(self, args):
self.model_dict = {
'TimesNet': TimesNet,
'Autoformer': Autoformer,
'Mamba': Mamba,
# ... 30+模型
}
self.device = self._acquire_device()
self.model = self._build_model().to(self.device)
def _build_model(self):
raise NotImplementedError
def train(self):
pass
def test(self):
pass
不同任务通过继承Exp_Basic实现特化逻辑:
Exp_Long_Term_Forecast
: 长期预测任务Exp_Imputation
: 缺失值填补任务Exp_Anomaly_Detection
: 异常检测任务Exp_Classification
: 时序分类任务
评估指标体系
库中实现了时间序列任务的完整评估指标集:
# utils/metrics.py
def metric(pred, true):
mae = MAE(pred, true)
mse = MSE(pred, true)
rmse = RMSE(pred, true)
mape = MAPE(pred, true)
mspe = MSPE(pred, true)
return mae, mse, rmse, mape, mspe
针对不同任务提供专用指标:
- 预测任务: MAE/MSE/RMSE/MAPE
- 异常检测: F1-score/精确率/召回率
- 分类任务: 准确率/混淆矩阵
实验工作流
以长期预测任务为例,完整实验流程如下:
关键实现代码
# exp/exp_long_term_forecasting.py
class Exp_Long_Term_Forecast(Exp_Basic):
def train(self, setting):
train_data, train_loader = self._get_data(flag='train')
vali_data, vali_loader = self._get_data(flag='val')
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
for epoch in range(self.args.train_epochs):
train_loss = []
self.model.train()
for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_loader):
model_optim.zero_grad()
batch_x = batch_x.float().to(self.device)
batch_y = batch_y.float().to(self.device)
outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, -self.args.pred_len:, f_dim:]
batch_y = batch_y[:, -self.args.pred_len:, f_dim:].to(self.device)
loss = criterion(outputs, batch_y)
train_loss.append(loss.item())
loss.backward()
model_optim.step()
vali_loss = self.vali(vali_data, vali_loader, criterion)
early_stopping(vali_loss, self.model, path)
if early_stopping.early_stop:
break
快速上手指南
环境准备
# 克隆仓库
git clone https://gitcode.com/GitHub_Trending/ti/Time-Series-Library
cd Time-Series-Library
# 安装依赖
pip install -r requirements.txt
模型训练示例
以Autoformer在ETTh1数据集上的长期预测为例:
# 执行训练脚本
bash ./scripts/long_term_forecast/ETT_script/Autoformer_ETTh1.sh
脚本内容解析:
# scripts/long_term_forecast/ETT_script/Autoformer_ETTh1.sh
python -u run.py \
--task_name long_term_forecast \
--is_training 1 \
--root_path ./dataset/ETT-small/ \
--data_path ETTh1.csv \
--model_id ETTh1_96_96 \
--model Autoformer \
--data ETTh1 \
--features M \
--seq_len 96 \ # 输入序列长度
--label_len 48 \ # 编码器输出长度
--pred_len 96 \ # 预测序列长度
--e_layers 2 \ # 编码器层数
--d_layers 1 \ # 解码器层数
--factor 3 \ # 自相关因子
--enc_in 7 \ # 输入特征数
--dec_in 7 \ # 解码器输入特征数
--c_out 7 \ # 输出特征数
--des 'Exp' \
--itr 1
模型对比与选型建议
根据库中基准测试结果,不同模型的性能特点如下:
模型 | 复杂度 | 长序列预测能力 | 多变量支持 | 关键优势 |
---|---|---|---|---|
Autoformer | O(n log n) | ★★★★☆ | ★★★★★ | 周期分解能力强 |
TimesNet | O(n) | ★★★★★ | ★★★★☆ | 捕捉复杂频率模式 |
Mamba | O(n) | ★★★★★ | ★★★☆☆ | 线性复杂度,推理快 |
PatchTST | O(n) | ★★★★☆ | ★★★★★ | 可解释性强 |
iTransformer | O(n²) | ★★★☆☆ | ★★★★★ | 特征交互建模好 |
选型建议:
- 高实时性场景:优先选择Mamba或TimesNet
- 多变量长序列:推荐Autoformer或PatchTST
- 异常检测任务:TimesNet综合性能最优
- 资源受限环境:LightTS或DLinear(MLP类模型)
总结与展望
Time-Series-Library通过模块化设计和统一接口,为时间序列研究提供了高效实验平台。其核心价值在于:
- 模型多样性:覆盖30+主流时序模型,便于对比研究
- 任务全面性:支持五种时序任务,满足不同应用需求
- 代码规范性:统一实现框架,降低新模型集成难度
未来发展方向:
- 引入大语言模型(LLM)与时间序列的融合方法
- 强化可解释性工具,可视化模型决策过程
- 优化大规模时序数据处理效率,支持分布式训练
通过本文的解析,读者可深入理解库的设计思想与实现细节,快速开展时间序列相关研究与应用开发。建议结合源码和教程进一步实践,探索不同模型在具体场景下的性能表现。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考