pytorch torch.max 用法

pytorch torch.max 用法

import numpy as np
import torch

a = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
print('a.shape', a.shape)
print('a', a)
# 列 最大
# 返回的值包含两个数据(values, indices) 分别代表最大值的值和所在的索引
a0 = torch.max(a.data, 0)
print('a0', a0)
# 行 最大
a1 = torch.max(a.data, 1)
print('a1', a1)
print('a1.indices', a1[1].data.numpy())

发表回复

您的电子邮箱地址不会被公开。