pytorch torch.unsqueeze 和 torch.squeeze 用法

pytorch torch.unsqueeze 和 torch.squeeze 用法

1. torch.unsqueeze 详解

# torch.unsqueeze(input, dim, out=None)
# 作用:扩展维度
# 返回一个新的张量,对输入的既定位置插入维度 1
# 注意:返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
# 参数:
# tensor (Tensor) – 输入张量
# dim (int) – 插入维度的索引
# out (Tensor, optional) – 结果张量

import torch
x = torch.Tensor([1, 2, 3, 4])

print('-' * 100)
print('步骤')
print('x', x)
print('x.size()', x.size())
print('x.dim()', x.dim())
print('x.numpy()', x.numpy())

print('-' * 100)
print('步骤 在 0维 插入维度1')
print(torch.unsqueeze(x, 0))
print(torch.unsqueeze(x, 0).size())
print(torch.unsqueeze(x, 0).dim())
print(torch.unsqueeze(x, 0).numpy())

print('-' * 100)
print('步骤 在 1维 插入维度1')
print(torch.unsqueeze(x, 1))
print(torch.unsqueeze(x, 1).size())
print(torch.unsqueeze(x, 1).dim())

2. torch.squeeze 详解

# torch.squeeze(input, dim=None, out=None)
#作用:降维
#将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
#当给定dim时,那么挤压操作只在给定维度上。例如 输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

#注意:返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
#参数:
#input (Tensor) – 输入张量
#dim (int, optional) – 如果给定,则input只会在给定维度挤压
#out (Tensor, optional) – 输出张量

import torch

print("-" * 100)

m = torch.zeros(2, 1, 2, 1, 2)
print(m.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m)
print(n.size())  # torch.Size([2, 2, 2])

n = torch.squeeze(m, 0)  # 当给定dim时,那么挤压操作只在给定维度上
print(n.size())  # torch.Size([2, 1, 2, 1, 2])

n = torch.squeeze(m, 1)
print(n.size())  # torch.Size([2, 2, 1, 2])

3. unsqueeze_ 和 unsqueeze 的区别
#unsqueeze_和 unsqueeze 的区别
#unsqueeze_ 和 unsqueeze 实现一样的功能,区别在于 unsqueeze_ 是 in_place 操作。unsqueeze_ 会对自己改变。
#unsqueeze 不会对使用 unsqueeze 的 tensor 进行改变,想要获取 unsqueeze 后的值必须赋予个新值。

print("-" * 100)
a = torch.Tensor([1, 2, 3, 4])
print('a', a)

b = torch.unsqueeze(a, 1)
print('b', b)

print('a', a)

print("-" * 100)
a = torch.Tensor([1, 2, 3, 4])
print('a', a)

print('a', a.unsqueeze_(1))

print('a', a)

 

发表回复

您的电子邮箱地址不会被公开。