11import {
2+ ExpressionContext ,
23 PluginError ,
34 PluginGlobalOptions ,
45 PluginOptions ,
56 RUNTIME_PACKAGE ,
7+ TypeScriptExpressionTransformer ,
8+ TypeScriptExpressionTransformerError ,
69 ensureEmptyDir ,
10+ getAttributeArg ,
11+ getAttributeArgLiteral ,
712 getDataModels ,
13+ getLiteralArray ,
814 hasAttribute ,
15+ isDataModelFieldReference ,
916 isDiscriminatorField ,
1017 isEnumFieldReference ,
1118 isForeignKeyField ,
@@ -15,7 +22,7 @@ import {
1522 resolvePath ,
1623 saveSourceFile ,
1724} from '@zenstackhq/sdk' ;
18- import { DataModel , EnumField , Model , TypeDef , isDataModel , isEnum , isTypeDef } from '@zenstackhq/sdk/ast' ;
25+ import { DataModel , EnumField , Model , TypeDef , isArrayExpr , isDataModel , isEnum , isTypeDef } from '@zenstackhq/sdk/ast' ;
1926import { addMissingInputObjectTypes , resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers' ;
2027import { getPrismaClientImportSpec , supportCreateMany , type DMMF } from '@zenstackhq/sdk/prisma' ;
2128import { streamAllContents } from 'langium' ;
@@ -26,7 +33,7 @@ import { name } from '.';
2633import { getDefaultOutputFolder } from '../plugin-utils' ;
2734import Transformer from './transformer' ;
2835import { ObjectMode } from './types' ;
29- import { makeFieldSchema , makeValidationRefinements } from './utils/schema-gen' ;
36+ import { makeFieldSchema } from './utils/schema-gen' ;
3037
3138export class ZodSchemaGenerator {
3239 private readonly sourceFiles : SourceFile [ ] = [ ] ;
@@ -294,7 +301,7 @@ export class ZodSchemaGenerator {
294301 sf . replaceWithText ( ( writer ) => {
295302 this . addPreludeAndImports ( typeDef , writer , output ) ;
296303
297- writer . write ( `export const ${ typeDef . name } Schema = z.object(` ) ;
304+ writer . write ( `const baseSchema = z.object(` ) ;
298305 writer . inlineBlock ( ( ) => {
299306 typeDef . fields . forEach ( ( field ) => {
300307 writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
@@ -313,9 +320,24 @@ export class ZodSchemaGenerator {
313320 writer . writeLine ( ').strict();' ) ;
314321 break ;
315322 }
316- } ) ;
317323
318- // TODO: "@@validate" refinements
324+ // compile "@@validate" to a function calling zod's `.refine()`
325+ const refineFuncName = this . createRefineFunction ( typeDef , writer ) ;
326+
327+ if ( refineFuncName ) {
328+ // export a schema without refinement for extensibility: `[Model]WithoutRefineSchema`
329+ const noRefineSchema = `${ upperCaseFirst ( typeDef . name ) } WithoutRefineSchema` ;
330+ writer . writeLine ( `
331+ /**
332+ * \`${ typeDef . name } \` schema prior to calling \`.refine()\` for extensibility.
333+ */
334+ export const ${ noRefineSchema } = baseSchema;
335+ export const ${ typeDef . name } Schema = ${ refineFuncName } (${ noRefineSchema } );
336+ ` ) ;
337+ } else {
338+ writer . writeLine ( `export const ${ typeDef . name } Schema = baseSchema;` ) ;
339+ }
340+ } ) ;
319341
320342 return schemaName ;
321343 }
@@ -436,22 +458,7 @@ export class ZodSchemaGenerator {
436458 }
437459
438460 // compile "@@validate" to ".refine"
439- const refinements = makeValidationRefinements ( model ) ;
440- let refineFuncName : string | undefined ;
441- if ( refinements . length > 0 ) {
442- refineFuncName = `refine${ upperCaseFirst ( model . name ) } ` ;
443- writer . writeLine (
444- `
445- /**
446- * Schema refinement function for applying \`@@validate\` rules.
447- */
448- export function ${ refineFuncName } <T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${ refinements . join (
449- '\n'
450- ) } ;
451- }
452- `
453- ) ;
454- }
461+ const refineFuncName = this . createRefineFunction ( model , writer ) ;
455462
456463 // delegate discriminator fields are to be excluded from mutation schemas
457464 const delegateDiscriminatorFields = model . fields . filter ( ( field ) => isDiscriminatorField ( field ) ) ;
@@ -658,6 +665,74 @@ export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};
658665 return schemaName ;
659666 }
660667
668+ private createRefineFunction ( decl : DataModel | TypeDef , writer : CodeBlockWriter ) {
669+ const refinements = this . makeValidationRefinements ( decl ) ;
670+ let refineFuncName : string | undefined ;
671+ if ( refinements . length > 0 ) {
672+ refineFuncName = `refine${ upperCaseFirst ( decl . name ) } ` ;
673+ writer . writeLine (
674+ `
675+ /**
676+ * Schema refinement function for applying \`@@validate\` rules.
677+ */
678+ export function ${ refineFuncName } <T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${ refinements . join (
679+ '\n'
680+ ) } ;
681+ }
682+ `
683+ ) ;
684+ return refineFuncName ;
685+ } else {
686+ return undefined ;
687+ }
688+ }
689+
690+ private makeValidationRefinements ( decl : DataModel | TypeDef ) {
691+ const attrs = decl . attributes . filter ( ( attr ) => attr . decl . ref ?. name === '@@validate' ) ;
692+ const refinements = attrs
693+ . map ( ( attr ) => {
694+ const valueArg = getAttributeArg ( attr , 'value' ) ;
695+ if ( ! valueArg ) {
696+ return undefined ;
697+ }
698+
699+ const messageArg = getAttributeArgLiteral < string > ( attr , 'message' ) ;
700+ const message = messageArg ? `message: ${ JSON . stringify ( messageArg ) } ,` : '' ;
701+
702+ const pathArg = getAttributeArg ( attr , 'path' ) ;
703+ const path =
704+ pathArg && isArrayExpr ( pathArg )
705+ ? `path: ['${ getLiteralArray < string > ( pathArg ) ?. join ( `', '` ) } '],`
706+ : '' ;
707+
708+ const options = `, { ${ message } ${ path } }` ;
709+
710+ try {
711+ let expr = new TypeScriptExpressionTransformer ( {
712+ context : ExpressionContext . ValidationRule ,
713+ fieldReferenceContext : 'value' ,
714+ } ) . transform ( valueArg ) ;
715+
716+ if ( isDataModelFieldReference ( valueArg ) ) {
717+ // if the expression is a simple field reference, treat undefined
718+ // as true since the all fields are optional in validation context
719+ expr = `${ expr } ?? true` ;
720+ }
721+
722+ return `.refine((value: any) => ${ expr } ${ options } )` ;
723+ } catch ( err ) {
724+ if ( err instanceof TypeScriptExpressionTransformerError ) {
725+ throw new PluginError ( name , err . message ) ;
726+ } else {
727+ throw err ;
728+ }
729+ }
730+ } )
731+ . filter ( ( r ) => ! ! r ) ;
732+
733+ return refinements ;
734+ }
735+
661736 private makePartial ( schema : string , fields ?: string [ ] ) {
662737 if ( fields ) {
663738 if ( fields . length === 0 ) {
0 commit comments