|
7 | 7 | import shutil |
8 | 8 | import sys |
9 | 9 | from importlib.util import module_from_spec, spec_from_file_location |
| 10 | +from pathlib import Path |
10 | 11 | from tempfile import mkdtemp |
11 | 12 | from types import ModuleType |
12 | 13 | from typing import Any, Dict, List, Tuple, Type |
13 | 14 | from uuid import uuid4 |
14 | 15 |
|
15 | | -from pydantic import BaseModel, Extra, create_model |
| 16 | +from pydantic import VERSION, BaseModel, Extra, create_model |
16 | 17 |
|
17 | 18 | try: |
18 | 19 | from pydantic.generics import GenericModel |
|
22 | 23 | logger = logging.getLogger("pydantic2ts") |
23 | 24 |
|
24 | 25 |
|
| 26 | +DEBUG = os.environ.get("DEBUG", False) |
| 27 | + |
| 28 | +V2 = True if VERSION.startswith("2") else False |
| 29 | + |
| 30 | + |
25 | 31 | def import_module(path: str) -> ModuleType: |
26 | 32 | """ |
27 | 33 | Helper which allows modules to be specified by either dotted path notation or by filepath. |
@@ -61,12 +67,15 @@ def is_concrete_pydantic_model(obj) -> bool: |
61 | 67 | Return true if an object is a concrete subclass of pydantic's BaseModel. |
62 | 68 | 'concrete' meaning that it's not a GenericModel. |
63 | 69 | """ |
| 70 | + generic_metadata = getattr(obj, "__pydantic_generic_metadata__", None) |
64 | 71 | if not inspect.isclass(obj): |
65 | 72 | return False |
66 | 73 | elif obj is BaseModel: |
67 | 74 | return False |
68 | | - elif GenericModel and issubclass(obj, GenericModel): |
| 75 | + elif not V2 and GenericModel and issubclass(obj, GenericModel): |
69 | 76 | return bool(obj.__concrete__) |
| 77 | + elif V2 and generic_metadata: |
| 78 | + return not bool(generic_metadata["parameters"]) |
70 | 79 | else: |
71 | 80 | return issubclass(obj, BaseModel) |
72 | 81 |
|
@@ -141,7 +150,7 @@ def clean_schema(schema: Dict[str, Any]) -> None: |
141 | 150 | del schema["description"] |
142 | 151 |
|
143 | 152 |
|
144 | | -def generate_json_schema(models: List[Type[BaseModel]]) -> str: |
| 153 | +def generate_json_schema_v1(models: List[Type[BaseModel]]) -> str: |
145 | 154 | """ |
146 | 155 | Create a top-level '_Master_' model with references to each of the actual models. |
147 | 156 | Generate the schema for this model, which will include the schemas for all the |
@@ -178,6 +187,43 @@ def generate_json_schema(models: List[Type[BaseModel]]) -> str: |
178 | 187 | m.Config.extra = x |
179 | 188 |
|
180 | 189 |
|
| 190 | +def generate_json_schema_v2(models: List[Type[BaseModel]]) -> str: |
| 191 | + """ |
| 192 | + Create a top-level '_Master_' model with references to each of the actual models. |
| 193 | + Generate the schema for this model, which will include the schemas for all the |
| 194 | + nested models. Then clean up the schema. |
| 195 | +
|
| 196 | + One weird thing we do is we temporarily override the 'extra' setting in models, |
| 197 | + changing it to 'forbid' UNLESS it was explicitly set to 'allow'. This prevents |
| 198 | + '[k: string]: any' from being added to every interface. This change is reverted |
| 199 | + once the schema has been generated. |
| 200 | + """ |
| 201 | + model_extras = [m.model_config.get("extra") for m in models] |
| 202 | + |
| 203 | + try: |
| 204 | + for m in models: |
| 205 | + if m.model_config.get("extra") != Extra.allow: |
| 206 | + m.model_config["extra"] = Extra.forbid |
| 207 | + |
| 208 | + master_model: BaseModel = create_model( |
| 209 | + "_Master_", **{m.__name__: (m, ...) for m in models} |
| 210 | + ) |
| 211 | + master_model.model_config["extra"] = Extra.forbid |
| 212 | + master_model.model_config["json_schema_extra"] = staticmethod(clean_schema) |
| 213 | + |
| 214 | + schema: dict = master_model.model_json_schema() |
| 215 | + |
| 216 | + for d in schema.get("$defs", {}).values(): |
| 217 | + clean_schema(d) |
| 218 | + |
| 219 | + return json.dumps(schema, indent=2) |
| 220 | + |
| 221 | + finally: |
| 222 | + for m, x in zip(models, model_extras): |
| 223 | + if x is not None: |
| 224 | + m.model_config["extra"] = x |
| 225 | + |
| 226 | + |
181 | 227 | def generate_typescript_defs( |
182 | 228 | module: str, output: str, exclude: Tuple[str] = (), json2ts_cmd: str = "json2ts" |
183 | 229 | ) -> None: |
@@ -205,13 +251,18 @@ def generate_typescript_defs( |
205 | 251 |
|
206 | 252 | logger.info("Generating JSON schema from pydantic models...") |
207 | 253 |
|
208 | | - schema = generate_json_schema(models) |
| 254 | + schema = generate_json_schema_v2(models) if V2 else generate_json_schema_v1(models) |
| 255 | + |
209 | 256 | schema_dir = mkdtemp() |
210 | 257 | schema_file_path = os.path.join(schema_dir, "schema.json") |
211 | 258 |
|
212 | 259 | with open(schema_file_path, "w") as f: |
213 | 260 | f.write(schema) |
214 | 261 |
|
| 262 | + if DEBUG: |
| 263 | + with open(Path(output).parent / "schema.json", "w") as f: |
| 264 | + f.write(schema) |
| 265 | + |
215 | 266 | logger.info("Converting JSON schema to typescript definitions...") |
216 | 267 |
|
217 | 268 | json2ts_exit_code = os.system( |
|
0 commit comments