【论文速度 + 核心代码定位】(2024 ECCV)ParCo: Part-Coordinating Text-to-Motion Synthesis

ParCo: Part-Coordinating Text-to-Motion Synthesis

Qiran Zou, Shangyuan Yuan , Shian Du, Yu Wang, Chang Liu, Yi Xu, Jie Chen, and Xiangyang Ji

论文地址[2403.18512] ParCo: Part-Coordinating Text-to-Motion Synthesis

codehttps://github.com/qrzou/ParCo


研究动机

通过身体部分的分解与协调机制,解决现有文本到动作生成方法在语义对齐、动作协调性和计算效率上的不足。

 


方法 & 创新

这篇论文提出了 ParCo 框架,让动作生成模型更好地理解并协调生成身体各个部分(body part)的运动。

该框架包含两个阶段:(1)针对身体部分的 motion vq-vae;(2)负责生成各身体部分运动的 transformer

身体部分感知的 Motion 离散化模块

1、把身体节点划分为如上图的 6 个部分:R.Arm, L.Arm, R.Leg, L.Leg, Backbone, and Root.

https://github.com/qrzou/ParCo/blob/main/dataset/dataset_VQ_bodypart.py#L107-L220

def whole2parts(motion, mode='t2m', window_size=None):
    # motion
    if mode == 't2m':
        # 263-dims motion is actually an augmented motion representation
        # split the 263-dims data into the separated augmented data form:
        #    root_data, ric_data, rot_data, local_vel, feet
        aug_data = torch.from_numpy(motion)  # (nframes, 263)
        joints_num = 22
        s = 0  # start
        e = 4  # end
        root_data = aug_data[:, s:e]  # [seg_len-1, 4]
        s = e
        e = e + (joints_num - 1) * 3
        ric_data = aug_data[:, s:e]  # [seq_len, (joints_num-1)*3]. (joints_num - 1) means the 0th joint is dropped.
        s = e
        e = e + (joints_num - 1) * 6
        rot_data = aug_data[:, s:e]  # [seq_len, (joints_num-1) *6]
        s = e
        e = e + joints_num * 3
        local_vel = aug_data[:, s:e]  # [seq_len-1, joints_num*3]
        s = e
        e = e + 4
        feet = aug_data[:, s:e]  # [seg_len-1, 4]

        # move the root out of belowing parts
        R_L_idx = torch.Tensor([2, 5, 8, 11]).to(torch.int64)        # right leg
        L_L_idx = torch.Tensor([1, 4, 7, 10]).to(torch.int64)        # left leg
        B_idx = torch.Tensor([3, 6, 9, 12, 15]).to(torch.int64)      # backbone
        R_A_idx = torch.Tensor([9, 14, 17, 19, 21]).to(torch.int64)  # right arm
        L_A_idx = torch.Tensor([9, 13, 16, 18, 20]).to(torch.int64)  # left arm

        nframes = root_data.shape[0]
        if window_size is not None:
            assert nframes == window_size

        # The original shape of root_data and feet
        # root_data: (nframes, 4)
        # feet: (nframes, 4)
        ric_data = ric_data.reshape(nframes, -1, 3)    # (nframes, joints_num - 1, 3)
        rot_data = rot_data.reshape(nframes, -1, 6)    # (nframes, joints_num - 1, 6)
        local_vel = local_vel.reshape(nframes, -1, 3)  # (nframes, joints_num, 3)

        root_data = torch.cat([root_data, local_vel[:,0,:]], dim=1)  # (nframes, 4+3=7)
        R_L = torch.cat([ric_data[:, R_L_idx - 1, :], rot_data[:, R_L_idx - 1, :], local_vel[:, R_L_idx, :]], dim=2)  # (nframes, 4, 3+6+3=12)
        L_L = torch.cat([ric_data[:, L_L_idx - 1, :], rot_data[:, L_L_idx - 1, :], local_vel[:, L_L_idx, :]], dim=2)  # (nframes, 4, 3+6+3=12)
        B = torch.cat([ric_data[:, B_idx - 1, :], rot_data[:, B_idx - 1, :], local_vel[:, B_idx, :]], dim=2)  # (nframes, 5, 3+6+3=12)
        R_A = torch.cat([ric_data[:, R_A_idx - 1, :], rot_data[:, R_A_idx - 1, :], local_vel[:, R_A_idx, :]], dim=2)  # (nframes, 5, 3+6+3=12)
        L_A = torch.cat([ric_data[:, L_A_idx - 1, :], rot_data[:, L_A_idx - 1, :], local_vel[:, L_A_idx, :]], dim=2)  # (nframes, 5, 3+6+3=12)

        Root = root_data  # (nframes, 4+3=7)
        R_Leg = torch.cat([R_L.reshape(nframes, -1), feet[:, 2:]], dim=1)  # (nframes, 4*12+2=50)
        L_Leg = torch.cat([L_L.reshape(nframes, -1), feet[:, :2]], dim=1)  # (nframes, 4*12+2=50)
        Backbone = B.reshape(nframes, -1)  # (nframes, 5*12=60)
        R_Arm = R_A.reshape(nframes, -1)  # (nframes, 5*12=60)
        L_Arm = L_A.reshape(nframes, -1)  # (nframes, 5*12=60)

    elif mode == 'kit':
        # 251-dims motion is actually an augmented motion representation
        # split the 251-dims data into the separated augmented data form:
        #    root_data, ric_data, rot_data, local_vel, feet
        aug_data = torch.from_numpy(motion)  # (nframes, 251)
        joints_num = 21
        s = 0  # start
        e = 4  # end
        root_data = aug_data[:, s:e]  # [seg_len-1, 4]
        s = e
        e = e + (joints_num - 1) * 3
        ric_data = aug_data[:, s:e]  # [seq_len, (joints_num-1)*3]. (joints_num - 1) means the 0th joint is dropped.
        s = e
        e = e + (joints_num - 1) * 6
        rot_data = aug_data[:, s:e]  # [seq_len, (joints_num-1) *6]
        s = e
        e = e + joints_num * 3
        local_vel = aug_data[:, s:e]  # [seq_len-1, joints_num*3]
        s = e
        e = e + 4
        feet = aug_data[:, s:e]  # [seg_len-1, 4]

        # move the root joint 0-th out of belowing parts
        R_L_idx = torch.Tensor([11, 12, 13, 14, 15]).to(torch.int64)        # right leg
        L_L_idx = torch.Tensor([16, 17, 18, 19, 20]).to(torch.int64)        # left leg
        B_idx = torch.Tensor([1, 2, 3, 4]).to(torch.int64)      # backbone
        R_A_idx = torch.Tensor([3, 5, 6, 7]).to(torch.int64)  # right arm
        L_A_idx = torch.Tensor([3, 8, 9, 10]).to(torch.int64)  # left arm

        nframes = root_data.shape[0]
        if window_size is not None:
            assert nframes == window_size

        # The original shape of root_data and feet
        # root_data: (nframes, 4)
        # feet: (nframes, 4)
        ric_data = ric_data.reshape(nframes, -1, 3)    # (nframes, joints_num - 1, 3)
        rot_data = rot_data.reshape(nframes, -1, 6)    # (nframes, joints_num - 1, 6)
        local_vel = local_vel.reshape(nframes, -1, 3)  # (nframes, joints_num, 3)

        root_data = torch.cat([root_data, local_vel[:,0,:]], dim=1)  # (nframes, 4+3=7)
        R_L = torch.cat([ric_data[:, R_L_idx - 1, :], rot_data[:, R_L_idx - 1, :], local_vel[:, R_L_idx, :]], dim=2)  # (nframes, 4, 3+6+3=12)
        L_L = torch.cat([ric_data[:, L_L_idx - 1, :], rot_data[:, L_L_idx - 1, :], local_vel[:, L_L_idx, :]], dim=2)  # (nframes, 4, 3+6+3=12)
        B = torch.cat([ric_data[:, B_idx - 1, :], rot_data[:, B_idx - 1, :], local_vel[:, B_idx, :]], dim=2)  # (nframes, 5, 3+6+3=12)
        R_A = torch.cat([ric_data[:, R_A_idx - 1, :], rot_data[:, R_A_idx - 1, :], local_vel[:, R_A_idx, :]], dim=2)  # (nframes, 5, 3+6+3=12)
        L_A = torch.cat([ric_data[:, L_A_idx - 1, :], rot_data[:, L_A_idx - 1, :], local_vel[:, L_A_idx, :]], dim=2)  # (nframes, 5, 3+6+3=12)

        Root = root_data  # (nframes, 4+3=7)
        R_Leg = torch.cat([R_L.reshape(nframes, -1), feet[:, 2:]], dim=1)  # (nframes, 4*12+2=50)
        L_Leg = torch.cat([L_L.reshape(nframes, -1), feet[:, :2]], dim=1)  # (nframes, 4*12+2=50)
        Backbone = B.reshape(nframes, -1)  # (nframes, 5*12=60)
        R_Arm = R_A.reshape(nframes, -1)  # (nframes, 5*12=60)
        L_Arm = L_A.reshape(nframes, -1)  # (nframes, 5*12=60)

    else:
        raise Exception()

    return [Root, R_Leg, L_Leg, Backbone, R_Arm, L_Arm]

2、把连续的 motion 离散化成 motion codes / tokens

MoMask 里 也有类似的操作,详见:

【论文阅读笔记 + 思考 + 总结】MoMask: Generative Masked Modeling of 3D Human Motions_residual vqvae-CSDN博客

用 encoder 提取特征,特征去可学习的码本中找到最相近的特征,所在的下标(位置)就是 motion codes / tokens。

motion codes / tokens 也可以用 decoder 再解码回 motion。

这里是用 6 个不同的 VQ-VAE 分别离散化不同身体部分的运动,得到的结果就是 body part motion codes / tokens。

文本驱动的 Part-Coordinated Transformer

这里分别用 6 个 small transformer,用自回归(与 T2M-GPT 类似)的方式去分别生成 6 个身体部分的动作,即根据输入的文本 clip 特征和前面部分的 body part motion token 去预测当前的 body part motion token。

但为了保证 body part 之间的协调性,这篇文章再提了个身体部分协作层

 在实现的时候,是在每个 small transformer layer 前插入这个身体部分协作层(第一层除外)。

具体来说,就是用 MLP 提取其他(≠ i)身体部分 transformer 的 token 特征,在和第 i 个身体部分 transformer 的 token 特征加和,再过一个 LayerNorm 层。

https://github.com/qrzou/ParCo/blob/main/models/t2m_trans_bodypart.py#L142-L159

            for i in range(num_layers):
                # block
                block = Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate)
                setattr(self, f'{name}_block_{i}', block)

                # Fuse module for each layers except for the first transformer layer.
                if use_fuse and i != 0:
                    # LayerNorm before sent into fuse modules.
                    ln = nn.LayerNorm(embed_dim)
                    setattr(self, f'{name}_ln_{i}', ln)
                    # Fuse module: MLP, global info weight
                    # todo: set in the config file
                    if self.fuse_ver == 'V1_3':
                        fuse = FuseModuleV1_3(in_features=other_parts_embed_dim, out_features=embed_dim,
                                          num_mlp_layers=num_mlp_layers, drop_out_rate=drop_out_rate, alpha=alpha)
                    else:
                        raise NotImplementedError()
                    setattr(self, f'{name}_fuse_{i}', fuse)

https://github.com/qrzou/ParCo/blob/main/models/t2m_trans_bodypart.py#L229-L270

        for i in range(self.num_layers):

            # Update the parts_token_embeddings using Transformer and Fuse

            # Create a list to keep the tensor from parts_token_embeddings,
            #   because the parts_token_embeddings will be updated.
            if self.use_fuse and i != 0:
                '''
                Do layerNorm before input to the FusionModule
                  It is unlikely to set the layerNorm into the FusionModule,
                    because the output will be sent to different FusionModule,
                    and we need to do LayerNorm before sending.
                  And insert a LayerNorm in the end in transformer block need to modify a lot of codes,
                    so we don't do that.
                '''
                no_modified_input = []
                for j, name in enumerate(self.parts_name):
                    ln = getattr(self, f'{name}_ln_{i}')
                    no_modified_input.append(ln(parts_token_embeddings[j]))
            else:
                no_modified_input = [elem for elem in parts_token_embeddings]  # not being modified parts token embeddings

            for j, name in enumerate(self.parts_name):

                # Fetch the input data and module
                block = getattr(self, f'{name}_block_{i}')
                x = no_modified_input[j]

                # Fuse the information or not.
                if self.use_fuse and i != 0:  # fuse the parts info if not 0-th layer

                    fuse = getattr(self, f'{name}_fuse_{i}')
                    other_parts_emb = [no_modified_input[count]
                                       for count in range(len(self.parts_name)) if count != j]
                    # other_parts_emb = torch.cat(other_parts_emb, dim=2)  # (B, 51, other_parts_embed_dim)
                    x = fuse(x, other_parts_emb)

                # send to transformer block
                x = block(x)

                # update parts_token_embeddings
                parts_token_embeddings[j] = x


部分实验结果展示

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值