1+ """
2+ Visualizing Gradients
3+ =====================
4+
5+ **Author:** `Justin Silver <https://github.com/j-silv>`__
6+
7+ This tutorial explains how to extract and visualize gradients at any
8+ layer in a neural network. By inspecting how information flows from the
9+ end of the network to the parameters we want to optimize, we can debug
10+ issues such as `vanishing or exploding
11+ gradients <https://arxiv.org/abs/1211.5063>`__ that occur during
12+ training.
13+
14+ Before starting, make sure you understand `tensors and how to manipulate
15+ them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
16+ A basic knowledge of `how autograd
17+ works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
18+ would also be useful.
19+
20+ """
21+
22+
23+ ######################################################################
24+ # Setup
25+ # -----
26+ #
27+ # First, make sure `PyTorch is
28+ # installed <https://pytorch.org/get-started/locally/>`__ and then import
29+ # the necessary libraries.
30+ #
31+
32+ import torch
33+ import torch .nn as nn
34+ import torch .optim as optim
35+ import torch .nn .functional as F
36+ import matplotlib .pyplot as plt
37+
38+
39+ ######################################################################
40+ # Next, we’ll be creating a network intended for the MNIST dataset,
41+ # similar to the architecture described by the `batch normalization
42+ # paper <https://arxiv.org/abs/1502.03167>`__.
43+ #
44+ # To illustrate the importance of gradient visualization, we will
45+ # instantiate one version of the network with batch normalization
46+ # (BatchNorm), and one without it. Batch normalization is an extremely
47+ # effective technique to resolve `vanishing/exploding
48+ # gradients <https://arxiv.org/abs/1211.5063>`__, and we will be verifying
49+ # that experimentally.
50+ #
51+ # The model we use has a configurable number of repeating fully-connected
52+ # layers which alternate between ``nn.Linear``, ``norm_layer``, and
53+ # ``nn.Sigmoid``. If batch normalization is enabled, then ``norm_layer``
54+ # will use
55+ # `BatchNorm1d <https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html>`__,
56+ # otherwise it will use the
57+ # `Identity <https://docs.pytorch.org/docs/stable/generated/torch.nn.Identity.html>`__
58+ # transformation.
59+ #
60+
61+ def fc_layer (in_size , out_size , norm_layer ):
62+ """Return a stack of linear->norm->sigmoid layers"""
63+ return nn .Sequential (nn .Linear (in_size , out_size ), norm_layer (out_size ), nn .Sigmoid ())
64+
65+ class Net (nn .Module ):
66+ """Define a network that has num_layers of linear->norm->sigmoid transformations"""
67+ def __init__ (self , in_size = 28 * 28 , hidden_size = 128 ,
68+ out_size = 10 , num_layers = 3 , batchnorm = False ):
69+ super ().__init__ ()
70+ if batchnorm is False :
71+ norm_layer = nn .Identity
72+ else :
73+ norm_layer = nn .BatchNorm1d
74+
75+ layers = []
76+ layers .append (fc_layer (in_size , hidden_size , norm_layer ))
77+
78+ for i in range (num_layers - 1 ):
79+ layers .append (fc_layer (hidden_size , hidden_size , norm_layer ))
80+
81+ layers .append (nn .Linear (hidden_size , out_size ))
82+
83+ self .layers = nn .Sequential (* layers )
84+
85+ def forward (self , x ):
86+ x = torch .flatten (x , 1 )
87+ return self .layers (x )
88+
89+
90+ ######################################################################
91+ # Next we set up some dummy data, instantiate two versions of the model,
92+ # and initialize the optimizers.
93+ #
94+
95+ # set up dummy data
96+ x = torch .randn (10 , 28 , 28 )
97+ y = torch .randint (10 , (10 , ))
98+
99+ # init model
100+ model_bn = Net (batchnorm = True , num_layers = 3 )
101+ model_nobn = Net (batchnorm = False , num_layers = 3 )
102+
103+ model_bn .train ()
104+ model_nobn .train ()
105+
106+ optimizer_bn = optim .SGD (model_bn .parameters (), lr = 0.01 , momentum = 0.9 )
107+ optimizer_nobn = optim .SGD (model_nobn .parameters (), lr = 0.01 , momentum = 0.9 )
108+
109+
110+
111+ ######################################################################
112+ # We can verify that batch normalization is only being applied to one of
113+ # the models by probing one of the internal layers:
114+ #
115+
116+ print (model_bn .layers [0 ])
117+ print (model_nobn .layers [0 ])
118+
119+
120+ ######################################################################
121+ # Registering hooks
122+ # -----------------
123+ #
124+
125+
126+ ######################################################################
127+ # Because we wrapped up the logic and state of our model in a
128+ # ``nn.Module``, we need another method to access the intermediate
129+ # gradients if we want to avoid modifying the module code directly. This
130+ # is done by `registering a
131+ # hook <https://docs.pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`__.
132+ #
133+ # .. warning::
134+ #
135+ # Using backward pass hooks attached to output tensors is preferred over using ``retain_grad()`` on the tensors themselves. An alternative method is to directly attach module hooks (e.g. ``register_full_backward_hook()``) so long as the ``nn.Module`` instance does not do perform any in-place operations. For more information, please refer to `this issue <https://github.com/pytorch/pytorch/issues/61519>`__.
136+ #
137+ # The following code defines our hooks and gathers descriptive names for
138+ # the network’s layers.
139+ #
140+
141+ # note that wrapper functions are used for Python closure
142+ # so that we can pass arguments.
143+
144+ def hook_forward (module_name , grads , hook_backward ):
145+ def hook (module , args , output ):
146+ """Forward pass hook which attaches backward pass hooks to intermediate tensors"""
147+ output .register_hook (hook_backward (module_name , grads ))
148+ return hook
149+
150+ def hook_backward (module_name , grads ):
151+ def hook (grad ):
152+ """Backward pass hook which appends gradients"""
153+ grads .append ((module_name , grad ))
154+ return hook
155+
156+ def get_all_layers (model , hook_forward , hook_backward ):
157+ """Register forward pass hook (which registers a backward hook) to model outputs
158+
159+ Returns:
160+ - layers: a dict with keys as layer/module and values as layer/module names
161+ e.g. layers[nn.Conv2d] = layer1.0.conv1
162+ - grads: a list of tuples with module name and tensor output gradient
163+ e.g. grads[0] == (layer1.0.conv1, tensor.Torch(...))
164+ """
165+ layers = dict ()
166+ grads = []
167+ for name , layer in model .named_modules ():
168+ # skip Sequential and/or wrapper modules
169+ if any (layer .children ()) is False :
170+ layers [layer ] = name
171+ layer .register_forward_hook (hook_forward (name , grads , hook_backward ))
172+ return layers , grads
173+
174+ # register hooks
175+ layers_bn , grads_bn = get_all_layers (model_bn , hook_forward , hook_backward )
176+ layers_nobn , grads_nobn = get_all_layers (model_nobn , hook_forward , hook_backward )
177+
178+
179+ ######################################################################
180+ # Training and visualization
181+ # --------------------------
182+ #
183+ # Let’s now train the models for a few epochs:
184+ #
185+
186+ epochs = 10
187+
188+ for epoch in range (epochs ):
189+
190+ # important to clear, because we append to
191+ # outputs everytime we do a forward pass
192+ grads_bn .clear ()
193+ grads_nobn .clear ()
194+
195+ optimizer_bn .zero_grad ()
196+ optimizer_nobn .zero_grad ()
197+
198+ y_pred_bn = model_bn (x )
199+ y_pred_nobn = model_nobn (x )
200+
201+ loss_bn = F .cross_entropy (y_pred_bn , y )
202+ loss_nobn = F .cross_entropy (y_pred_nobn , y )
203+
204+ loss_bn .backward ()
205+ loss_nobn .backward ()
206+
207+ optimizer_bn .step ()
208+ optimizer_nobn .step ()
209+
210+
211+ ######################################################################
212+ # After running the forward and backward pass, the gradients for all the
213+ # intermediate tensors should be present in ``grads_bn`` and
214+ # ``grads_nobn``. We compute the mean absolute value of each gradient
215+ # matrix so that we can compare the two models.
216+ #
217+
218+ def get_grads (grads ):
219+ layer_idx = []
220+ avg_grads = []
221+ for idx , (name , grad ) in enumerate (grads ):
222+ if grad is not None :
223+ avg_grad = grad .abs ().mean ()
224+ avg_grads .append (avg_grad )
225+ # idx is backwards since we appended in backward pass
226+ layer_idx .append (len (grads ) - 1 - idx )
227+ return layer_idx , avg_grads
228+
229+ layer_idx_bn , avg_grads_bn = get_grads (grads_bn )
230+ layer_idx_nobn , avg_grads_nobn = get_grads (grads_nobn )
231+
232+
233+ ######################################################################
234+ # With the average gradients computed, we can now plot them and see how
235+ # the values change as a function of the network depth. Notice that when
236+ # we don’t apply batch normalization, the gradient values in the
237+ # intermediate layers fall to zero very quickly. The batch normalization
238+ # model, however, maintains non-zero gradients in its intermediate layers.
239+ #
240+
241+ fig , ax = plt .subplots ()
242+ ax .plot (layer_idx_bn , avg_grads_bn , label = "With BatchNorm" , marker = "o" )
243+ ax .plot (layer_idx_nobn , avg_grads_nobn , label = "Without BatchNorm" , marker = "x" )
244+ ax .set_xlabel ("Layer depth" )
245+ ax .set_ylabel ("Average gradient" )
246+ ax .set_title ("Gradient flow" )
247+ ax .grid (True )
248+ ax .legend ()
249+ plt .show ()
250+
251+
252+ ######################################################################
253+ # Conclusion
254+ # ----------
255+ #
256+ # In this tutorial, we demonstrated how to visualize the gradient flow
257+ # through a neural network wrapped in a ``nn.Module`` class. We
258+ # qualitatively showed how batch normalization helps to alleviate the
259+ # vanishing gradient issue which occurs with deep neural networks.
260+ #
261+ # If you would like to learn more about how PyTorch’s autograd system
262+ # works, please visit the `references <#references>`__ below. If you have
263+ # any feedback for this tutorial (improvements, typo fixes, etc.) then
264+ # please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
265+ # the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
266+ # reach out.
267+ #
268+
269+
270+ ######################################################################
271+ # (Optional) Additional exercises
272+ # -------------------------------
273+ #
274+ # - Try increasing the number of layers (``num_layers``) in our model and
275+ # see what effect this has on the gradient flow graph
276+ # - How would you adapt the code to visualize average activations instead
277+ # of average gradients? (*Hint: in the hook_forward() function we have
278+ # access to the raw tensor output*)
279+ # - What are some other methods to deal with vanishing and exploding
280+ # gradients?
281+ #
282+
283+
284+ ######################################################################
285+ # References
286+ # ----------
287+ #
288+ # - `A Gentle Introduction to
289+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
290+ # - `Automatic Differentiation with
291+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
292+ # - `Autograd
293+ # mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
294+ # - `Batch Normalization: Accelerating Deep Network Training by Reducing
295+ # Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__
296+ # - `On the difficulty of training Recurrent Neural
297+ # Networks <https://arxiv.org/abs/1211.5063>`__
298+ #
0 commit comments