Skip to content

Commit a5f45df

Browse files
authored
[XPU][Fix] Fix large maxpool index (#2362)
This is to fix the pytorch/pytorch#167253 . It does the following: 1. Use `index_t` instead of int and dispatch kernels accordingly. (follows pytorch/pytorch#167427) 2. Use NHWC when output > INT_MAX (follows pytorch/pytorch#167322) 3. Change other related dtype (like `num_wg` to `index_t` to avoid overflow. # Details Test case: ```python x = torch.zeros(74, 32, 30090, 81, device=torch.device("xpu"), dtype=torch.bfloat16) torch.nn.functional.max_pool2d(x, kernel_size=(1,2), stride=(1,2), ceil_mode=False, padding=0) ``` It will throw the error: ```Bash [MaxPool2d] Input shape: [74, 32, 30090, 81] output: [74, 32, 30090, 40] [MaxPool2d] Strides: n=77993280 c=1 h=2592 w=32 [MaxPool2d] Memory format: ChannelsLast [MaxPool2d Forward] ChannelsLast path: numBatch=74 numPlane=32 inputH=30090 inputW=81 outputH=30090 outputW=40 index_t=int64 [MaxPool2d Forward] Using vec_size=1 num_wg=-72057583935701024 Segmentation fault from GPU at 0xff00000c04e33000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting. Segmentation fault from GPU at 0xff00000c04e33000, ctx_id: 1 (CCS) type: 0 (NotPresent), level: 1 (PDE), access: 0 (Read), banned: 1, aborting. Abort was called at 279 line in file: ./shared/source/os_interface/linux/drm_neo.cpp [1] 77805 IOT instruction (core dumped) python ``` From the above code, the `num_wg` is overflow to negative, thus caused segfault.
1 parent 993ab70 commit a5f45df

File tree

1 file changed

+433
-267
lines changed

1 file changed

+433
-267
lines changed

0 commit comments

Comments
 (0)