活动介绍
file-type

基于Pytorch的推荐系统模型深入解析

ZIP文件

下载需积分: 50 | 15.9MB | 更新于2025-01-25 | 55 浏览量 | 3 下载量 举报 1 收藏
download 立即下载
根据给定的文件信息,我们可以总结出以下几点IT知识点: ### 推荐系统和Pytorch 推荐系统是现代信息技术中用于增强用户体验的重要技术之一。它通过分析用户的历史行为、偏好、上下文信息等来预测用户可能感兴趣的内容,并据此提供个性化的推荐。推荐系统在电子商务、社交媒体、视频和音乐流媒体服务等众多领域发挥着核心作用。Pytorch是一种开源机器学习库,基于Python,广泛用于计算机视觉和自然语言处理领域,近年来也被应用于构建推荐系统模型。 ### 推荐系统模型 在标题中提到了“NGCF”,即Neural Graph Collaborative Filtering(神经图协作过滤),由Wang Xiang等人在SIGIR 2019会议上提出。NGCF是一种基于图神经网络的推荐模型,它将用户-物品交互表征为一个图,并利用图神经网络在图上进行特征传播,以学习更深层的用户和物品的嵌入表示。NGCF能够捕捉复杂的用户-物品交互关系,尤其在处理推荐系统中的隐式反馈场景中表现出色。 另外还提到了“自动编码器”,这是一种无监督学习算法,它可以学习输入数据的有效表示(编码),然后能够从这个表示中重构输入数据。在推荐系统领域,自动编码器经常被用来降维和去噪用户和物品的特征,以增强推荐模型的泛化能力。 ### 数据集 文件描述中提到了电影镜头数据集,其中包括“电影镜头-1M”和“电影镜头-100k”。这些数据集是推荐系统研究中的常用基准数据集,包含了用户对电影的评分信息。电影镜头-1M包含了约100万条评分数据,而电影镜头-100k则相应较小,包含约10万条数据。这些数据集广泛用于测试和比较不同的推荐算法和模型。 ### Python语言 标签中提到的“Python”,是目前在数据科学、机器学习和人工智能领域最流行和广泛使用的编程语言之一。Python语言以其简洁易读的语法、强大的标准库和第三方库支持以及跨平台特性而受到开发者的青睐。在推荐系统模型的开发和实验中,Python提供了大量的库和框架,如Pytorch、TensorFlow、Scikit-learn等,为研究人员和开发人员提供了极大的便利。 ### Pytorch模型开发 在提到的“Recommend_System_Pytorch-master”文件列表中,“master”通常指的是该文件或项目的主要分支,包含最新的开发成果和代码。这表明该项目可能是一个基于Pytorch框架构建推荐系统模型的项目,开发者可能在不断更新和维护该项目,并且该项目可能包含了构建推荐系统所必需的数据处理、模型训练、评估和推理等模块。在实际的项目中,这可能涉及到了如模型定义、数据加载、训练循环、模型保存与加载、评估指标的计算等众多步骤。 综上所述,该文件信息涉及到的知识点包括了推荐系统的概念、Pytorch在构建推荐系统模型中的应用、推荐系统中常见模型如NGCF和自动编码器、常用的基准数据集,以及在开发推荐系统时常用的编程语言Python和Pytorch。这些内容在数据科学和机器学习领域是基础且重要的知识点,对于希望在该领域深造的人员具有指导意义。

相关推荐

filetype

import argparse import os import sys import numpy as np import json import torch from PIL import Image sys.path.append(os.path.join(os.getcwd(), "GroundingDINO")) sys.path.append(os.path.join(os.getcwd(), "segment_anything")) # Grounding DINO import GroundingDINO.groundingdino.datasets.transforms as T from GroundingDINO.groundingdino.models import build_model from GroundingDINO.groundingdino.util.slconfig import SLConfig from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap # segment anything from segment_anything import sam_model_registry import cv2 import numpy as np import matplotlib.pyplot as plt def load_image(image_path): # load image image_pil = Image.open(image_path).convert("RGB") # load image transform = T.Compose( [ T.RandomResize([800], max_size=1333), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) image, _ = transform(image_pil, None) # 3, h, w return image_pil, image def load_model(model_config_path, model_checkpoint_path, bert_base_uncased_path, device): args = SLConfig.fromfile(model_config_path) args.device = device args.bert_base_uncased_path = bert_base_uncased_path model = build_model(args) checkpoint = torch.load(model_checkpoint_path, map_location="cpu") load_res = model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) print(load_res) _ = model.eval() return model def get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device="cpu"): caption = caption.lower() caption = caption.strip() if not caption.endswith("."): caption = caption + "." model = model.to(device) image = image.to(device) with torch.no_grad(): outputs = model(image[None], captions=[caption]) logits = outputs["pred_logits"].cpu().sigmoid()[0] # (nq, 256) boxes = outputs["pred_boxes"].cpu()[0] # (nq, 4) logits.shape[0] # filter output logits_filt = logits.clone() boxes_filt = boxes.clone() filt_mask = logits_filt.max(dim=1)[0] > box_threshold logits_filt = logits_filt[filt_mask] # num_filt, 256 boxes_filt = boxes_filt[filt_mask] # num_filt, 4 logits_filt.shape[0] # get phrase tokenlizer = model.tokenizer tokenized = tokenlizer(caption) # build pred pred_phrases = [] for logit, box in zip(logits_filt, boxes_filt): pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer) if with_logits: pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})") else: pred_phrases.append(pred_phrase) return boxes_filt, pred_phrases def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_box(box, ax, label): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) ax.text(x0, y0, label) def save_mask_data(output_dir, mask_list, box_list, label_list): value = 0 # 0 for background mask_img = torch.zeros(mask_list.shape[-2:]) for idx, mask in enumerate(mask_list): mask_img[mask.cpu().numpy()[0] == True] = value + idx + 1 plt.figure(figsize=(10, 10)) plt.imshow(mask_img.numpy()) plt.axis('off') plt.savefig(os.path.join(output_dir, 'mask.jpg'), bbox_inches="tight", dpi=300, pad_inches=0.0) json_data = [{ 'value': value, 'label': 'background' }] for label, box in zip(label_list, box_list): value += 1 name, logit = label.split('(') logit = logit[:-1] # the last is ')' json_data.append({ 'value': value, 'label': name, 'logit': float(logit), 'box': box.numpy().tolist(), }) with open(os.path.join(output_dir, 'mask.json'), 'w') as f: json.dump(json_data, f) if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True) parser.add_argument("--config", type=str, default="./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", help="path to config file") parser.add_argument( "--grounded_checkpoint", type=str, default="./groundingdino_swint_ogc.pth", help="path to checkpoint file" ) parser.add_argument( "--sam_version", type=str, default="vit_h", required=False, help="SAM ViT version: vit_b / vit_l / vit_h" ) parser.add_argument( "--sam_checkpoint", type=str, default="./sam_vit_h_4b8939.pth", help="path to sam checkpoint file" ) parser.add_argument( "--sam_hq_checkpoint", type=str, default=None, help="path to sam-hq checkpoint file" ) parser.add_argument( "--use_sam_hq", action="store_true", help="using sam-hq for prediction" ) parser.add_argument("--input_image", type=str, required=True, help="path to image file") parser.add_argument("--text_prompt", type=str, required=True, help="text prompt") parser.add_argument( "--output_dir", "-o", type=str, default="./outputs", help="output directory" ) parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold") parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold") parser.add_argument("--device", type=str, default="cpu", help="running on cpu only!, default=False") parser.add_argument("--bert_base_uncased_path", type=str, required=False, help="bert_base_uncased model path, default=False") args = parser.parse_args() # cfg config_file = args.config # change the path of the model config file grounded_checkpoint = args.grounded_checkpoint # change the path of the model sam_version = args.sam_version sam_checkpoint = args.sam_checkpoint sam_hq_checkpoint = args.sam_hq_checkpoint use_sam_hq = args.use_sam_hq image_path = args.input_image text_prompt = args.text_prompt output_dir = args.output_dir box_threshold = args.box_threshold text_threshold = args.text_threshold device = args.device bert_base_uncased_path = args.bert_base_uncased_path # make dir os.makedirs(output_dir, exist_ok=True) # load image image_pil, image = load_image(image_path) # load model model = load_model(config_file, grounded_checkpoint, bert_base_uncased_path, device=device) # visualize raw image image_pil.save(os.path.join(output_dir, "raw_image.jpg")) # run grounding dino model boxes_filt, pred_phrases = get_grounding_output( model, image, text_prompt, box_threshold, text_threshold, device=device ) # initialize SAM if use_sam_hq: predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device)) else: predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device)) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) predictor.set_image(image) size = image_pil.size H, W = size[1], size[0] for i in range(boxes_filt.size(0)): boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H]) boxes_filt[i][:2] -= boxes_filt[i][2:] / 2 boxes_filt[i][2:] += boxes_filt[i][:2] boxes_filt = boxes_filt.cpu() transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device) masks, _, _ = predictor.predict_torch( point_coords = None, point_labels = None, boxes = transformed_boxes.to(device), multimask_output = False, ) # draw output image plt.figure(figsize=(10, 10)) plt.imshow(image) for mask in masks: show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) for box, label in zip(boxes_filt, pred_phrases): show_box(box.numpy(), plt.gca(), label) plt.axis('off') plt.savefig( os.path.join(output_dir, "grounded_sam_output.jpg"), bbox_inches="tight", dpi=300, pad_inches=0.0 ) save_mask_data(output_dir, masks, boxes_filt, pred_phrases) 运行时报错 C:\Users\29386\.conda\envs\grounded_sam\python.exe C:\Users\29386\segment-anything\Grounded-Segment-Anything\grounded_sam_demo.py --config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py --grounded_checkpoint weights/groundingdino_swint_ogc.pth --sam_checkpoint weights/sam_vit_h_4b8939.pth --input_image assets/demo1.jpg --output_dir outputs --text_prompt cat C:\Users\29386\.conda\envs\grounded_sam\lib\site-packages\timm\models\layers\__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning) C:\Users\29386\.conda\envs\grounded_sam\lib\site-packages\torch\functional.py:513: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\TensorShape.cpp:3610.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] final text_encoder_type: bert-base-uncased C:\Users\29386\.conda\envs\grounded_sam\lib\site-packages\huggingface_hub\file_download.py:143: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\Users\29386\.cache\huggingface\hub\models--bert-base-uncased. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations. To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development warnings.warn(message) Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet` C:\Users\29386\segment-anything\Grounded-Segment-Anything\grounded_sam_demo.py:47: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. checkpoint = torch.load(model_checkpoint_path, map_location="cpu") Traceback (most recent call last): File "C:\Users\29386\segment-anything\Grounded-Segment-Anything\grounded_sam_demo.py", line 187, in <module> model = load_model(config_file, grounded_checkpoint, bert_base_uncased_path, device=device) File "C:\Users\29386\segment-anything\Grounded-Segment-Anything\grounded_sam_demo.py", line 47, in load_model checkpoint = torch.load(model_checkpoint_path, map_location="cpu") File "C:\Users\29386\.conda\envs\grounded_sam\lib\site-packages\torch\serialization.py", line 1065, in load with _open_file_like(f, 'rb') as opened_file: File "C:\Users\29386\.conda\envs\grounded_sam\lib\site-packages\torch\serialization.py", line 468, in _open_file_like return _open_file(name_or_buffer, mode) File "C:\Users\29386\.conda\envs\grounded_sam\lib\site-packages\torch\serialization.py", line 449, in __init__ super().__init__(open(name, mode)) FileNotFoundError: [Errno 2] No such file or directory: 'weights/groundingdino_swint_ogc.pth' 如何在windows系统上解决问题

李川雨
  • 粉丝: 43
上传资源 快速赚钱