Skip to content

Commit 74507d6

Browse files
committed
chore: add TypedDict for Insert and Update
new class is needed because typeddict uses NonRequired for missing attributes
1 parent 5a514cc commit 74507d6

File tree

2 files changed

+384
-114
lines changed

2 files changed

+384
-114
lines changed

src/server/templates/python.ts

Lines changed: 131 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,60 @@ import type {
88
} from '../../lib/index.js'
99
import type { GeneratorMetadata } from '../../lib/generators.js'
1010

11+
export const apply = ({
12+
schemas,
13+
tables,
14+
views,
15+
materializedViews,
16+
columns,
17+
types,
18+
}: GeneratorMetadata): string => {
19+
const ctx = new PythonContext(types, columns, schemas);
20+
const py_tables = tables
21+
.filter((table) => schemas.some((schema) => schema.name === table.schema))
22+
.flatMap((table) => {
23+
const py_class_and_methods = ctx.tableToClass(table);
24+
return py_class_and_methods;
25+
});
26+
const composite_types = types
27+
.filter((type) => type.attributes.length > 0)
28+
.map((type) => ctx.typeToClass(type));
29+
const py_views = views.map((view) => ctx.viewToClass(view));
30+
const py_matviews = materializedViews.map((matview) => ctx.matViewToClass(matview));
31+
32+
let output = `
33+
from __future__ import annotations
34+
35+
import datetime
36+
from typing import (
37+
Annotated,
38+
Any,
39+
List,
40+
Literal,
41+
NotRequired,
42+
Optional,
43+
TypeAlias,
44+
TypedDict,
45+
)
46+
47+
from pydantic import BaseModel, Field, Json
48+
49+
${concatLines(Object.values(ctx.user_enums))}
50+
51+
${concatLines(py_tables)}
52+
53+
${concatLines(py_views)}
54+
55+
${concatLines(py_matviews)}
56+
57+
${concatLines(composite_types)}
58+
59+
`.trim()
60+
61+
return output
62+
}
63+
64+
1165
interface Serializable {
1266
serialize(): string
1367
}
@@ -63,7 +117,7 @@ class PythonContext {
63117
}
64118
}
65119

66-
typeToClass(type: PostgresType) : PythonClass {
120+
typeToClass(type: PostgresType) : PythonBaseModel {
67121
const types = Object.values(this.types);
68122
const attributes = type.attributes.map((attribute) => {
69123
const type = types.find((type) => type.id === attribute.type_id)
@@ -72,41 +126,49 @@ class PythonContext {
72126
type,
73127
}
74128
});
75-
const attributeEntries: PythonClassAttribute[] = attributes
129+
const attributeEntries: PythonBaseModelAttr[] = attributes
76130
.map((attribute) => {
77131
const type = this.parsePgType(attribute.type!.name);
78-
return new PythonClassAttribute(attribute.name, type, false, false, false, false);
132+
return new PythonBaseModelAttr(attribute.name, type, false);
79133
});
134+
80135
const schema = this.schemas[type.schema];
81-
return new PythonClass(type.name, schema, attributeEntries);
136+
return new PythonBaseModel(type.name, schema, attributeEntries);
137+
}
138+
139+
columnsToClassAttrs(table_id: number) : PythonBaseModelAttr[] {
140+
const attrs = this.columns[table_id] ?? [];
141+
return attrs.map((col) => {
142+
const type = this.parsePgType(col.format);
143+
return new PythonBaseModelAttr(col.name, type, col.is_nullable);
144+
});
82145
}
83146

84-
columnsToClassAttrs(table_id: number) : PythonClassAttribute[] {
147+
columnsToDictAttrs(table_id: number, not_required: boolean) : PythonTypedDictAttr[] {
85148
const attrs = this.columns[table_id] ?? [];
86149
return attrs.map((col) => {
87150
const type = this.parsePgType(col.format);
88-
return new PythonClassAttribute(col.name, type,
89-
col.is_nullable,
90-
col.is_updatable,
91-
col.is_generated || !!col.default_value,
92-
col.is_identity);
151+
return new PythonTypedDictAttr(col.name, type, col.is_nullable, not_required || col.is_nullable || col.is_identity || (col.default_value !== null));
93152
});
94153
}
95154

96-
tableToClass(table: PostgresTable) : PythonClass {
97-
const attributes = this.columnsToClassAttrs(table.id);
98-
return new PythonClass(table.name, this.schemas[table.schema], attributes)
155+
tableToClass(table: PostgresTable) : [PythonBaseModel, PythonTypedDict, PythonTypedDict] {
156+
const schema = this.schemas[table.schema];
157+
const select = new PythonBaseModel(table.name, schema, this.columnsToClassAttrs(table.id));
158+
const insert = new PythonTypedDict(table.name, "Insert", schema, this.columnsToDictAttrs(table.id, false));
159+
const update = new PythonTypedDict(table.name, "Update", schema, this.columnsToDictAttrs(table.id, true));
160+
return [select, insert, update];
99161
}
100162

101163

102-
viewToClass(view: PostgresView) : PythonClass {
164+
viewToClass(view: PostgresView) : PythonBaseModel {
103165
const attributes = this.columnsToClassAttrs(view.id);
104-
return new PythonClass(view.name, this.schemas[view.schema], attributes)
166+
return new PythonBaseModel(view.name, this.schemas[view.schema], attributes)
105167
}
106168

107-
matViewToClass(matview: PostgresMaterializedView) : PythonClass {
169+
matViewToClass(matview: PostgresMaterializedView) : PythonBaseModel {
108170
const attributes = this.columnsToClassAttrs(matview.id);
109-
return new PythonClass(matview.name, this.schemas[matview.schema], attributes)
171+
return new PythonBaseModel(matview.name, this.schemas[matview.schema], attributes)
110172
}
111173
}
112174

@@ -116,7 +178,7 @@ class PythonEnum implements Serializable {
116178
variants: string[];
117179
constructor(type: PostgresType) {
118180
this.name = `${formatForPyClassName(type.schema)}${formatForPyClassName(type.name)}`;
119-
this.variants = type.enums.map(formatForPyAttributeName);
181+
this.variants = type.enums;
120182
}
121183
serialize(): string {
122184
const variants = this.variants.map((item) => `"${item}"`).join(', ');
@@ -146,71 +208,93 @@ class PythonListType implements Serializable {
146208
}
147209
}
148210

149-
class PythonClassAttribute implements Serializable {
211+
class PythonBaseModelAttr implements Serializable {
150212
name: string;
151213
pg_name: string;
152214
py_type: PythonType;
153215
nullable: boolean;
154-
mutable: boolean;
155-
has_default: boolean;
156-
is_identity: boolean;
157-
158216

159-
constructor(name: string, py_type: PythonType, nullable: boolean, mutable: boolean, has_default: boolean, is_identity: boolean) {
217+
constructor(name: string, py_type: PythonType, nullable: boolean) {
160218
this.name = formatForPyAttributeName(name);
161219
this.pg_name = name;
162220
this.py_type = py_type;
163221
this.nullable = nullable;
164-
this.mutable = mutable;
165-
this.has_default = has_default;
166-
this.is_identity = is_identity;
167222
}
168-
223+
169224
serialize(): string {
170225
const py_type = this.nullable
171226
? `Optional[${this.py_type.serialize()}]`
172227
: this.py_type.serialize();
173-
return ` ${this.name}: Annotated[${py_type}, Field(alias="${this.pg_name}")]`
228+
return ` ${this.name}: ${py_type} = Field(alias="${this.pg_name}")`
174229
}
175-
176230
}
177231

178-
class PythonClass implements Serializable {
232+
class PythonBaseModel implements Serializable {
179233
name: string;
180234
table_name: string;
181-
parent_class: string;
182235
schema: PostgresSchema;
183-
class_attributes: PythonClassAttribute[];
236+
class_attributes: PythonBaseModelAttr[];
184237

185-
constructor(name: string, schema: PostgresSchema, class_attributes: PythonClassAttribute[], parent_class: string="BaseModel") {
238+
constructor(name: string, schema: PostgresSchema, class_attributes: PythonBaseModelAttr[]) {
186239
this.schema = schema;
187240
this.class_attributes = class_attributes;
188241
this.table_name = name;
189242
this.name = `${formatForPyClassName(schema.name)}${formatForPyClassName(name)}`;
190-
this.parent_class = parent_class;
191243
}
192244
serialize(): string {
193245
const attributes = this.class_attributes.length > 0
194246
? this.class_attributes.map((attr) => attr.serialize()).join('\n')
195247
: " pass";
196-
return `class ${this.name}(${this.parent_class}):\n${attributes}`;
248+
return `class ${this.name}(BaseModel):\n${attributes}`;
197249
}
250+
}
251+
252+
class PythonTypedDictAttr implements Serializable {
253+
name: string;
254+
pg_name: string;
255+
py_type: PythonType;
256+
nullable: boolean;
257+
not_required: boolean;
198258

199-
update() : PythonClass {
200-
// Converts all attributes to nullable
201-
const attrs = this.class_attributes
202-
.filter((attr) => attr.mutable || attr.is_identity)
203-
.map((attr) => new PythonClassAttribute(attr.name, attr.py_type, true, attr.mutable, attr.has_default, attr.is_identity))
204-
return new PythonClass(`${this.table_name}_update`, this.schema, attrs, "TypedDict")
259+
constructor(name: string, py_type: PythonType, nullable: boolean, required: boolean) {
260+
this.name = formatForPyAttributeName(name);
261+
this.pg_name = name;
262+
this.py_type = py_type;
263+
this.nullable = nullable;
264+
this.not_required = required;
205265
}
206266

207-
insert() : PythonClass {
208-
// Converts all attributes that have a default to nullable.
209-
const attrs = this.class_attributes
210-
.map((attr) => new PythonClassAttribute(attr.name, attr.py_type, attr.has_default || attr.nullable, attr.mutable, attr.has_default, attr.is_identity));
211-
return new PythonClass(`${this.table_name}_insert`, this.schema, attrs, "TypedDict")
267+
serialize(): string {
268+
const annotation = `Annotated[${this.py_type.serialize()}, Field(alias="${this.pg_name}")]`;
269+
const rhs = this.not_required
270+
? `NotRequired[${annotation}]`
271+
: annotation;
272+
return ` ${this.name}: ${rhs}`;
212273
}
274+
}
275+
276+
class PythonTypedDict implements Serializable {
277+
name: string;
278+
table_name: string;
279+
parent_class: string;
280+
schema: PostgresSchema;
281+
dict_attributes: PythonTypedDictAttr[];
282+
operation: "Insert" | "Update";
213283

284+
constructor(name: string, operation: "Insert" | "Update", schema: PostgresSchema, dict_attributes: PythonTypedDictAttr[], parent_class: string="BaseModel") {
285+
this.schema = schema;
286+
this.dict_attributes = dict_attributes;
287+
this.table_name = name;
288+
this.name = `${formatForPyClassName(schema.name)}${formatForPyClassName(name)}`;
289+
this.parent_class = parent_class;
290+
this.operation = operation;
291+
}
292+
serialize(): string {
293+
const attributes = this.dict_attributes.length > 0
294+
? this.dict_attributes.map((attr) => attr.serialize()).join('\n')
295+
: " pass";
296+
return `class ${this.name}${this.operation}(TypedDict):\n${attributes}`;
297+
}
214298
}
215299

216300
function concatLines(items: Serializable[]): string {
@@ -267,52 +351,6 @@ const PY_TYPE_MAP: Record<string, string> = {
267351
record: 'dict[str, Any]',
268352
} as const
269353

270-
export const apply = ({
271-
schemas,
272-
tables,
273-
views,
274-
materializedViews,
275-
columns,
276-
types,
277-
}: GeneratorMetadata): string => {
278-
const ctx = new PythonContext(types, columns, schemas);
279-
const py_tables = tables
280-
.filter((table) => schemas.some((schema) => schema.name === table.schema))
281-
.flatMap((table) => {
282-
const py_class = ctx.tableToClass(table);
283-
return [py_class, py_class.insert(), py_class.update()];
284-
});
285-
286-
const composite_types = types
287-
.filter((type) => type.attributes.length > 0)
288-
.map((type) => ctx.typeToClass(type));
289-
290-
const py_views = views.map((view) => ctx.viewToClass(view));
291-
const py_matviews = materializedViews.map((matview) => ctx.matViewToClass(matview));
292-
293-
let output = `
294-
from __future__ import annotations
295-
296-
import datetime
297-
from typing import Annotated, Any, List, Literal, Optional, TypeAlias, TypedDict
298-
299-
from pydantic import BaseModel, Field, Json
300-
301-
${concatLines(Object.values(ctx.user_enums))}
302-
303-
${concatLines(py_tables)}
304-
305-
${concatLines(py_views)}
306-
307-
${concatLines(py_matviews)}
308-
309-
${concatLines(composite_types)}
310-
311-
`.trim()
312-
313-
return output
314-
}
315-
316354
/**
317355
* Converts a Postgres name to PascalCase.
318356
*
@@ -349,24 +387,3 @@ function formatForPyAttributeName(name: string): string {
349387
.join('_'); // Join with underscores
350388
}
351389

352-
function pgTypeToPythonType(pgType: string, nullable: boolean, types: PostgresType[] = []): string {
353-
let pythonType: string | undefined = undefined
354-
355-
if (pgType in PY_TYPE_MAP) {
356-
pythonType = PY_TYPE_MAP[pgType as keyof typeof PY_TYPE_MAP]
357-
}
358-
359-
// Enums
360-
const enumType = types.find((type) => type.name === pgType && type.enums.length > 0)
361-
if (enumType) {
362-
pythonType = formatForPyClassName(String(pgType))
363-
}
364-
365-
if (pythonType) {
366-
// If the type is nullable, append "| None" to the type
367-
return nullable ? `${pythonType} | None` : pythonType
368-
}
369-
370-
// Fallback
371-
return nullable ? String(pgType)+' | None' : String(pgType)
372-
}

0 commit comments

Comments
 (0)