|
49 | 49 |
|
50 | 50 | @contextmanager |
51 | 51 | def hf_load_state_dict_with_device(device: DeviceLikeType): |
52 | | - """Patch HF load_state_dict to use provided device. |
53 | | -
|
54 | | - NOTE (lucaslie): this function is called by ``load_checkpoint_in_model``. We provide the device |
55 | | - map here as a patch instead of going through ``load_checkpoint_in_model``. This is because |
56 | | - otherwise ``load_checkpoint_in_model`` will execute its own state_dict loading logic instead of |
57 | | - calling ``nn.Module.load_state_dict``. However, we rely on the state dict loading hooks in |
58 | | - ``nn.Module.load_state_dict`` to correctly load the weights. By providing the device map here, |
59 | | - we can ensure that ``load_checkpoint_in_model`` will call ``nn.Module.load_state_dict``. |
| 52 | + """Patch HF loading utilities according to our needs. |
| 53 | +
|
| 54 | + Following patches are applied: |
| 55 | + 1. load_state_dict to use provided device. NOTE (lucaslie): this function is called by |
| 56 | + ``load_checkpoint_in_model``. We provide the device map here as a patch instead of going |
| 57 | + through ``load_checkpoint_in_model``. This is because otherwise |
| 58 | + ``load_checkpoint_in_model`` will execute its own state_dict loading logic instead of |
| 59 | + calling ``nn.Module.load_state_dict``. However, we rely on the state dict loading hooks |
| 60 | + in ``nn.Module.load_state_dict`` to correctly load the weights. By providing the device |
| 61 | + map here, we can ensure that ``load_checkpoint_in_model`` will call |
| 62 | + ``nn.Module.load_state_dict``. |
| 63 | + 2. change logging level of logger to ERROR to avoid logging warnings from HF state_dict |
| 64 | + loading for missing/unexpected keys (happens for MoE expert-sharded layers for example). |
60 | 65 | """ |
61 | 66 | # save the original load_state_dict method |
62 | 67 | original_load_state_dict = modeling.load_state_dict |
63 | 68 |
|
| 69 | + # save the original logger level |
| 70 | + original_logger_level = modeling.logger.level |
| 71 | + |
64 | 72 | # Define and apply the patched version |
65 | 73 | def load_state_dict_with_device(checkpoint_file, device_map=None): |
66 | 74 | return original_load_state_dict(checkpoint_file, device_map={"": device}) |
67 | 75 |
|
68 | 76 | # Apply the patch |
69 | 77 | modeling.load_state_dict = load_state_dict_with_device |
70 | 78 |
|
| 79 | + # Change the logger level to ERROR |
| 80 | + modeling.logger.setLevel("ERROR") |
| 81 | + |
71 | 82 | try: |
72 | 83 | yield |
73 | 84 | finally: |
74 | 85 | # Restore the original method, even if an exception occurred |
75 | 86 | modeling.load_state_dict = original_load_state_dict |
| 87 | + # Restore the original logger level |
| 88 | + modeling.logger.setLevel(original_logger_level) |
76 | 89 |
|
77 | 90 |
|
78 | 91 | # TODO (lucaslie): continue working on the base class |
|
0 commit comments