pytorch torch.max 用法

pytorch torch.max 用法

import numpy as np
import torch

a = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
print('a.shape', a.shape)
print('a', a)
# 列 最大
# 返回的值包含两个数据(values, indices) 分别代表最大值的值和所在的索引(返回最大元素在各列的行索引)
a0 = torch.max(a.data, 0)
print('a0', a0)
# 行 最大
a1 = torch.max(a.data, 1)
print('a1', a1)
print('a1.indices', a1[1].data.numpy())

(1) torch.max(a): 返回输入a中所有元素的最大值。

(2) torch.max(a, 0): 返回每一列的最大值,且返回索引(返回最大元素在各列的行索引)。

(3) torch.max(a, 1): 返回每一行的最大值,且返回索引(返回最大元素在各行的列索引)。

(4) torch.max()[0]: 只返回最大值。

(5) torch.max()[1]: 只返回最大值的索引。

(6) torch.max()[0].data: 只返回variable中的数据部分(去掉Variable containing)。

(7) torch.max()[0].data.numpy(): 把数据转化成 numpy ndarray。

(8) torch.max()[0].numpy(): 把数据转化成 numpy ndarray。

import torch
a = torch.tensor([[1,2,3,14],[5,16,7,8],[9,10,11,12]])
print('a:', a,
      '\n\n torch.max(a):', torch.max(a),
      '\n\n torch.max(a, 0):', torch.max(a, 0),
      '\n\n torch.max(a, 0)[0]:', torch.max(a, 0)[0],
      '\n\n torch.max(a, 0)[1]:', torch.max(a, 0)[1],
      '\n\n torch.max(a, 0)[1].data:', torch.max(a, 0)[1].data,
      '\n\n torch.max(a, 0)[1].data.numpy():', torch.max(a, 0)[1].data.numpy(),
      '\n\n torch.max(a, 0)[1].numpy():', torch.max(a, 0)[1].numpy(),
      '\n\n torch.max(a, 1):', torch.max(a, 1))
a: tensor([[ 1,  2,  3, 14],
        [ 5, 16,  7,  8],
        [ 9, 10, 11, 12]]) 

 torch.max(a): tensor(16) 

 torch.max(a, 0): torch.return_types.max(
values=tensor([ 9, 16, 11, 14]),
indices=tensor([2, 1, 2, 0])) 

 torch.max(a, 0)[0]: tensor([ 9, 16, 11, 14]) 

 torch.max(a, 0)[1]: tensor([2, 1, 2, 0]) 

 torch.max(a, 0)[1].data: tensor([2, 1, 2, 0]) 

 torch.max(a, 0)[1].data.numpy(): [2 1 2 0] 

 torch.max(a, 0)[1].numpy(): [2 1 2 0] 

 torch.max(a, 1): torch.return_types.max(
values=tensor([14, 16, 12]),
indices=tensor([3, 1, 3]))

torch.max()[0].data.numpy() 和 torch.max()[0].numpy() 效果一样。

有些代码会出现 torch.max()[0].data.numpy() 的写法,这是因为在早期的PyTorch版本中,variable 和 tensor 是不同的数据格式,variable 可以进行反向传播,tensor 不可以,需要将 variable 转变成 tensor 再转变成 numpy。新的版本已经将 variable 和 tensor 合并,所以只用 torch.max()[1].numpy() 就可以了。

发表回复

您的电子邮箱地址不会被公开。