grid_sampler
时间: 2025-03-04 11:06:34 浏览: 38
### Grid Sampler 使用方法及常见问题
#### 什么是Grid Sampler?
`grid_sampler`函数用于基于采样网格执行双线性插值或最近邻插值来变换输入图像。此操作常应用于计算机视觉任务中,如空间变换网络(STN),其中需要根据特定的几何变换调整图片位置。
该函数接受两个张量作为参数:一个是待变形的源图像(batch, channel, height, width),另一个是指明如何映射原图素至新坐标系下的网格形状(batch, height_out, width_out, 2)[^1]。
#### 如何使用Grid Sampler?
下面是一个简单的Python代码片段展示怎样利用PyTorch框架内的`torch.nn.functional.grid_sample()`来进行基本的空间变换:
```python
import torch
from torch import nn
import matplotlib.pyplot as plt
def show(imgs):
"""辅助显示函数"""
if not isinstance(imgs, list):
imgs = [imgs]
fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
for i, img in enumerate(imgs):
img = img.detach().cpu()
img = (img - img.min()) / (img.max() - img.min())
axs[0, i].imshow(img.permute(1, 2, 0))
axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
# 创建一个测试用的小批量单通道灰度图像
input_tensor = torch.full((1, 1, 3, 3), 0.5)
# 定义一个简单平移效果的采样网格
theta = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]])
grid = nn.functional.affine_grid(theta, input_tensor.size(), align_corners=True).add_(1.).div_(2.)
output_tensor = nn.functional.grid_sample(input_tensor, grid, mode='bilinear', padding_mode="zeros", align_corners=True)
show(output_tensor)
plt.show()
```
这段脚本创建了一个全为0.5填充的小批次单通道灰度图像,并应用了一个恒等仿射变换矩阵(即不改变任何东西)。为了可视化结果,在调用了`nn.functional.grid_sample()`之后,采用了自定义的帮助函数`show()`来呈现最终输出。
#### 常见问题解答
- **Q:** 如果遇到边界外索引怎么办?
当指定的目标像素位于原始图像之外时,可以通过设置`padding_mode`参数控制行为,默认情况下会返回零(`'zeros'`);也可以选择反射边缘(`'reflection'`)或将最接近的有效值复制过去(`'border'`)。
- **Q:** `align_corners`选项有什么作用?
此标志决定了是否将角点对齐。当设为True时,意味着四个角落会被视为对应于实际坐标的整数部分;而False则采用更严格的中心取样方式,通常推荐后者除非有特殊需求。
阅读全文
相关推荐



















