目录
在计算机视觉领域,人物抠图是一个经典且极具价值的任务,广泛应用于影视后期制作、图像合成、视频会议背景替换等场景。传统的人工抠图方法耗时费力,而借助深度学习技术,我们可以实现高效、精准的自动人物抠图。本文将介绍基于 MobileSAM 和 YOLOv11n 的自动人物抠图系统 Python 实现,该系统结合了目标检测与分割的优势,能够实现对人物的快速、精细抠图。
一、技术背景
1.1移动设备适配的 SAM(MobileSAM)
Segment Anything Model(SAM)是 Meta 推出的一种通用分割模型,能够为给定提示(如点、框、文字描述等)的任意输入图像生成像素级的分割掩码。MobileSAM 是为移动设备优化的 SAM 版本,具有更小的模型尺寸和较低的计算资源消耗,同时在分割性能上仍保持较高水平,非常适合在资源受限的环境中进行部署和应用,为我们在移动平台实现高效、便捷的人物抠图提供了基础。
1.2 YOLOv11n 算法
YOLO(You Only Look Once)系列算法是一类经典的实时目标检测算法,以其快速、高效的检测能力而闻名。YOLOv11n 作为该系列的改进版本,在模型结构、训练策略等方面进行了优化,能够更加准确地检测出图像中的目标物体(包括人物)并给出其边界框位置。在本系统中,我们利用 YOLOv11n 对图像中的人物进行初步的粗略定位和分割,为后续的精细化抠图操作提供基础区域信息。
二、实现代码解析
2.1环境搭建与依赖库
- 代码运行环境基于 Python,主要依赖的库包括 OpenCV(用于图像处理)、PyTorch(作为深度学习框架)、PIL(Python Imaging Library,用于图像加载和基本操作)、ultralytics(提供了 YOLO 系列模型的实现和接口)以及 mobile_sam(集成了 MobileSAM 模型的定义、加载和预测功能)等。需确保这些库已正确安装并配置好相应的环境。
2.2 MobileSAM 模型加载
# 加载 MobileSAM
sam_checkpoint = "weights/mobile_sam.pt"
model_type = "vit_t"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
mobile_sam = mobile_sam.to(device=device)
mobile_sam.eval()
predictor = SamPredictor(mobile_sam)
通过上述代码,加载预训练的 MobileSAM 模型,其中 sam_checkpoint
指向模型权重文件路径,model_type
指定模型类型。将模型转移到可用的设备(GPU 或 CPU)上,并设置为评估模式,以便进行推理。同时,创建 SamPredictor
实例,用于后续的图像分割预测操作。
2.3 图像预处理与尺寸调整函数
def resizeed_img(result,max_dim = 512):
# 调整图片尺寸,确保长和宽最多不超过1024
h, w = result.shape[:2]
if h > max_dim or w > max_dim:
# 计算缩放比例
scale = min(max_dim / h, max_dim / w)
new_h, new_w = int(h * scale), int(w * scale)
# 按比例缩放图片
result_resized = cv2.resize(result, (new_w, new_h))
return result_resized
else:
return result
定义 resizeed_img
函数,用于将图像尺寸调整至不超过指定的最大维度(默认为 512),以适应模型输入要求并减少计算量。在实际处理过程中,会根据原始图像的宽高比例进行等比例缩放,确保图像内容不失真。
2.4 主函数实现
- 图像加载与转换
# 1. 加载图片
im0 = cv2.imread(img_path)
h, w = im0.shape[:2]
im_rgb = cv2.cvtColor(im0, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(im_rgb)
使用 OpenCV 的 cv2.imread
函数加载待处理图像,将其从 BGR 颜色空间转换为 RGB 颜色空间,并创建一个对应的 PIL 图像对象,以便后续作为 MobileSAM 的输入。
- YOLO 人物检测与粗略分割
# 2. 用YOLO分割大致人物区域
yolo_model = YOLO(os.path.join(os.path.dirname(os.path.abspath(__file__)), "yolo11n-seg.pt"))
results = yolo_model(im0)
person_masks = []
for result in results[0]:
if hasattr(result, "boxes") and hasattr(result.boxes, "cls"):
if int(result.boxes.cls[0]) == 0: # 0 是 person
# 获取掩码
mask