You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Debug Embedding Collection to check for NaNs in backward (#3519)
Summary:
Pull Request resolved: #3519
This diff adds support to Embedding classes to detect NaNs in backward. It adds the following: `DebugEmbeddingCollection`
`DebugEmbeddingCollectionClass`
Currently it checks if gradients contain a NaN during backward. Before we call .backward() upon `EmbeddingCollection`, this class will catch the issue first. It works by wrapping all the tensors (inside KeyedJaggedTensor) with an autograd function. this autograd function performs identity during forward but checks for nans during backward. The same is happening for `EmbeddingBagCollection` also.
This diff adds 3 tests alongside debug embedding classes
- `test_embedding`
- `test_embedding_bag`
- `test_model` (reference DLRM model which uses `DebugEmbeddingCollectionClass`). The test adds NaN to the logits, after which it would be caught by `DebugEmbeddingCollectionClass` before we can do backward)
Addresses the issue which was previously seen in S542457 https://docs.google.com/presentation/d/1soiz7UxALa_hsgCnOEw_OL4yg8oK4z-VNEIVMB7_v7U/edit?slide=id.g37205c3166e_1_135#slide=id.g37205c3166e_1_135
Reviewed By: jeffkbkim
Differential Revision: D86233629
fbshipit-source-id: 4f620d84c90c01c045cc4b69e1c5564ed2839ff3
0 commit comments