import torch
x = torch.randn(9, 3) # 输入的维度是(9,3)
m = torch.nn.Linear(3, 1) # 输入的特征维度是3,输出的特征维度是1
output = m(x)
print('m.weight.shape:\n', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
print('x:\n', x)
print('m:\n', m.weight)
print('output:\n', output)
# torch.mm(a, b) 是矩阵a和b矩阵相乘
# result = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
result = torch.mm(x, m.weight.t()) + m.bias
print('result.shape:\n', result.shape)
# 判断两种方式的运算结果是否相等
print(torch.equal(output, result))
m.weight.shape:
torch.Size([1, 3])
m.bias.shape:
torch.Size([1])
output.shape:
torch.Size([9, 1])
x:
tensor([[-0.0222, 0.2094, -0.7877],
[ 1.4259, -0.1152, -0.2617],
[ 1.0417, -0.3274, -0.2228],
[ 0.7585, -0.0021, 0.9533],
[ 0.5270, 0.4336, -0.6255],
[ 0.3288, 0.4246, 0.3199],
[-0.6834, 0.8066, 1.0961],
[-0.0174, 0.1950, 1.6762],
[-0.1756, 0.3667, -1.1426]])
m:
Parameter containing:
tensor([[-0.3427, 0.4085, -0.5723]], requires_grad=True)
output:
tensor([[ 0.8231],
[-0.1068],
[-0.0842],
[-0.5273],
[ 0.6336],
[ 0.1568],
[ 0.2155],
[-0.5946],
[ 1.1431]], grad_fn=<AddmmBackward0>)
result.shape:
torch.Size([9, 1])
True
Process finished with exit code 0
为什么 m.weight.shape = (1,3)?
答:因为线性变换的公式是:

先生成一个(1,3)的weight,实际运算中再转置,这样就能和x做矩阵乘法。