@@ -53,6 +53,8 @@ type Field struct {
5353 Name string
5454 Type pyType
5555 Comment string
56+ // EmbedFields contains the embedded fields that require scanning.
57+ EmbedFields []Field
5658}
5759
5860type Struct struct {
@@ -105,14 +107,42 @@ func (v QueryValue) RowNode(rowVar string) *pyast.Node {
105107 call := & pyast.Call {
106108 Func : v .Annotation (),
107109 }
108- for i , f := range v .Struct .Fields {
109- call .Keywords = append (call .Keywords , & pyast.Keyword {
110- Arg : f .Name ,
111- Value : subscriptNode (
110+ rowIndex := 0 // We need to keep track of the index in the row variable.
111+ for _ , f := range v .Struct .Fields {
112+
113+ var valueNode * pyast.Node
114+ // Check if we are using sqlc.embed, if so we need to create a new object.
115+ if len (f .EmbedFields ) > 0 {
116+ // We keep this separate so we can easily add all arguments.
117+ embed_call := & pyast.Call {Func : f .Type .Annotation ()}
118+
119+ // Now add all field Initializers for the embedded model that index into the original row.
120+ for i , embedField := range f .EmbedFields {
121+ embed_call .Keywords = append (embed_call .Keywords , & pyast.Keyword {
122+ Arg : embedField .Name ,
123+ Value : subscriptNode (
124+ rowVar ,
125+ constantInt (rowIndex + i ),
126+ ),
127+ })
128+ }
129+
130+ valueNode = & pyast.Node {
131+ Node : & pyast.Node_Call {
132+ Call : embed_call ,
133+ },
134+ }
135+
136+ rowIndex += len (f .EmbedFields )
137+ } else {
138+ valueNode = subscriptNode (
112139 rowVar ,
113- constantInt (i ),
114- ),
115- })
140+ constantInt (rowIndex ),
141+ )
142+ rowIndex ++
143+ }
144+
145+ call .Keywords = append (call .Keywords , & pyast.Keyword {Arg : f .Name , Value : valueNode })
116146 }
117147 return & pyast.Node {
118148 Node : & pyast.Node_Call {
@@ -336,6 +366,47 @@ func paramName(p *plugin.Parameter) string {
336366type pyColumn struct {
337367 id int32
338368 * plugin.Column
369+ embed * pyEmbed
370+ }
371+
372+ type pyEmbed struct {
373+ modelType string
374+ modelName string
375+ fields []Field
376+ }
377+
378+ // Taken from https://github.com/sqlc-dev/sqlc/blob/8c59fbb9938a0bad3d9971fc2c10ea1f83cc1d0b/internal/codegen/golang/result.go#L123-L126
379+ // look through all the structs and attempt to find a matching one to embed
380+ // We need the name of the struct and its field names.
381+ func newGoEmbed (embed * plugin.Identifier , structs []Struct , defaultSchema string ) * pyEmbed {
382+ if embed == nil {
383+ return nil
384+ }
385+
386+ for _ , s := range structs {
387+ embedSchema := defaultSchema
388+ if embed .Schema != "" {
389+ embedSchema = embed .Schema
390+ }
391+
392+ // compare the other attributes
393+ if embed .Catalog != s .Table .Catalog || embed .Name != s .Table .Name || embedSchema != s .Table .Schema {
394+ continue
395+ }
396+
397+ fields := make ([]Field , len (s .Fields ))
398+ for i , f := range s .Fields {
399+ fields [i ] = f
400+ }
401+
402+ return & pyEmbed {
403+ modelType : s .Name ,
404+ modelName : s .Name ,
405+ fields : fields ,
406+ }
407+ }
408+
409+ return nil
339410}
340411
341412func columnsToStruct (req * plugin.CodeGenRequest , name string , columns []pyColumn ) * Struct {
@@ -359,10 +430,22 @@ func columnsToStruct(req *plugin.CodeGenRequest, name string, columns []pyColumn
359430 if suffix > 0 {
360431 fieldName = fmt .Sprintf ("%s_%d" , fieldName , suffix )
361432 }
362- gs .Fields = append (gs .Fields , Field {
433+
434+ f := Field {
363435 Name : fieldName ,
364436 Type : makePyType (req , c .Column ),
365- })
437+ }
438+
439+ if c .embed != nil {
440+ f .Type = pyType {
441+ InnerType : "models." + modelName (c .embed .modelType , req .Settings ),
442+ IsArray : false ,
443+ IsNull : false ,
444+ }
445+ f .EmbedFields = c .embed .fields
446+ }
447+
448+ gs .Fields = append (gs .Fields , f )
366449 seen [colName ]++
367450 }
368451 return & gs
@@ -476,6 +559,7 @@ func buildQueries(conf Config, req *plugin.CodeGenRequest, structs []Struct) ([]
476559 columns = append (columns , pyColumn {
477560 id : int32 (i ),
478561 Column : c ,
562+ embed : newGoEmbed (c .EmbedTable , structs , req .Catalog .DefaultSchema ),
479563 })
480564 }
481565 gs = columnsToStruct (req , query .Name + "Row" , columns )
0 commit comments