55 "sort"
66 "strings"
77
8- "github.com/tabbed/ sqlc- go/codegen "
8+ "buf.build/gen/go/ sqlc/sqlc/protocolbuffers/ go/protos/plugin "
99)
1010
1111type importSpec struct {
@@ -14,15 +14,8 @@ type importSpec struct {
1414 Alias string
1515}
1616
17- func pyTypeIsSet (t * codegen.PythonType ) bool {
18- return t .Module != "" || t .Name != ""
19- }
20-
21- func pyTypeString (t * codegen.PythonType ) string {
22- if t .Name != "" && t .Module == "" {
23- return t .Name
24- }
25- return t .Module + "." + t .Name
17+ func pyTypeIsSet (o * plugin.Override ) bool {
18+ return o .CodeType != ""
2619}
2720
2821func (i importSpec ) String () string {
@@ -39,7 +32,7 @@ func (i importSpec) String() string {
3932}
4033
4134type importer struct {
42- Settings * codegen .Settings
35+ Settings * plugin .Settings
4336 Models []Struct
4437 Queries []Query
4538 Enums []Enum
@@ -112,12 +105,17 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS
112105 pkg := make (map [string ]importSpec )
113106
114107 for _ , o := range i .Settings .Overrides {
115- if pyTypeIsSet (o .PythonType ) && o .PythonType .Module != "" {
116- if modelUses (pyTypeString (o .PythonType )) {
117- pkg [o .PythonType .Module ] = importSpec {Module : o .PythonType .Module }
108+ if pyTypeIsSet (o ) {
109+ mod , _ , found := strings .Cut (o .CodeType , "." )
110+ if ! found {
111+ continue
112+ }
113+ if modelUses (o .CodeType ) {
114+ pkg [mod ] = importSpec {Module : mod }
118115 }
119116 }
120117 }
118+
121119 return std , pkg
122120}
123121
@@ -158,9 +156,13 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map
158156 }
159157
160158 for _ , o := range i .Settings .Overrides {
161- if pyTypeIsSet (o .PythonType ) && o .PythonType .Module != "" {
162- if queryUses (pyTypeString (o .PythonType )) {
163- pkg [o .PythonType .Module ] = importSpec {Module : o .PythonType .Module }
159+ if pyTypeIsSet (o ) {
160+ mod , _ , found := strings .Cut (o .CodeType , "." )
161+ if ! found {
162+ continue
163+ }
164+ if queryUses (o .CodeType ) {
165+ pkg [mod ] = importSpec {Module : mod }
164166 }
165167 }
166168 }
0 commit comments