[pytorch]在mps设备上,计算max()时出错。

2023-12-07 355 views
7

🐛 描述错误

尝试了以下几种方式,希望有帮助~

import torch

print(torch.__version__)
a = torch.Tensor([1, 2, 2, 3])
print("a: ", a)
b = a.to('mps')
print(f"a.max(): {a.max()}, b.max(): {b.max()}")

c = a.to('mps').to(torch.int)
d = c.to('cpu')
print("c: ", c, "d: ", d)
print(f"a.max(): {a.max()}, c.max(): {c.max()}, d.max(): {d.max()}")

print('============')
c = a.to('mps').to(int)
d = c.to('cpu')
print("c: ", c, "d: ", d)
print(f"a.max(): {a.max()}, c.max(): {c.max()}, d.max(): {d.max()}")

print('============clear')
a = torch.Tensor(
    [[1.0000, 0.0000, 377.2475, 237.9320, 640.0000, 420.5970],
     [2.0000, 0.0000, 523.8516, 55.7052, 640.0000, 219.5220],
     [2.0000, 0.0000, 367.9471, 78.8564, 455.8632, 198.4222],
     [3.0000, 0.0000, 46.5807, 58.2641, 466.7800, 439.6348]],
).to('mps')
b = a[:, 0]
c = b.unique(return_counts=True)[1].cpu().max()
d = b.unique(return_counts=True)[1].max()
print(f"b: {b}\nc: {c}\nd: {d}")
图像

版本

(mm2) ➜  mmyolo git:(main) ✗ python collect_env.py
Collecting environment information...
PyTorch version: 1.13.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.0.1 (arm64)
GCC version: Could not collect
Clang version: 14.0.0 (clang-1400.0.29.202)
CMake version: version 3.26.4
Libc version: N/A

Python version: 3.8.15 (default, Nov 24 2022, 08:57:44)  [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-13.0.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.23.0
[pip3] pytorch-lightning==1.9.0
[pip3] torch==1.13.1
[pip3] torchmetrics==0.8.0
[pip3] torchstat==0.0.7
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.14.1
[conda] numpy                     1.23.0                   pypi_0    pypi
[conda] pytorch-lightning         1.9.0                    pypi_0    pypi
[conda] torch                     1.13.1                   pypi_0    pypi
[conda] torchmetrics              0.8.0                    pypi_0    pypi
[conda] torchstat                 0.0.7                    pypi_0    pypi
[conda] torchsummary              1.5.1                    pypi_0    pypi
[conda] torchvision               0.14.1                   pypi_0    pypi

抄送@kulinseth @albanD @malfet @DenisVieriu97 @razarmehr @abhudev

回答

3

mps 设备上计算最大值时出错

8

1.13 中很多东西都被破坏了,当时 MPS 仍然是一个原型功能。这在 2.0.0 中按预期工作(随意查找修复它的问题,在 nightlies 中甚至支持超过 int64 值的最大值):

2.0.0
a:  tensor([1., 2., 2., 3.])
a.max(): 3.0, b.max(): 3.0
c:  tensor([1, 2, 2, 3], device='mps:0', dtype=torch.int32) d:  tensor([1, 2, 2, 3], dtype=torch.int32)
a.max(): 3.0, c.max(): 3, d.max(): 3
============
c:  tensor([1, 2, 2, 3], device='mps:0') d:  tensor([1, 2, 2, 3])
/Users/nshulga/test/baz.py:18: UserWarning: MPS: no support for int64 min/max ops, casting it to int32 (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/mps/operations/ReduceOps.mm:1271.)
  print(f"a.max(): {a.max()}, c.max(): {c.max()}, d.max(): {d.max()}")
a.max(): 3.0, c.max(): 3, d.max(): 3
============clear
b: tensor([1., 2., 2., 3.], device='mps:0')
c: 2
d: 2
1

请使用最新版本,如果问题仍然重现,请立即创建新问题