STGCN(时空图卷积网络)详解

1️⃣ STGCN介绍

前面已经介绍过了图卷积(GCN)。这篇论文《Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition》将GCN扩展到时空图模型上,用于实现动作识别。下图展示了STGCN的输入,即一系列骨架图。其中每个节点对应于人体的一个关节,有两种类型的边,①符合节点自然连通性的空间边(图1中淡蓝色线条)②跨越连续时间步长连接相同节点的时间边(淡绿色线条)

在这里插入图片描述


2️⃣ 网络结构

STGCN简单的网络结构如下所示,由三部分构成:

  • 归一化:对输入数据归一化
  • 时空变化:通过多个ST-GCN块,每个块中交替使用GCN和TCN
  • 输出:使用平均池化和全连接层对特征进行分类

在这里插入图片描述

在这里,我们不关注数据部分,只对网络结构进行解析,因此我们来看ST-GCN块的详细结构:

在这里插入图片描述

  • 步骤一:引入一个可学习的权重矩阵,与邻接矩阵大小一致,记作Learnable edge importance weight。让它与邻接矩阵 A A A按位相乘,得到加权后的邻接矩阵。其目的是给重要的边较大的权重,给非重要的边较小的权重。
  • 步骤二:将加权后的邻接矩阵与输入数据送到GCN中进行运算
  • 步骤三:利用TCN网络,实现时间维度信息的聚合

3️⃣ 代码

# 只包含网络结构,“网络的输入(关节坐标)”和“邻接矩阵”都是随机生成的

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


# ---------------------------------------------------------------------------------
# 空域图卷积
# ---------------------------------------------------------------------------------
class SpatialGraphConvolution(nn.Module):
    """
    Args:
    in_channels: 输入通道数,表示每个节点的特征维度(例如 3 对应 x, y, z 坐标)。
    out_channels: 输出通道数,表示经过卷积后每个节点的特征维度。
    s_kernel_size: 空间卷积核的大小,等于邻接矩阵的数量,表示多重图卷积的支持
    """
    def __init__(self, in_channels, out_channels, s_kernel_size):
        super().__init__()
        self.s_kernel_size = s_kernel_size
        self.conv = nn.Conv2d(in_channels=in_channels,
                            out_channels=out_channels * s_kernel_size,
                            kernel_size=1)

    def forward(self, x, A):
        x = self.conv(x)
        n, kc, t, v = x.size()
        x = x.view(n, self.s_kernel_size, kc//self.s_kernel_size, t, v)
        #对邻接矩阵进行GC,相加特征
        x = torch.einsum('nkctv,kvw->nctw', (x, A))
        return x.contiguous()
  
# ---------------------------------------------------------------------------------
### ST-GCN 模型架构详解 #### 正则化与 Batch Normalization 的应用 ST-GCN(Spatio-Temporal Graph Convolutional Network)是一种针对时空数据建模的深度学习框架,广泛应用于动作识别等领域。为了提高模型稳定性并减少过拟合现象,Batch Normalization 被引入每一层卷积操作之后[^3]。具体而言,BN 层通过标准化激活函数前的输入值来加速收敛过程,并缓解梯度消失或爆炸问题。 此外,Dropout 技术也被采用作为一种显式的正则化手段。通常情况下,在全连接层之前会加入一定比例的 Dropout 来随机丢弃部分神经元,从而增强泛化能力[^4]。 #### 各层通道数的设计原理 关于通道数设置方面,一般遵循从小到大逐渐增加的原则。例如初始几层可能设定为64个滤波器,随后逐步提升至128甚至更高如256等数值。这种渐增模式有助于先提取低级局部特征再过渡向更高级别的抽象表达形式。这样的设计可以有效捕捉不同尺度下的空间与时序关联特性。 #### 时间卷积核数量 对于时间维度上的处理,则主要依赖一维卷积运算完成。这些卷积核负责沿着时间轴滑动扫描节点间相互作用规律。其大小往往依据实验效果调整优化得出最佳参数组合。 #### Residual Connection 实现方式 残差连接允许较深网络仍能保持良好性能表现的关键因素之一就是它解决了深层网络训练过程中可能出现的退化解难题。在实际编码实现当中可以通过简单加法操作将未经变换或者仅经过线性投影转换过的原始输入叠加回最终输出端口处达成目标。 #### Pooling Layer 配置详情 池化层主要用于降低特征地图的空间分辨率进而达到降维目的同时保留最重要信息成分不丢失过多细节内容。按照惯例会在特定层次比如第四七两阶段安排相应操作步骤执行最大值选取策略(Max-Pooling)。 #### 全局平均池化(Global Average Pooling) 及 Softmax 分类环节说明 当所有先前预处理完毕后进入最后决策判断时刻时,全局平均池化将会把整个序列压缩成固定长度表示向量供后续SoftMax分类器利用计算各类别概率分布情况以便做出最有可能的选择结果. #### 优化方法(SGD Optimizer Usage) 至于求解算法选择上,默认推荐使用带momentum项的标准随机梯度下降法(Stochastic Gradient Descent with Momentum),因为它具备较快收敛速度又能较好避开局部极小点陷阱的优势特点. ```python import torch.nn as nn class Model(nn.Module): def __init__(self): super(Model,self).__init__() self.st_gcn_networks = nn.Sequential( st_gcn(in_channels=3,out_channels=64,kernel_size=(9,1)), ... st_gcn(in_channels=128,out_channels=256,kernel_size=(9,1)) ) self.fc = nn.Linear(256,num_class) def forward(self,x): out = self.st_gcn_networks(x) out = F.avg_pool2d(out,(1,T)) # Global average pooling over time dimension T. out = out.view(N,-1) predict_score = self.fc(out) return predict_score ```
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值