ParCo: Part-Coordinating Text-to-Motion Synthesis
Qiran Zou, Shangyuan Yuan , Shian Du, Yu Wang, Chang Liu, Yi Xu, Jie Chen, and Xiangyang Ji
论文地址:[2403.18512] ParCo: Part-Coordinating Text-to-Motion Synthesis
code:https://github.com/qrzou/ParCo
研究动机
通过身体部分的分解与协调机制,解决现有文本到动作生成方法在语义对齐、动作协调性和计算效率上的不足。
方法 & 创新
这篇论文提出了 ParCo 框架,让动作生成模型更好地理解并协调生成身体各个部分(body part)的运动。
该框架包含两个阶段:(1)针对身体部分的 motion vq-vae;(2)负责生成各身体部分运动的 transformer
身体部分感知的 Motion 离散化模块
1、把身体节点划分为如上图的 6 个部分:R.Arm, L.Arm, R.Leg, L.Leg, Backbone, and Root.
def whole2parts(motion, mode='t2m', window_size=None):
# motion
if mode == 't2m':
# 263-dims motion is actually an augmented motion representation
# split the 263-dims data into the separated augmented data form:
# root_data, ric_data, rot_data, local_vel, feet
aug_data = torch.from_numpy(motion) # (nframes, 263)
joints_num = 22
s = 0 # start
e = 4 # end
root_data = aug_data[:, s:e] # [seg_len-1, 4]
s = e
e = e + (joints_num - 1) * 3
ric_data = aug_data[:, s:e] # [seq_len, (joints_num-1)*3]. (joints_num - 1) means the 0th joint is dropped.
s = e
e = e + (joints_num - 1) * 6
rot_data = aug_data[:, s:e] # [seq_len, (joints_num-1) *6]
s = e
e = e + joints_num * 3
local_vel = aug_data[:, s:e] # [seq_len-1, joints_num*3]
s = e
e = e + 4
feet = aug_data[:, s:e] # [seg_len-1, 4]
# move the root out of belowing parts
R_L_idx = torch.Tensor([2, 5, 8, 11]).to(torch.int64) # right leg
L_L_idx = torch.Tensor([1, 4, 7, 10]).to(torch.int64) # left leg
B_idx = torch.Tensor([3, 6, 9, 12, 15]).to(torch.int64) # backbone
R_A_idx = torch.Tensor([9, 14, 17, 19, 21]).to(torch.int64) # right arm
L_A_idx = torch.Tensor([9, 13, 16, 18, 20]).to(torch.int64) # left arm
nframes = root_data.shape[0]
if window_size is not None:
assert nframes == window_size
# The original shape of root_data and feet
# root_data: (nframes, 4)
# feet: (nframes, 4)
ric_data = ric_data.reshape(nframes, -1, 3) # (nframes, joints_num - 1, 3)
rot_data = rot_data.reshape(nframes, -1, 6) # (nframes, joints_num - 1, 6)
local_vel = local_vel.reshape(nframes, -1, 3) # (nframes, joints_num, 3)
root_data = torch.cat([root_data, local_vel[:,0,:]], dim=1) # (nframes, 4+3=7)
R_L = torch.cat([ric_data[:, R_L_idx - 1, :], rot_data[:, R_L_idx - 1, :], local_vel[:, R_L_idx, :]], dim=2) # (nframes, 4, 3+6+3=12)
L_L = torch.cat([ric_data[:, L_L_idx - 1, :], rot_data[:, L_L_idx - 1, :], local_vel[:, L_L_idx, :]], dim=2) # (nframes, 4, 3+6+3=12)
B = torch.cat([ric_data[:, B_idx - 1, :], rot_data[:, B_idx - 1, :], local_vel[:, B_idx, :]], dim=2) # (nframes, 5, 3+6+3=12)
R_A = torch.cat([ric_data[:, R_A_idx - 1, :], rot_data[:, R_A_idx - 1, :], local_vel[:, R_A_idx, :]], dim=2) # (nframes, 5, 3+6+3=12)
L_A = torch.cat([ric_data[:, L_A_idx - 1, :], rot_data[:, L_A_idx - 1, :], local_vel[:, L_A_idx, :]], dim=2) # (nframes, 5, 3+6+3=12)
Root = root_data # (nframes, 4+3=7)
R_Leg = torch.cat([R_L.reshape(nframes, -1), feet[:, 2:]], dim=1) # (nframes, 4*12+2=50)
L_Leg = torch.cat([L_L.reshape(nframes, -1), feet[:, :2]], dim=1) # (nframes, 4*12+2=50)
Backbone = B.reshape(nframes, -1) # (nframes, 5*12=60)
R_Arm = R_A.reshape(nframes, -1) # (nframes, 5*12=60)
L_Arm = L_A.reshape(nframes, -1) # (nframes, 5*12=60)
elif mode == 'kit':
# 251-dims motion is actually an augmented motion representation
# split the 251-dims data into the separated augmented data form:
# root_data, ric_data, rot_data, local_vel, feet
aug_data = torch.from_numpy(motion) # (nframes, 251)
joints_num = 21
s = 0 # start
e = 4 # end
root_data = aug_data[:, s:e] # [seg_len-1, 4]
s = e
e = e + (joints_num - 1) * 3
ric_data = aug_data[:, s:e] # [seq_len, (joints_num-1)*3]. (joints_num - 1) means the 0th joint is dropped.
s = e
e = e + (joints_num - 1) * 6
rot_data = aug_data[:, s:e] # [seq_len, (joints_num-1) *6]
s = e
e = e + joints_num * 3
local_vel = aug_data[:, s:e] # [seq_len-1, joints_num*3]
s = e
e = e + 4
feet = aug_data[:, s:e] # [seg_len-1, 4]
# move the root joint 0-th out of belowing parts
R_L_idx = torch.Tensor([11, 12, 13, 14, 15]).to(torch.int64) # right leg
L_L_idx = torch.Tensor([16, 17, 18, 19, 20]).to(torch.int64) # left leg
B_idx = torch.Tensor([1, 2, 3, 4]).to(torch.int64) # backbone
R_A_idx = torch.Tensor([3, 5, 6, 7]).to(torch.int64) # right arm
L_A_idx = torch.Tensor([3, 8, 9, 10]).to(torch.int64) # left arm
nframes = root_data.shape[0]
if window_size is not None:
assert nframes == window_size
# The original shape of root_data and feet
# root_data: (nframes, 4)
# feet: (nframes, 4)
ric_data = ric_data.reshape(nframes, -1, 3) # (nframes, joints_num - 1, 3)
rot_data = rot_data.reshape(nframes, -1, 6) # (nframes, joints_num - 1, 6)
local_vel = local_vel.reshape(nframes, -1, 3) # (nframes, joints_num, 3)
root_data = torch.cat([root_data, local_vel[:,0,:]], dim=1) # (nframes, 4+3=7)
R_L = torch.cat([ric_data[:, R_L_idx - 1, :], rot_data[:, R_L_idx - 1, :], local_vel[:, R_L_idx, :]], dim=2) # (nframes, 4, 3+6+3=12)
L_L = torch.cat([ric_data[:, L_L_idx - 1, :], rot_data[:, L_L_idx - 1, :], local_vel[:, L_L_idx, :]], dim=2) # (nframes, 4, 3+6+3=12)
B = torch.cat([ric_data[:, B_idx - 1, :], rot_data[:, B_idx - 1, :], local_vel[:, B_idx, :]], dim=2) # (nframes, 5, 3+6+3=12)
R_A = torch.cat([ric_data[:, R_A_idx - 1, :], rot_data[:, R_A_idx - 1, :], local_vel[:, R_A_idx, :]], dim=2) # (nframes, 5, 3+6+3=12)
L_A = torch.cat([ric_data[:, L_A_idx - 1, :], rot_data[:, L_A_idx - 1, :], local_vel[:, L_A_idx, :]], dim=2) # (nframes, 5, 3+6+3=12)
Root = root_data # (nframes, 4+3=7)
R_Leg = torch.cat([R_L.reshape(nframes, -1), feet[:, 2:]], dim=1) # (nframes, 4*12+2=50)
L_Leg = torch.cat([L_L.reshape(nframes, -1), feet[:, :2]], dim=1) # (nframes, 4*12+2=50)
Backbone = B.reshape(nframes, -1) # (nframes, 5*12=60)
R_Arm = R_A.reshape(nframes, -1) # (nframes, 5*12=60)
L_Arm = L_A.reshape(nframes, -1) # (nframes, 5*12=60)
else:
raise Exception()
return [Root, R_Leg, L_Leg, Backbone, R_Arm, L_Arm]
2、把连续的 motion 离散化成 motion codes / tokens
MoMask 里 也有类似的操作,详见:
【论文阅读笔记 + 思考 + 总结】MoMask: Generative Masked Modeling of 3D Human Motions_residual vqvae-CSDN博客
用 encoder 提取特征,特征去可学习的码本中找到最相近的特征,所在的下标(位置)就是 motion codes / tokens。
motion codes / tokens 也可以用 decoder 再解码回 motion。
这里是用 6 个不同的 VQ-VAE 分别离散化不同身体部分的运动,得到的结果就是 body part motion codes / tokens。
文本驱动的 Part-Coordinated Transformer
这里分别用 6 个 small transformer,用自回归(与 T2M-GPT 类似)的方式去分别生成 6 个身体部分的动作,即根据输入的文本 clip 特征和前面部分的 body part motion token 去预测当前的 body part motion token。
但为了保证 body part 之间的协调性,这篇文章再提了个身体部分协作层。
在实现的时候,是在每个 small transformer layer 前插入这个身体部分协作层(第一层除外)。
具体来说,就是用 MLP 提取其他(≠ i)身体部分 transformer 的 token 特征,在和第 i 个身体部分 transformer 的 token 特征加和,再过一个 LayerNorm 层。
for i in range(num_layers):
# block
block = Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate)
setattr(self, f'{name}_block_{i}', block)
# Fuse module for each layers except for the first transformer layer.
if use_fuse and i != 0:
# LayerNorm before sent into fuse modules.
ln = nn.LayerNorm(embed_dim)
setattr(self, f'{name}_ln_{i}', ln)
# Fuse module: MLP, global info weight
# todo: set in the config file
if self.fuse_ver == 'V1_3':
fuse = FuseModuleV1_3(in_features=other_parts_embed_dim, out_features=embed_dim,
num_mlp_layers=num_mlp_layers, drop_out_rate=drop_out_rate, alpha=alpha)
else:
raise NotImplementedError()
setattr(self, f'{name}_fuse_{i}', fuse)
for i in range(self.num_layers):
# Update the parts_token_embeddings using Transformer and Fuse
# Create a list to keep the tensor from parts_token_embeddings,
# because the parts_token_embeddings will be updated.
if self.use_fuse and i != 0:
'''
Do layerNorm before input to the FusionModule
It is unlikely to set the layerNorm into the FusionModule,
because the output will be sent to different FusionModule,
and we need to do LayerNorm before sending.
And insert a LayerNorm in the end in transformer block need to modify a lot of codes,
so we don't do that.
'''
no_modified_input = []
for j, name in enumerate(self.parts_name):
ln = getattr(self, f'{name}_ln_{i}')
no_modified_input.append(ln(parts_token_embeddings[j]))
else:
no_modified_input = [elem for elem in parts_token_embeddings] # not being modified parts token embeddings
for j, name in enumerate(self.parts_name):
# Fetch the input data and module
block = getattr(self, f'{name}_block_{i}')
x = no_modified_input[j]
# Fuse the information or not.
if self.use_fuse and i != 0: # fuse the parts info if not 0-th layer
fuse = getattr(self, f'{name}_fuse_{i}')
other_parts_emb = [no_modified_input[count]
for count in range(len(self.parts_name)) if count != j]
# other_parts_emb = torch.cat(other_parts_emb, dim=2) # (B, 51, other_parts_embed_dim)
x = fuse(x, other_parts_emb)
# send to transformer block
x = block(x)
# update parts_token_embeddings
parts_token_embeddings[j] = x