Skip to content

Commit e5a616f

Browse files
GdoongMathewBordaSkafteNickijustusschockdeependujha
authored
check the init args only when the given frames are in __init__ method. (#21227)
* check the init args only when the given frames are in `__init__` method. * chlog * add additional testing. --------- Co-authored-by: jirka <jirka.borovec@seznam.cz> Co-authored-by: Nicki Skafte Detlefsen <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent b3e0df0 commit e5a616f

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6161
- Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246))
6262

6363

64+
- Fixed check the init args only when the given frames are in `__init__` method ([#21227](https://github.com/Lightning-AI/pytorch-lightning/pull/21227))
65+
66+
6467
- Fixed how `ThroughputMonitor` calculated training time ([#21291](https://github.com/Lightning-AI/pytorch-lightning/pull/21291))
6568

6669

src/lightning/pytorch/utilities/parsing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_init_args(frame: types.FrameType) -> dict[str, Any]: # pragma: no-cover
9191

9292
def _get_init_args(frame: types.FrameType) -> tuple[Optional[Any], dict[str, Any]]:
9393
_, _, _, local_vars = inspect.getargvalues(frame)
94-
if "__class__" not in local_vars:
94+
if "__class__" not in local_vars or frame.f_code.co_name != "__init__":
9595
return None, {}
9696
cls = local_vars["__class__"]
9797
init_parameters = inspect.signature(cls.__init__).parameters

tests/tests_pytorch/models/test_hparams.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,20 @@ def __init__(obj, *more_args, other_arg=300, **more_kwargs):
341341
obj.save_hyperparameters()
342342

343343

344+
class _MetaType(type):
345+
def __call__(cls, *args, **kwargs):
346+
instance = super().__call__(*args, **kwargs) # Create the instance
347+
if hasattr(instance, "_after_init"):
348+
instance._after_init(**kwargs) # Call the method if defined
349+
return instance
350+
351+
352+
class MetaTypeBoringModel(CustomBoringModel, metaclass=_MetaType):
353+
def __init__(self, *args, **kwargs):
354+
super().__init__(*args, **kwargs)
355+
self.save_hyperparameters()
356+
357+
344358
if _OMEGACONF_AVAILABLE:
345359

346360
class DictConfSubClassBoringModel(SubClassBoringModel):
@@ -365,6 +379,7 @@ class DictConfSubClassBoringModel: ...
365379
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
366380
BoringModelWithMixin,
367381
BoringModelWithMixinAndInit,
382+
MetaTypeBoringModel,
368383
],
369384
)
370385
def test_collect_init_arguments(tmp_path, cls):
@@ -420,6 +435,21 @@ def _raw_checkpoint_path(trainer) -> str:
420435
return os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path)
421436

422437

438+
def test_collect_init_arguments_in_other_methods():
439+
class _ABCModelCreator:
440+
def init(self, model, **kwargs) -> LightningModule:
441+
self.model = model
442+
return self.model
443+
444+
class ConcreteModelCreator(_ABCModelCreator):
445+
def init(self, model=None, **kwargs) -> LightningModule:
446+
return super().init(model=model or CustomBoringModel(**kwargs))
447+
448+
model_creator = ConcreteModelCreator()
449+
model = model_creator.init(batch_size=123)
450+
assert model.hparams.batch_size == 123
451+
452+
423453
@pytest.mark.parametrize("base_class", [HyperparametersMixin, LightningModule, LightningDataModule])
424454
def test_save_hyperparameters_under_composition(base_class):
425455
"""Test that in a composition where the parent is not a Lightning-like module, the parent's arguments don't get

0 commit comments

Comments
 (0)