|
55 | 55 | }, |
56 | 56 | { |
57 | 57 | "cell_type": "code", |
58 | | - "execution_count": 7, |
| 58 | + "execution_count": null, |
59 | 59 | "id": "ed96287d-8fd1-454b-9bfd-9f0eaff6c56e", |
60 | 60 | "metadata": {}, |
61 | 61 | "outputs": [], |
62 | 62 | "source": [ |
63 | | - "# 定义抽帧策略枚举类\n", |
64 | 63 | "class Strategy(Enum):\n", |
65 | | - " # 固定间隔抽帧策略,例如每1秒抽一帧\n", |
| 64 | + " # sampling stragegies\n", |
| 65 | + " # constant interval: sampling at a constant interval, fps sampling\n", |
66 | 66 | " CONSTANT_INTERVAL = \"constant_interval\"\n", |
67 | | - " # 均匀间隔抽帧策略,根据设定的最大帧数均匀从视频全长度抽取\n", |
| 67 | + " # even interval: sampling at an even interval, uniform sampling\n", |
68 | 68 | " EVEN_INTERVAL = \"even_interval\"" |
69 | 69 | ] |
70 | 70 | }, |
|
78 | 78 | }, |
79 | 79 | { |
80 | 80 | "cell_type": "code", |
81 | | - "execution_count": 18, |
| 81 | + "execution_count": null, |
82 | 82 | "id": "974013e8-6436-403f-a5a3-fa245f322939", |
83 | 83 | "metadata": {}, |
84 | 84 | "outputs": [], |
|
92 | 92 | " use_timestamp: bool = True,\n", |
93 | 93 | " keyframe_naming_template: str = \"frame_{:04d}.jpg\",\n", |
94 | 94 | ") -> list[str]:\n", |
95 | | - " \"\"\"将视频按照指定策略抽帧\n", |
96 | | - " 参数:\n", |
97 | | - " video_file_path (str): 视频文件路径\n", |
98 | | - " output_dir (str): 输出目录\n", |
99 | | - " extraction_strategy (Optional[Strategy], optional): 抽帧策略。\n", |
100 | | - " 固定间隔 比如 1s 抽一帧 或\n", |
101 | | - " 均匀间隔 根据设定的最大帧数 均匀从视频全长度均匀抽取\n", |
102 | | - " 默认固定间隔 1s 抽一帧\n", |
103 | | - " interval_in_seconds (Optional[float], optional): 固定间隔抽帧的间隔时间. 默认 1s 抽一帧\n", |
104 | | - " max_frames (Optional[int], optional): 最大抽帧帧数. 默认 10 帧\n", |
105 | | - " use_timestamp (bool): 是否输出视频时间戳, 默认True\n", |
106 | | - " keyframe_naming_template (_type_, optional): 抽帧图片命名模板\n", |
107 | | - " 返回:\n", |
108 | | - " list[str]: 抽帧图片路径列表\n", |
109 | | - " list[float]: 视频采样帧对应的时间戳\n", |
| 95 | + " \"\"\"sampling videos and extract keyframes with different strategies.\n", |
| 96 | + " Args:\n", |
| 97 | + " video_file_path (str): video path\n", |
| 98 | + " output_dir (str): output directory for sampled keyframes\n", |
| 99 | + " extraction_strategy (Optional[Strategy], optional): extraction strategy. Defaults to Strategy.EVEN_INTERVAL.\n", |
| 100 | + " interval_in_seconds (Optional[float], optional): the sampling interval\n", |
| 101 | + " max_frames (Optional[int], optional): maximum number of sampled frames. Defaults to 10.\n", |
| 102 | + " use_timestamp (bool): whether to output video timestamps. Defaults to True.\n", |
| 103 | + " keyframe_naming_template (_type_, optional): keyframe naming template. Defaults to \"frame_{:04d}.jpg\".\n", |
| 104 | + " Returns:\n", |
| 105 | + " list[str]: sampled keyframe paths\n", |
| 106 | + " list[float]: timestamps of sampled keyframes\n", |
110 | 107 | " \"\"\"\n", |
111 | | - " # 检查输出目录是否存在,如果不存在则创建\n", |
112 | 108 | " if not os.path.exists(output_dir):\n", |
113 | 109 | " os.makedirs(output_dir)\n", |
114 | | - " # 使用OpenCV打开视频文件\n", |
115 | 110 | " cap = cv2.VideoCapture(video_file_path)\n", |
116 | | - " # 获取视频的帧率\n", |
117 | 111 | " fps = cap.get(cv2.CAP_PROP_FPS)\n", |
118 | | - " # 获取视频的总帧数\n", |
119 | 112 | " length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", |
120 | 113 | "\n", |
121 | | - " # 根据策略选择抽帧间隔\n", |
122 | 114 | " if extraction_strategy == Strategy.CONSTANT_INTERVAL:\n", |
123 | | - " # 计算固定间隔抽帧的帧间隔\n", |
124 | 115 | " frame_interval = int(fps * interval_in_seconds)\n", |
125 | 116 | " elif extraction_strategy == Strategy.EVEN_INTERVAL:\n", |
126 | | - " # 计算均匀间隔抽帧的帧间隔\n", |
127 | 117 | " frame_interval = int(length / max_frames)\n", |
128 | 118 | " else:\n", |
129 | | - " # 如果策略无效,抛出异常\n", |
130 | 119 | " raise ValueError(\"Invalid extraction strategy\")\n", |
131 | | - " # 初始化帧计数器\n", |
132 | 120 | " frame_count = 0\n", |
133 | | - " # 初始化关键帧列表\n", |
134 | 121 | " keyframes = []\n", |
135 | 122 | " timestamps = []\n", |
136 | | - " # 循环读取视频帧\n", |
137 | 123 | " while True:\n", |
138 | | - " # 读取一帧\n", |
139 | 124 | " ret, frame = cap.read()\n", |
140 | | - " # 如果读取失败,跳出循环\n", |
141 | 125 | " if not ret:\n", |
142 | 126 | " break\n", |
143 | | - " # 如果当前帧是关键帧\n", |
144 | 127 | " if frame_count % frame_interval == 0:\n", |
145 | | - " # 生成关键帧的文件名\n", |
146 | 128 | " image_path = os.path.join(\n", |
147 | 129 | " output_dir, keyframe_naming_template.format(len(keyframes))\n", |
148 | 130 | " )\n", |
149 | | - " # 将关键帧保存为图片\n", |
150 | 131 | " cv2.imwrite(\n", |
151 | 132 | " image_path,\n", |
152 | 133 | " frame,\n", |
153 | 134 | " )\n", |
154 | | - " # 将关键帧路径添加到列表中\n", |
155 | 135 | " keyframes.append(image_path)\n", |
156 | 136 | " timestamps.append(round(frame_count / fps, 1))\n", |
157 | | - " # 增加帧计数器\n", |
158 | 137 | " frame_count += 1\n", |
159 | | - " # 如果关键帧数量达到最大值,跳出循环\n", |
160 | 138 | " if len(keyframes) >= max_frames:\n", |
161 | 139 | " break\n", |
162 | 140 | "\n", |
163 | | - " print(\"抽取帧数:\", len(keyframes))\n", |
164 | | - " # 返回关键帧路径列表\n", |
| 141 | + " print(\"sampled frames:\", len(keyframes))\n", |
165 | 142 | " if use_timestamp:\n", |
166 | 143 | " return keyframes, timestamps\n", |
167 | 144 | " return keyframes, None\n", |
168 | 145 | "\n", |
169 | 146 | "def resize(image):\n", |
170 | | - " \"\"\"\n", |
171 | | - " 调整图片大小以适应指定的尺寸。\n", |
172 | | - " 参数:\n", |
173 | | - " image (numpy.ndarray): 输入的图片,格式为numpy数组。\n", |
174 | | - " 返回:\n", |
175 | | - " numpy.ndarray: 调整大小后的图片。\n", |
176 | | - " \"\"\"\n", |
177 | | - " # 获取图片的原始高度和宽度\n", |
178 | 147 | " height, width = image.shape[:2]\n", |
179 | | - " # 根据图片的宽高比确定目标尺寸\n", |
180 | 148 | " if height < width:\n", |
181 | 149 | " target_height, target_width = 480, 640\n", |
182 | 150 | " else:\n", |
183 | 151 | " target_height, target_width = 640, 480\n", |
184 | | - " # 如果图片尺寸已经小于或等于目标尺寸,则直接返回原图片\n", |
185 | 152 | " if height <= target_height and width <= target_width:\n", |
186 | 153 | " return image\n", |
187 | | - " # 计算新的高度和宽度,保持图片的宽高比\n", |
188 | 154 | " if height / target_height < width / target_width:\n", |
189 | 155 | " new_width = target_width\n", |
190 | 156 | " new_height = int(height * (new_width / width))\n", |
191 | 157 | " else:\n", |
192 | 158 | " new_height = target_height\n", |
193 | 159 | " new_width = int(width * (new_height / height))\n", |
194 | | - " # 调整图片大小\n", |
195 | 160 | " return cv2.resize(image, (new_width, new_height))\n", |
196 | 161 | "\n", |
197 | | - "# 定义方法将指定路径图片resize到合适大小并转为Base64编码\n", |
198 | 162 | "def encode_image(image_path: str) -> str:\n", |
199 | | - " \"\"\"\n", |
200 | | - " 将指定路径的图片进行编码\n", |
201 | | - " 参数:\n", |
202 | | - " image_path (str): 图片文件的路径\n", |
203 | | - " 返回:\n", |
204 | | - " str: 编码后的图片字符串\n", |
205 | | - " \"\"\"\n", |
206 | | - " # 读取图片\n", |
207 | 163 | " image = cv2.imread(image_path)\n", |
208 | | - " # 调整图片大小\n", |
209 | 164 | " image_resized = resize(image)\n", |
210 | | - " # 将图片编码为JPEG格式\n", |
211 | 165 | " _, encoded_image = cv2.imencode(\".jpg\", image_resized)\n", |
212 | | - " # 将编码后的图片转换为Base64字符串\n", |
213 | 166 | " return base64.b64encode(encoded_image).decode(\"utf-8\")\n", |
214 | 167 | "\n", |
215 | 168 | "def construct_messages(image_paths: list[str], timestamps: list[float], prompt: str) -> list[dict]:\n", |
216 | 169 | " \"\"\"\n", |
217 | | - " 构造包含文本和图像的消息列表。\n", |
218 | | - " 参数:\n", |
219 | | - " image_paths (list[str]): 图像文件路径列表。\n", |
220 | | - " timestamps (list[float]): 视频的时间戳。\n", |
221 | | - " prompt (str): 文本提示。\n", |
222 | | - " 返回:\n", |
223 | | - " list[dict]: 包含文本和图像的消息列表。\n", |
| 170 | + " construct messages for the video understanding\n", |
224 | 171 | " \"\"\"\n", |
225 | | - " # 初始化消息内容列表\n", |
226 | 172 | " content = []\n", |
227 | | - " # 遍历图像路径列表\n", |
228 | 173 | " for idx, image_path in enumerate(image_paths):\n", |
229 | | - " # 为每个图像路径构造一个图像URL消息\n", |
230 | 174 | " if timestamps is not None:\n", |
| 175 | + " # add timestamp for each frame\n", |
231 | 176 | " content.append({\n", |
232 | 177 | " \"type\": \"text\",\n", |
233 | 178 | " \"text\": f'[{timestamps[idx]} second]'\n", |
|
236 | 181 | " {\n", |
237 | 182 | " \"type\": \"image_url\",\n", |
238 | 183 | " \"image_url\": {\n", |
239 | | - " # 使用Base64编码将图像转换为数据URL\n", |
240 | 184 | " \"url\": f\"data:image/jpeg;base64,{encode_image(image_path)}\",\n", |
241 | | - " # 指定图像细节级别为低\n", |
242 | 185 | " \"detail\":\"low\"\n", |
243 | 186 | " },\n", |
244 | 187 | " }\n", |
|
248 | 191 | " \"type\": \"text\",\n", |
249 | 192 | " \"text\": prompt,\n", |
250 | 193 | " })\n", |
251 | | - " # 返回包含文本和图像的消息列表\n", |
252 | 194 | " return [\n", |
253 | 195 | " {\n", |
254 | 196 | " \"role\": \"user\",\n", |
|
274 | 216 | }, |
275 | 217 | { |
276 | 218 | "cell_type": "code", |
277 | | - "execution_count": 12, |
| 219 | + "execution_count": null, |
278 | 220 | "id": "f48fd468-12d6-46c9-ae19-c2c981bdc6c2", |
279 | 221 | "metadata": {}, |
280 | 222 | "outputs": [ |
|
324 | 266 | "# sampling video frames\n", |
325 | 267 | "sampling_fps = 1\n", |
326 | 268 | "max_frames = 30\n", |
| 269 | + "sampling_interval = 1.0 / sampling_fps\n", |
327 | 270 | "selected_images, timestamps = preprocess_video(\n", |
328 | 271 | " video_file_path=video_path,\n", |
329 | 272 | " output_dir=\"video_frames\",\n", |
330 | 273 | " extraction_strategy=Strategy.CONSTANT_INTERVAL,\n", |
331 | | - " interval_in_seconds=sampling_fps,\n", |
| 274 | + " interval_in_seconds=sampling_interval,\n", |
332 | 275 | " use_timestamp=True,\n", |
333 | 276 | " max_frames=max_frames\n", |
334 | 277 | ")\n", |
|
348 | 291 | }, |
349 | 292 | { |
350 | 293 | "cell_type": "code", |
351 | | - "execution_count": 15, |
| 294 | + "execution_count": null, |
352 | 295 | "id": "89845b90-e976-45c1-84cd-7239963101ee", |
353 | 296 | "metadata": {}, |
354 | 297 | "outputs": [ |
|
369 | 312 | "# sampling video frames\n", |
370 | 313 | "sampling_fps = 1\n", |
371 | 314 | "max_frames = 30\n", |
| 315 | + "sampling_interval = 1.0 / sampling_fps\n", |
372 | 316 | "selected_images, timestamps = preprocess_video(\n", |
373 | 317 | " video_file_path=video_path,\n", |
374 | 318 | " output_dir=\"video_frames\",\n", |
375 | 319 | " extraction_strategy=Strategy.CONSTANT_INTERVAL,\n", |
376 | | - " interval_in_seconds=sampling_fps,\n", |
| 320 | + " interval_in_seconds=sampling_interval,\n", |
377 | 321 | " use_timestamp=True,\n", |
378 | 322 | " max_frames=max_frames\n", |
379 | 323 | ")\n", |
|
393 | 337 | }, |
394 | 338 | { |
395 | 339 | "cell_type": "code", |
396 | | - "execution_count": 17, |
| 340 | + "execution_count": null, |
397 | 341 | "id": "2c4fbac9-5c82-447b-b174-bec330bd70df", |
398 | 342 | "metadata": {}, |
399 | 343 | "outputs": [ |
|
418 | 362 | "# sampling video frames\n", |
419 | 363 | "sampling_fps = 1\n", |
420 | 364 | "max_frames = 30\n", |
| 365 | + "sampling_interval = 1.0 / sampling_fps\n", |
421 | 366 | "selected_images, timestamps = preprocess_video(\n", |
422 | 367 | " video_file_path=video_path,\n", |
423 | 368 | " output_dir=\"video_frames\",\n", |
424 | 369 | " extraction_strategy=Strategy.CONSTANT_INTERVAL,\n", |
425 | | - " interval_in_seconds=sampling_fps,\n", |
| 370 | + " interval_in_seconds=sampling_interval,\n", |
426 | 371 | " use_timestamp=True,\n", |
427 | 372 | " max_frames=max_frames\n", |
428 | 373 | ")\n", |
|
0 commit comments