nn.Linear()函数详解

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)[原文地址](Linear — PyTorch 1.12 documentation)

其中的参数:
  • in_features – 每个输入样本的大小。
  • out_features – 每个输出样本的大小。
  • bias – 如果设置为False,该层将不会学习附加偏差。默认为:True。
shape:
  • Input:(#,IN),其中#表示表示任意大小的维度。IN表示in_features
  • Outout:(#,OUT),其中除了最后一个OUT外,其余的维度都和输入的shape相同,OUT表示out_features
实例:
>>> m = nn.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
Logo

华为开发者空间,是为全球开发者打造的专属开发空间,汇聚了华为优质开发资源及工具,致力于让每一位开发者拥有一台云主机,基于华为根生态开发、创新。

更多推荐