import numpy as np import torch a = torch.tensor([[1,2,3,4],[5,6,7,8],[9,10,11,12]]) print('a.shape', a.shape) print('a', a) # 列 最大 # 返回的值包含两个数据(values, indices) 分别代表最大值的值和所在的索引 a0 = torch.max(a.data, 0) print('a0', a0) # 行 最大 a1 = torch.max(a.data, 1) print('a1', a1) print('a1.indices', a1[1].data.numpy())