77import torch .nn .functional as F
88from typing import List , Tuple , Optional
99
10- from .helpers import tup_pair
10+ from .helpers import to_2tuple
1111from .padding import pad_same , get_padding_value
1212
1313
@@ -22,8 +22,8 @@ class AvgPool2dSame(nn.AvgPool2d):
2222 """ Tensorflow like 'SAME' wrapper for 2D average pooling
2323 """
2424 def __init__ (self , kernel_size : int , stride = None , padding = 0 , ceil_mode = False , count_include_pad = True ):
25- kernel_size = tup_pair (kernel_size )
26- stride = tup_pair (stride )
25+ kernel_size = to_2tuple (kernel_size )
26+ stride = to_2tuple (stride )
2727 super (AvgPool2dSame , self ).__init__ (kernel_size , stride , (0 , 0 ), ceil_mode , count_include_pad )
2828
2929 def forward (self , x ):
@@ -42,9 +42,9 @@ class MaxPool2dSame(nn.MaxPool2d):
4242 """ Tensorflow like 'SAME' wrapper for 2D max pooling
4343 """
4444 def __init__ (self , kernel_size : int , stride = None , padding = 0 , dilation = 1 , ceil_mode = False , count_include_pad = True ):
45- kernel_size = tup_pair (kernel_size )
46- stride = tup_pair (stride )
47- dilation = tup_pair (dilation )
45+ kernel_size = to_2tuple (kernel_size )
46+ stride = to_2tuple (stride )
47+ dilation = to_2tuple (dilation )
4848 super (MaxPool2dSame , self ).__init__ (kernel_size , stride , (0 , 0 ), dilation , ceil_mode , count_include_pad )
4949
5050 def forward (self , x ):
0 commit comments