|
6 | 6 |
|
7 | 7 | # pyre-strict |
8 | 8 |
|
| 9 | +import abc |
9 | 10 | import asyncio |
10 | 11 | import copy |
11 | 12 | import inspect |
|
17 | 18 | import shutil |
18 | 19 | import typing |
19 | 20 | import warnings |
20 | | -from dataclasses import asdict, dataclass, field |
| 21 | +from abc import abstractmethod |
| 22 | +from dataclasses import asdict, dataclass, field, fields |
21 | 23 | from datetime import datetime |
22 | | -from enum import Enum, IntEnum |
| 24 | +from enum import Enum |
23 | 25 | from json import JSONDecodeError |
24 | 26 | from string import Template |
25 | 27 | from typing import ( |
|
36 | 38 | Tuple, |
37 | 39 | Type, |
38 | 40 | TypeVar, |
39 | | - Union, |
40 | 41 | ) |
41 | 42 |
|
42 | 43 | from torchx.util.types import to_dict |
| 44 | +from typing_extensions import Self |
43 | 45 |
|
44 | 46 | _APP_STATUS_FORMAT_TEMPLATE = """AppStatus: |
45 | 47 | State: ${state} |
@@ -877,11 +879,81 @@ def __init__(self, status: AppStatus, *args: object) -> None: |
877 | 879 | self.status = status |
878 | 880 |
|
879 | 881 |
|
880 | | -# valid run cfg values; only support primitives (str, int, float, bool, List[str], Dict[str, str]) |
| 882 | +U = TypeVar("U", bound="StructuredRunOpt") |
| 883 | + |
| 884 | + |
| 885 | +class StructuredRunOpt(abc.ABC): |
| 886 | + """ |
| 887 | + StructuredRunOpt is a class that represents a structured run option. |
| 888 | + This is to allow for more complex types than currently supported. |
| 889 | +
|
| 890 | + Usage |
| 891 | +
|
| 892 | + .. doctest:: |
| 893 | + @dataclass |
| 894 | + class Ulimit(StructuredRunOpt): |
| 895 | + name: str |
| 896 | + hard: int |
| 897 | + soft: int |
| 898 | +
|
| 899 | + def template(self) -> str: |
| 900 | + # The template string should contain the field names of the Ulimit object. |
| 901 | + # The field names are mapped to the keys in the repr string. |
| 902 | + # These are comma seperated and wrapped in curly braces. |
| 903 | + return "{name},{soft},{hard}" |
| 904 | +
|
| 905 | + opts = runopts() |
| 906 | + opts.add("ulimit", type_=self.Ulimit, help="ulimits for the container") |
| 907 | +
|
| 908 | + # .from_repr() is used to create a Ulimit object from a string representation that is the template. |
| 909 | + cfg = opts.resolve( |
| 910 | + { |
| 911 | + "ulimit": self.Ulimit.from_repr( |
| 912 | + "test,50,100", |
| 913 | + ) |
| 914 | + } |
| 915 | + ) |
| 916 | +
|
| 917 | + """ |
| 918 | + |
| 919 | + @abstractmethod |
| 920 | + def template(self) -> str: |
| 921 | + """ |
| 922 | + Returns the template string for the StructuredRunOpt. |
| 923 | + These are mapped to the field names of the StructuredRunOpt object. |
| 924 | + """ |
| 925 | + ... |
| 926 | + |
| 927 | + def __repr__(self) -> str: |
| 928 | + key_value = ", ".join(**asdict(self)) |
| 929 | + return f"{self.__class__.__name__}({key_value})" |
| 930 | + |
| 931 | + def __eq__(self, other: object) -> bool: |
| 932 | + return isinstance(other, type(self)) and asdict(self) == asdict(other) |
| 933 | + |
| 934 | + @classmethod |
| 935 | + def from_repr(cls, repr: str) -> Self: |
| 936 | + """ |
| 937 | + Parses the repr string and returns a StructuredRunOpt object |
| 938 | + """ |
| 939 | + tmpl = cls.__new__(cls).template() |
| 940 | + fields_from_tmpl = [field.strip("{}") for field in tmpl.split(",")] |
| 941 | + values = repr.split(",") |
| 942 | + gd = dict(zip(fields_from_tmpl, values, strict=True)) |
| 943 | + for field_cls in fields(cls): |
| 944 | + name = field_cls.name |
| 945 | + field_type = field_cls.type |
| 946 | + value = gd.get(name) |
| 947 | + gd[name] = field_type(value) |
| 948 | + return cls(**gd) |
| 949 | + |
| 950 | + |
| 951 | +# valid run cfg values; support primitives (str, int, float, bool, List[str], Dict[str, str]) |
| 952 | +# And StructuredRunOpt Type for more complex types. |
881 | 953 | # TODO(wilsonhong): python 3.9+ supports list[T] in typing, which can be used directly |
882 | 954 | # in isinstance(). Should replace with that. |
883 | 955 | # see: https://docs.python.org/3/library/stdtypes.html#generic-alias-type |
884 | | -CfgVal = Union[str, int, float, bool, List[str], Dict[str, str], None] |
| 956 | +CfgVal = str | int | float | bool | List[str] | Dict[str, str] | StructuredRunOpt | None |
885 | 957 |
|
886 | 958 |
|
887 | 959 | T = TypeVar("T") |
|
0 commit comments