1+ import torch
2+ import torch .nn .parallel
3+ import torch .nn as nn
4+ import torch .nn .functional as F
5+
6+
7+ class AntiAliasDownsampleLayer (nn .Module ):
8+ def __init__ (self , remove_aa_jit : bool = False , filt_size : int = 3 , stride : int = 2 ,
9+ channels : int = 0 ):
10+ super (AntiAliasDownsampleLayer , self ).__init__ ()
11+ if not remove_aa_jit :
12+ self .op = DownsampleJIT (filt_size , stride , channels )
13+ else :
14+ self .op = Downsample (filt_size , stride , channels )
15+
16+ def forward (self , x ):
17+ return self .op (x )
18+
19+
20+ @torch .jit .script
21+ class DownsampleJIT (object ):
22+ def __init__ (self , filt_size : int = 3 , stride : int = 2 , channels : int = 0 ):
23+ self .stride = stride
24+ self .filt_size = filt_size
25+ self .channels = channels
26+
27+ assert self .filt_size == 3
28+ assert stride == 2
29+ a = torch .tensor ([1. , 2. , 1. ])
30+
31+ filt = (a [:, None ] * a [None , :]).clone ().detach ()
32+ filt = filt / torch .sum (filt )
33+ self .filt = filt [None , None , :, :].repeat ((self .channels , 1 , 1 , 1 )).cuda ().half ()
34+
35+ def __call__ (self , input : torch .Tensor ):
36+ if input .dtype != self .filt .dtype :
37+ self .filt = self .filt .float ()
38+ input_pad = F .pad (input , (1 , 1 , 1 , 1 ), 'reflect' )
39+ return F .conv2d (input_pad , self .filt , stride = 2 , padding = 0 , groups = input .shape [1 ])
40+
41+
42+ class Downsample (nn .Module ):
43+ def __init__ (self , filt_size = 3 , stride = 2 , channels = None ):
44+ super (Downsample , self ).__init__ ()
45+ self .filt_size = filt_size
46+ self .stride = stride
47+ self .channels = channels
48+
49+
50+ assert self .filt_size == 3
51+ a = torch .tensor ([1. , 2. , 1. ])
52+
53+ filt = (a [:, None ] * a [None , :])
54+ filt = filt / torch .sum (filt )
55+
56+ # self.filt = filt[None, None, :, :].repeat((self.channels, 1, 1, 1))
57+ self .register_buffer ('filt' , filt [None , None , :, :].repeat ((self .channels , 1 , 1 , 1 )))
58+
59+ def forward (self , input ):
60+ input_pad = F .pad (input , (1 , 1 , 1 , 1 ), 'reflect' )
61+ return F .conv2d (input_pad , self .filt , stride = self .stride , padding = 0 , groups = input .shape [1 ])
0 commit comments