pytorch torch.nn.Linear 用法

pytorch torch.nn.Linear 用法

import torch

x = torch.randn(9, 3)  # 输入的维度是(9,3)
m = torch.nn.Linear(3, 1)  # 输入的特征维度是3,输出的特征维度是1
output = m(x)
print('m.weight.shape:\n', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)

print('x:\n', x)
print('m:\n', m.weight)
print('output:\n', output)

# torch.mm(a, b) 是矩阵a和b矩阵相乘
# result = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
result = torch.mm(x, m.weight.t()) + m.bias
print('result.shape:\n', result.shape)

# 判断两种方式的运算结果是否相等
print(torch.equal(output, result))
m.weight.shape:
 torch.Size([1, 3])
m.bias.shape:
 torch.Size([1])
output.shape:
 torch.Size([9, 1])
x:
 tensor([[-0.0222,  0.2094, -0.7877],
        [ 1.4259, -0.1152, -0.2617],
        [ 1.0417, -0.3274, -0.2228],
        [ 0.7585, -0.0021,  0.9533],
        [ 0.5270,  0.4336, -0.6255],
        [ 0.3288,  0.4246,  0.3199],
        [-0.6834,  0.8066,  1.0961],
        [-0.0174,  0.1950,  1.6762],
        [-0.1756,  0.3667, -1.1426]])
m:
 Parameter containing:
tensor([[-0.3427,  0.4085, -0.5723]], requires_grad=True)
output:
 tensor([[ 0.8231],
        [-0.1068],
        [-0.0842],
        [-0.5273],
        [ 0.6336],
        [ 0.1568],
        [ 0.2155],
        [-0.5946],
        [ 1.1431]], grad_fn=<AddmmBackward0>)
result.shape:
 torch.Size([9, 1])
True

Process finished with exit code 0

为什么 m.weight.shape = (1,3)?
答:因为线性变换的公式是:

先生成一个(1,3)的weight,实际运算中再转置,这样就能和x做矩阵乘法。

发表回复

您的电子邮箱地址不会被公开。