推荐系统中的各个模型对比学习损失InfoNCE的具体实现方法

1、SGL写法(Self-supervised Graph Learning for Recommendation)
SGL使用的是基于图结构扰动的数据增强方式,他给每一个节点都建立了augmented views。作者认为同一个节点增强出来的views看作positive pairs{(zu′,zu′′)∣u∈U},\left \{ \left ( z_{u}^{'},z_{u}^{''} \right ) |u\in U \right \} ,{(zu,zu′′)uU}同时任何不同节点的产生的views当作negative pairs{(zu′,zv′′)∣u,v∈U,u≠v}。\left \{ \left ( z_{u}^{'},z_{v}^{''} \right ) |u,v\in U ,u \ne v \right \} 。{(zu,zv′′)u,vU,u=v}作者是借鉴的SimCLR的论文,采用了InfoNCE的loss来做。
The auxiliary supervision of positive pairs encourages the consistency between different views of the same node for prediction, while the supervision of negative pairs enforces the divergence among different nodes.
Lssluser=∑u∈U−logexp⁡(s(zu′,zu′′)/τ)∑u∈Uexp⁡(s(zu′,zv′′)/τ)L_{ssl}^{user}=\sum_{u\in U}-log \frac{\exp \left ( s\left ( z_{u}^{'},z_{u}^{''} \right )/\tau \right ) }{ {\textstyle \sum_{u\in U}\exp \left ( s\left ( z_{u}^{'},z_{v}^{''} \right ) /\tau \right ) } } Lssluser=uUloguUexp(s(zu,zv′′)/τ)exp(s(zu,zu′′)/τ)
其中,s(⋅),s\left ( \cdot \right ) ,s(),测量两个向量之间的相似性,作者用了余弦相似度函数,τ,\tau,τ表示the temperature in softmax。同理可得itemitemitem端的对比损失,然后自监督任务就可用下式表示:Lssl=Lssluser+Lsslitem。L_{ssl}=L_{ssl}^{user}+L_{ssl}^{item}。Lssl=Lssluser+Lsslitem

# LightGCN前向传播、卷积部分、
    def forward(self, sub_graph1, sub_graph2, users, items, neg_items):
        user_embeddings, item_embeddings = self._forward_gcn(self.norm_adj)
        user_embeddings1, item_embeddings1 = self._forward_gcn(sub_graph1)
        user_embeddings2, item_embeddings2 = self._forward_gcn(sub_graph2)

        # Normalize embeddings learnt from sub-graph to construct SSL loss
        user_embeddings1 = F.normalize(user_embeddings1, dim=1)
        item_embeddings1 = F.normalize(item_embeddings1, dim=1)
        user_embeddings2 = F.normalize(user_embeddings2, dim=1)
        item_embeddings2 = F.normalize(item_embeddings2, dim=1)
        # 先对表征进行2-范式正则化


        user_embs = F.embedding(users, user_embeddings)
        item_embs = F.embedding(items, item_embeddings)
        neg_item_embs = F.embedding(neg_items, item_embeddings)
        user_embs1 = F.embedding(users, user_embeddings1)
        item_embs1 = F.embedding(items, item_embeddings1)
        user_embs2 = F.embedding(users, user_embeddings2)
        item_embs2 = F.embedding(items, item_embeddings2)
        # 查表获取各个部分的embedding信息


		# 接下来开始计算各种loss
        sup_pos_ratings = inner_product(user_embs, item_embs)       # [batch_size]
        sup_neg_ratings = inner_product(user_embs, neg_item_embs)   # [batch_size]
        sup_logits = sup_pos_ratings - sup_neg_ratings              # [batch_size]
        # BPR_LOSS首当其冲,作为模型的main loss是相当重要的,模型优化

		# 接下来就是对比损失的计算了,分别针对user端和item端进行计算。
        tot_ratings_user = torch.matmul(user_embs1,torch.transpose(user_embeddings2, 0, 1)) # [batch_size,num_users]
        pos_ratings_user = inner_product(user_embs1, user_embs2)    # [batch_size]


        pos_ratings_item = inner_product(item_embs1, item_embs2)    # [batch_size]
        tot_ratings_item = torch.matmul(item_embs1, torch.transpose(item_embeddings2, 0, 1))  # [batch_size, num_items]

        ssl_logits_user = tot_ratings_user - pos_ratings_user[:, None]                  # [batch_size, num_users]
        ssl_logits_item = tot_ratings_item - pos_ratings_item[:, None]                  # [batch_size, num_users]
        return sup_logits, ssl_logits_user, ssl_logits_item


	sup_logits, ssl_logits_user, ssl_logits_item = self.lightgcn(
	sub_graph1, sub_graph2, bat_users, bat_pos_items, bat_neg_items
	)
	# InfoNCE Loss
	clogits_user = torch.logsumexp(ssl_logits_user / self.ssl_temp, dim=1)
	clogits_item = torch.logsumexp(ssl_logits_item / self.ssl_temp, dim=1)
	infonce_loss = torch.sum(clogits_user + clogits_item)

2、SimGCL写法

    def InfoNCE(self,view1, view2, temperature):
        view1, view2 = F.normalize(view1, dim=1), F.normalize(view2, dim=1)
        pos_score = (view1 * view2).sum(dim=-1)
        pos_score = torch.exp(pos_score / temperature)
        ttl_score = torch.matmul(view1, view2.transpose(0, 1))
        ttl_score = torch.exp(ttl_score / temperature).sum(dim=1)
        cl_loss = -torch.log(pos_score / ttl_score)
        return torch.mean(cl_loss)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值