Skip to content

Commit e95dc7f

Browse files
dmarek-flexyaugenst-flex
authored andcommitted
fix: improve type hints for Tidy3dBaseModel
1 parent 205e83d commit e95dc7f

File tree

4 files changed

+38
-14
lines changed

4 files changed

+38
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5050
- In `Tidy3dBaseModel` the hash (and cached `.json_string`) are now sensitive to changes in `.attrs`.
5151
- More accurate frequency range for ``GaussianPulse`` when DC is removed.
5252
- Bug in `TerminalComponentModelerData.get_antenna_metrics_data()` where `WavePort` mode indices were not properly handled. Improved docstrings and type hints to make the usage clearer.
53+
- Improved type hints for `Tidy3dBaseModel`, so that all derived classes will have more accurate return types.
5354

5455
## [v2.10.0rc2] - 2025-10-01
5556

poetry.lock

Lines changed: 22 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ pyjwt = "*"
4646
click = "^8.1.0"
4747
responses = "*"
4848
joblib = "*"
49+
typing-extensions = "*"
4950
### END NOT CORE
5051

5152
### Optional dependencies ###

tidy3d/components/base.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from autograd.tracer import isbox
2424
from pydantic.v1.fields import ModelField
2525
from pydantic.v1.json import custom_pydantic_encoder
26+
from typing_extensions import Self
2627

2728
from tidy3d.exceptions import FileError
2829
from tidy3d.log import log
@@ -271,7 +272,7 @@ def _default(o):
271272

272273
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
273274

274-
def copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Tidy3dBaseModel:
275+
def copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Self:
275276
"""Copy a Tidy3dBaseModel. With ``deep=True`` and ``validate=True`` as default."""
276277
kwargs.update(deep=deep)
277278
new_copy = pydantic.BaseModel.copy(self, **kwargs)
@@ -284,7 +285,7 @@ def copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Tidy3dBase
284285

285286
def updated_copy(
286287
self, path: Optional[str] = None, deep: bool = True, validate: bool = True, **kwargs
287-
) -> Tidy3dBaseModel:
288+
) -> Self:
288289
"""Make copy of a component instance with ``**kwargs`` indicating updated field values.
289290
290291
Note
@@ -342,7 +343,7 @@ def updated_copy(
342343

343344
return self._updated_copy(deep=deep, validate=validate, **{field_name: new_component})
344345

345-
def _updated_copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Tidy3dBaseModel:
346+
def _updated_copy(self, deep: bool = True, validate: bool = True, **kwargs) -> Self:
346347
"""Make copy of a component instance with ``**kwargs`` indicating updated field values."""
347348
return self.copy(update=kwargs, deep=deep, validate=validate)
348349

@@ -368,7 +369,7 @@ def from_file(
368369
lazy: bool = False,
369370
on_load: Optional[Callable] = None,
370371
**parse_obj_kwargs,
371-
) -> Tidy3dBaseModel:
372+
) -> Self:
372373
"""Loads a :class:`Tidy3dBaseModel` from .yaml, .json, .hdf5, or .hdf5.gz file.
373374
374375
Parameters
@@ -391,7 +392,7 @@ def from_file(
391392
392393
Returns
393394
-------
394-
:class:`Tidy3dBaseModel`
395+
Self
395396
An instance of the component class calling ``load``.
396397
397398
Example
@@ -469,7 +470,7 @@ def to_file(self, fname: str) -> None:
469470
return converter(fname=fname)
470471

471472
@classmethod
472-
def from_json(cls, fname: str, **parse_obj_kwargs) -> Tidy3dBaseModel:
473+
def from_json(cls, fname: str, **parse_obj_kwargs) -> Self:
473474
"""Load a :class:`Tidy3dBaseModel` from .json file.
474475
475476
Parameters
@@ -479,7 +480,7 @@ def from_json(cls, fname: str, **parse_obj_kwargs) -> Tidy3dBaseModel:
479480
480481
Returns
481482
-------
482-
:class:`Tidy3dBaseModel`
483+
Self
483484
An instance of the component class calling `load`.
484485
**parse_obj_kwargs
485486
Keyword arguments passed to pydantic's ``parse_obj`` method.
@@ -532,7 +533,7 @@ def to_json(self, fname: str) -> None:
532533
file_handle.write(json_string)
533534

534535
@classmethod
535-
def from_yaml(cls, fname: str, **parse_obj_kwargs) -> Tidy3dBaseModel:
536+
def from_yaml(cls, fname: str, **parse_obj_kwargs) -> Self:
536537
"""Loads :class:`Tidy3dBaseModel` from .yaml file.
537538
538539
Parameters
@@ -544,7 +545,7 @@ def from_yaml(cls, fname: str, **parse_obj_kwargs) -> Tidy3dBaseModel:
544545
545546
Returns
546547
-------
547-
:class:`Tidy3dBaseModel`
548+
Self
548549
An instance of the component class calling `from_yaml`.
549550
550551
Example
@@ -747,7 +748,7 @@ def from_hdf5(
747748
group_path: str = "",
748749
custom_decoders: Optional[list[Callable]] = None,
749750
**parse_obj_kwargs,
750-
) -> Tidy3dBaseModel:
751+
) -> Self:
751752
"""Loads :class:`Tidy3dBaseModel` instance to .hdf5 file.
752753
753754
Parameters
@@ -882,7 +883,7 @@ def from_hdf5_gz(
882883
group_path: str = "",
883884
custom_decoders: Optional[list[Callable]] = None,
884885
**parse_obj_kwargs,
885-
) -> Tidy3dBaseModel:
886+
) -> Self:
886887
"""Loads :class:`Tidy3dBaseModel` instance to .hdf5.gz file.
887888
888889
Parameters
@@ -1084,7 +1085,7 @@ def handle_value(x: Any, path: tuple[str, ...]) -> None:
10841085
# convert the resulting field_mapping to an autograd-traced dictionary
10851086
return dict_ag(field_mapping)
10861087

1087-
def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Tidy3dBaseModel:
1088+
def _insert_traced_fields(self, field_mapping: AutogradFieldMap) -> Self:
10881089
"""Recursively insert a map of paths to autograd-traced fields into a copy of this obj."""
10891090

10901091
self_dict = self.dict()
@@ -1129,7 +1130,7 @@ def _serialized_traced_field_keys(
11291130
tracer_keys = TracerKeys.from_field_mapping(field_mapping)
11301131
return tracer_keys.json(separators=(",", ":"), ensure_ascii=True)
11311132

1132-
def to_static(self) -> Tidy3dBaseModel:
1133+
def to_static(self) -> Self:
11331134
"""Version of object with all autograd-traced fields removed."""
11341135

11351136
# get dictionary of all traced fields

0 commit comments

Comments
 (0)