DistributeMF_Part.py
import time
import copy
import numpy as np
from load_data import ratings_dict, item_id_list, user_id_list
from shared_parameter import *
"""
Part & Full
full: 用户上传所有项目的梯度,如果一个用户没有对一个item评价,梯度为零
part: 用户只上传已评项目的梯度
part会泄露已评项目的信息但是会有更高的计算效率
full不会泄露信息但是需要更多的计算时间
"""
def user_update(single_user_vector, user_rating_list, item_vector):
gradient = {}
for item_id, rate, _ in user_rating_list:
error = rate - np.dot(single_user_vector, item_vector[item_id])
single_user_vector = single_user_vector - lr * (-2 * error * item_vector[item_id] + 2 * reg_u * single_user_vector)
gradient[item_id] = error * single_user_vector
'''
gradient只包含用户评价的item,其他的不考虑为零
'''
return single_user_vector, gradient
def mse():
loss = []
for i in range(len(user_id_list)):
for r in range(len(ratings_dict[user_id_list[i]])):
item_id, rate, _ = ratings_dict[user_id_list[i]][r]
error = (rate - np.dot(user_vector[i], item_vector[item_id])) ** 2
loss.append(error)
return np.mean(loss)
if __name__ == '__main__':
user_vector = np.random.normal(size=[len(user_id_list), hidden_dim])
item_vector = np.random.normal(size=[len(item_id_list), hidden_dim])
'''
i = 0
for vector in item_vector:
for e in vector:
print(e)
i = i + 1
print('===============')
print(i)
print('================')
print(len(item_vector[0]))第一行有多少个元素(列)
'''
start_time = time.time()
for iteration in range(max_iteration):
print('###################')
t = time.time()
gradient_from_user = []
for i in range(len(user_id_list)):
user_vector[i], gradient = user_update(user_vector[i], ratings_dict[user_id_list[i]], item_vector)
gradient_from_user.append(gradient)
tmp_item_vector = copy.deepcopy(item_vector)
for g in gradient_from_user:
for item_id in g:
'''
full:
item_vector = item_vector - lr * (-2 * g + 2 * reg_u * item_vector)
part:
只是用户已评的项目,不是整个矩阵了
'''
item_vector[item_id] = item_vector[item_id] - lr * (-2 * g[item_id] + 2 * reg_v * item_vector[item_id])
if np.mean(np.abs(item_vector - tmp_item_vector)) < 1e-4:
print('Converged')
break
print('Time', time.time() - t, 's')
print('loss', mse())
print('Converged using', time.time() - start_time)