Commit a5f45df
authored
[XPU][Fix] Fix large maxpool index (#2362)
This is to fix the pytorch/pytorch#167253 . It
does the following:
1. Use `index_t` instead of int and dispatch kernels accordingly.
(follows pytorch/pytorch#167427)
2. Use NHWC when output > INT_MAX (follows
pytorch/pytorch#167322)
3. Change other related dtype (like `num_wg` to `index_t` to avoid
overflow.
# Details
Test case:
```python
x = torch.zeros(74, 32, 30090, 81, device=torch.device("xpu"), dtype=torch.bfloat16)
torch.nn.functional.max_pool2d(x, kernel_size=(1,2), stride=(1,2), ceil_mode=False, padding=0)
```
It will throw the error:
```Bash
[MaxPool2d] Input shape: [74, 32, 30090, 81] output: [74, 32, 30090, 40]
[MaxPool2d] Strides: n=77993280 c=1 h=2592 w=32
[MaxPool2d] Memory format: ChannelsLast
[MaxPool2d Forward] ChannelsLast path: numBatch=74 numPlane=32 inputH=30090 inputW=81 outputH=30090 outputW=40 index_t=int64
[MaxPool2d Forward] Using vec_size=1 num_wg=-72057583935701024
Segmentation fault from GPU at 0xff00000c04e33000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
Segmentation fault from GPU at 0xff00000c04e33000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting.
Abort was called at 279 line in file:
./shared/source/os_interface/linux/drm_neo.cpp
[1] 77805 IOT instruction (core dumped) python
```
From the above code, the `num_wg` is overflow to negative, thus caused
segfault.1 parent 993ab70 commit a5f45df
1 file changed
+433
-267
lines changed
0 commit comments