import torch x = torch.randn(9, 3) # 输入的维度是(9,3) m = torch.nn.Linear(3, 1) # 输入的特征维度是3,输出的特征维度是1 output = m(x) print('m.weight.shape:\n', m.weight.shape) print('m.bias.shape:\n', m.bias.shape) print('output.shape:\n', output.shape) print('x:\n', x) print('m:\n', m.weight) print('output:\n', output) # torch.mm(a, b) 是矩阵a和b矩阵相乘 # result = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的 result = torch.mm(x, m.weight.t()) + m.bias print('result.shape:\n', result.shape) # 判断两种方式的运算结果是否相等 print(torch.equal(output, result))
m.weight.shape: torch.Size([1, 3]) m.bias.shape: torch.Size([1]) output.shape: torch.Size([9, 1]) x: tensor([[-0.0222, 0.2094, -0.7877], [ 1.4259, -0.1152, -0.2617], [ 1.0417, -0.3274, -0.2228], [ 0.7585, -0.0021, 0.9533], [ 0.5270, 0.4336, -0.6255], [ 0.3288, 0.4246, 0.3199], [-0.6834, 0.8066, 1.0961], [-0.0174, 0.1950, 1.6762], [-0.1756, 0.3667, -1.1426]]) m: Parameter containing: tensor([[-0.3427, 0.4085, -0.5723]], requires_grad=True) output: tensor([[ 0.8231], [-0.1068], [-0.0842], [-0.5273], [ 0.6336], [ 0.1568], [ 0.2155], [-0.5946], [ 1.1431]], grad_fn=<AddmmBackward0>) result.shape: torch.Size([9, 1]) True Process finished with exit code 0
为什么 m.weight.shape = (1,3)?
答:因为线性变换的公式是:
先生成一个(1,3)的weight,实际运算中再转置,这样就能和x做矩阵乘法。