Skip to content

Commit c94e582

Browse files
committed
fix: import handling for struct embedding with promoted methods
When embedding structs like bytes.Buffer, promoted methods reference types from packages (e.g., io.Reader, io.Writer) that may not be imported in the source Go file. This caused TypeScript compilation errors. The fix adds import detection during the analysis phase by scanning embedded struct types and recursively collecting all packages referenced by promoted method signatures. These imports are then written during code generation. Signed-off-by: Christian Stewart <christian@aperture.us>
1 parent a7768f2 commit c94e582

File tree

64 files changed

+1425
-122
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+1425
-122
lines changed

compiler/analysis.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,9 +1207,145 @@ func AnalyzePackageFiles(pkg *packages.Package, allPackages map[string]*packages
12071207
// Interface implementation async status is now updated on-demand in IsInterfaceMethodAsync
12081208
visitor.analyzeAllMethodsAsync()
12091209

1210+
// Fourth pass: collect imports needed by promoted methods from embedded structs
1211+
analysis.addImportsForPromotedMethods(pkg)
1212+
12101213
return analysis
12111214
}
12121215

1216+
// addImportsForPromotedMethods scans all struct types in the package for embedded fields
1217+
// and adds imports for any packages referenced by the promoted methods' parameter/return types.
1218+
func (a *Analysis) addImportsForPromotedMethods(pkg *packages.Package) {
1219+
// Collect all package names we need to add
1220+
packagesToAdd := make(map[string]*types.Package)
1221+
1222+
// Iterate through all type definitions in the package
1223+
scope := pkg.Types.Scope()
1224+
for _, name := range scope.Names() {
1225+
obj := scope.Lookup(name)
1226+
if obj == nil {
1227+
continue
1228+
}
1229+
1230+
// Check if it's a type definition
1231+
typeName, ok := obj.(*types.TypeName)
1232+
if !ok {
1233+
continue
1234+
}
1235+
1236+
// Get the underlying type
1237+
namedType, ok := typeName.Type().(*types.Named)
1238+
if !ok {
1239+
continue
1240+
}
1241+
1242+
// Check if it's a struct
1243+
structType, ok := namedType.Underlying().(*types.Struct)
1244+
if !ok {
1245+
continue
1246+
}
1247+
1248+
// Look for embedded fields
1249+
for i := 0; i < structType.NumFields(); i++ {
1250+
field := structType.Field(i)
1251+
if !field.Embedded() {
1252+
continue
1253+
}
1254+
1255+
// Get the type of the embedded field
1256+
embeddedType := field.Type()
1257+
1258+
// Handle pointer to embedded type
1259+
if ptr, ok := embeddedType.(*types.Pointer); ok {
1260+
embeddedType = ptr.Elem()
1261+
}
1262+
1263+
// Get named type to access methods
1264+
embeddedNamed, ok := embeddedType.(*types.Named)
1265+
if !ok {
1266+
continue
1267+
}
1268+
1269+
// Scan all methods of the embedded type
1270+
for j := 0; j < embeddedNamed.NumMethods(); j++ {
1271+
method := embeddedNamed.Method(j)
1272+
sig, ok := method.Type().(*types.Signature)
1273+
if !ok {
1274+
continue
1275+
}
1276+
1277+
// Scan parameters
1278+
if sig.Params() != nil {
1279+
for k := 0; k < sig.Params().Len(); k++ {
1280+
param := sig.Params().At(k)
1281+
a.collectPackageFromType(param.Type(), pkg.Types, packagesToAdd)
1282+
}
1283+
}
1284+
1285+
// Scan results
1286+
if sig.Results() != nil {
1287+
for k := 0; k < sig.Results().Len(); k++ {
1288+
result := sig.Results().At(k)
1289+
a.collectPackageFromType(result.Type(), pkg.Types, packagesToAdd)
1290+
}
1291+
}
1292+
}
1293+
}
1294+
}
1295+
1296+
// Add collected packages to imports
1297+
for pkgName, pkgObj := range packagesToAdd {
1298+
if _, exists := a.Imports[pkgName]; !exists {
1299+
tsImportPath := "@goscript/" + pkgObj.Path()
1300+
a.Imports[pkgName] = &fileImport{
1301+
importPath: tsImportPath,
1302+
importVars: make(map[string]struct{}),
1303+
}
1304+
}
1305+
}
1306+
}
1307+
1308+
// collectPackageFromType recursively collects packages referenced by a type.
1309+
func (a *Analysis) collectPackageFromType(t types.Type, currentPkg *types.Package, packagesToAdd map[string]*types.Package) {
1310+
switch typ := t.(type) {
1311+
case *types.Named:
1312+
pkg := typ.Obj().Pkg()
1313+
if pkg != nil && pkg != currentPkg {
1314+
packagesToAdd[pkg.Name()] = pkg
1315+
}
1316+
// Check type arguments for generics
1317+
if typ.TypeArgs() != nil {
1318+
for i := 0; i < typ.TypeArgs().Len(); i++ {
1319+
a.collectPackageFromType(typ.TypeArgs().At(i), currentPkg, packagesToAdd)
1320+
}
1321+
}
1322+
case *types.Pointer:
1323+
a.collectPackageFromType(typ.Elem(), currentPkg, packagesToAdd)
1324+
case *types.Slice:
1325+
a.collectPackageFromType(typ.Elem(), currentPkg, packagesToAdd)
1326+
case *types.Array:
1327+
a.collectPackageFromType(typ.Elem(), currentPkg, packagesToAdd)
1328+
case *types.Map:
1329+
a.collectPackageFromType(typ.Key(), currentPkg, packagesToAdd)
1330+
a.collectPackageFromType(typ.Elem(), currentPkg, packagesToAdd)
1331+
case *types.Chan:
1332+
a.collectPackageFromType(typ.Elem(), currentPkg, packagesToAdd)
1333+
case *types.Signature:
1334+
// Collect from parameters
1335+
if typ.Params() != nil {
1336+
for i := 0; i < typ.Params().Len(); i++ {
1337+
a.collectPackageFromType(typ.Params().At(i).Type(), currentPkg, packagesToAdd)
1338+
}
1339+
}
1340+
// Collect from results
1341+
if typ.Results() != nil {
1342+
for i := 0; i < typ.Results().Len(); i++ {
1343+
a.collectPackageFromType(typ.Results().At(i).Type(), currentPkg, packagesToAdd)
1344+
}
1345+
}
1346+
}
1347+
}
1348+
12131349
// AnalyzePackageImports performs package-level analysis to collect function definitions
12141350
// and calls across all files in the package for auto-import generation
12151351
func AnalyzePackageImports(pkg *packages.Package) *PackageAnalysis {

compiler/compiler.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,36 @@ func (c *FileCompiler) Compile(ctx context.Context) error {
676676
}
677677
}
678678

679+
// Write any imports that were added during analysis (e.g., for promoted methods)
680+
// but don't appear in the source AST
681+
writtenImports := make(map[string]bool)
682+
// Track imports that will be written from the AST
683+
for _, imp := range f.Imports {
684+
if imp.Path != nil {
685+
path := imp.Path.Value[1 : len(imp.Path.Value)-1] // Remove quotes
686+
if imp.Name != nil && imp.Name.Name != "" {
687+
writtenImports[imp.Name.Name] = true
688+
} else {
689+
// Use the actual package name
690+
if actualName, err := getActualPackageName(path, c.pkg.Imports); err == nil {
691+
writtenImports[actualName] = true
692+
}
693+
}
694+
}
695+
}
696+
// Write imports from analysis that aren't in the AST
697+
var additionalImports []string
698+
for pkgName := range c.Analysis.Imports {
699+
if !writtenImports[pkgName] && pkgName != "$" {
700+
additionalImports = append(additionalImports, pkgName)
701+
}
702+
}
703+
sort.Strings(additionalImports)
704+
for _, pkgName := range additionalImports {
705+
fileImp := c.Analysis.Imports[pkgName]
706+
c.codeWriter.WriteImport(pkgName, fileImp.importPath+"/index.js")
707+
}
708+
679709
c.codeWriter.WriteLine("") // Add a newline after imports
680710

681711
if err := goWriter.WriteDecls(f.Decls); err != nil {

compiler/spec-struct.go

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -345,15 +345,21 @@ func (c *GoToTSCompiler) WriteStructTypeSpec(a *ast.TypeSpec, t *ast.StructType)
345345
}
346346
}
347347

348-
// Promoted methods
349-
embeddedMethodSet := types.NewMethodSet(embeddedFieldType) // Use original field type for method set
350-
for k := range embeddedMethodSet.Len() {
351-
methodSelection := embeddedMethodSet.At(k)
352-
method := methodSelection.Obj().(*types.Func)
353-
methodName := method.Name()
354-
355-
// Skip if it's not a promoted method (indirect) or if it's shadowed by a direct method or an already processed promoted method
356-
if len(methodSelection.Index()) == 1 && !directMethods[methodName] && !seenPromotedFields[methodName] {
348+
// Promoted methods
349+
// Use pointer to embedded type to get both value and pointer receiver methods
350+
// This matches Go's behavior where embedding T promotes both T and *T methods
351+
methodSetType := embeddedFieldType
352+
if _, isPtr := embeddedFieldType.(*types.Pointer); !isPtr {
353+
methodSetType = types.NewPointer(embeddedFieldType)
354+
}
355+
embeddedMethodSet := types.NewMethodSet(methodSetType)
356+
for k := range embeddedMethodSet.Len() {
357+
methodSelection := embeddedMethodSet.At(k)
358+
method := methodSelection.Obj().(*types.Func)
359+
methodName := method.Name()
360+
361+
// Skip if it's not a promoted method (indirect) or if it's shadowed by a direct method or an already processed promoted method
362+
if len(methodSelection.Index()) == 1 && !directMethods[methodName] && !seenPromotedFields[methodName] {
357363
// Check for conflict with outer struct's own fields
358364
conflictWithField := false
359365
for k_idx := 0; k_idx < underlyingStruct.NumFields(); k_idx++ {
@@ -543,15 +549,16 @@ func (c *GoToTSCompiler) generateFlattenedInitTypeString(structType *types.Named
543549
fieldType = ptr.Elem()
544550
}
545551

546-
if named, ok := fieldType.(*types.Named); ok {
547-
embeddedName := named.Obj().Name()
552+
if _, ok := fieldType.(*types.Named); ok {
548553
// Check if the embedded type is an interface
549554
if _, isInterface := fieldType.Underlying().(*types.Interface); isInterface {
550555
// For embedded interfaces, use the full qualified interface type
551556
embeddedTypeMap[c.getEmbeddedFieldKeyName(field.Type())] = c.getTypeString(field.Type())
552557
} else {
553558
// For embedded structs, use ConstructorParameters for field-based initialization
554-
embeddedTypeMap[c.getEmbeddedFieldKeyName(field.Type())] = fmt.Sprintf("Partial<ConstructorParameters<typeof %s>[0]>", embeddedName)
559+
// Use getTypeString to get the qualified type name (e.g., bytes.Buffer not just Buffer)
560+
qualifiedTypeName := c.getTypeString(fieldType)
561+
embeddedTypeMap[c.getEmbeddedFieldKeyName(field.Type())] = fmt.Sprintf("Partial<ConstructorParameters<typeof %s>[0]>", qualifiedTypeName)
555562
}
556563
}
557564
continue

compiler/stmt.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,15 @@ func (c *GoToTSCompiler) WriteStmtExpr(exp *ast.ExprStmt) error {
368368
return nil
369369
}
370370

371+
// Defensive semicolon: if the expression will start with '(' in TypeScript,
372+
// prepend a semicolon to prevent JavaScript from treating the previous line
373+
// as a function call. This happens when:
374+
// 1. CallExpr where Fun itself will be parenthesized (e.g., (await fn())())
375+
// 2. Array/slice literals starting with '['
376+
if c.needsDefensiveSemicolon(exp.X) {
377+
c.tsw.WriteLiterally(";")
378+
}
379+
371380
// Handle other expression statements
372381
if err := c.WriteValueExpr(exp.X); err != nil { // Expression statement evaluates a value
373382
return err
@@ -1129,3 +1138,38 @@ func (c *GoToTSCompiler) substituteExprForShadowing(expr ast.Expr, shadowingInfo
11291138
func (c *GoToTSCompiler) isBuiltinFunction(name string) bool {
11301139
return builtinFunctions[name]
11311140
}
1141+
1142+
// needsDefensiveSemicolon determines if an expression will generate TypeScript
1143+
// code starting with '(' or '[', which would require a defensive semicolon to
1144+
// prevent JavaScript from treating the previous line as a function call.
1145+
func (c *GoToTSCompiler) needsDefensiveSemicolon(expr ast.Expr) bool {
1146+
switch e := expr.(type) {
1147+
case *ast.CallExpr:
1148+
// Check if the function being called will be parenthesized
1149+
// This happens when Fun is itself a CallExpr, TypeAssertExpr, or other complex expression
1150+
switch e.Fun.(type) {
1151+
case *ast.CallExpr:
1152+
// (fn())() - needs defensive semicolon
1153+
return true
1154+
case *ast.TypeAssertExpr:
1155+
// (x.(T))() - needs defensive semicolon
1156+
return true
1157+
case *ast.IndexExpr:
1158+
// Could generate (arr[i])() if indexed result is called
1159+
// But typically doesn't need defensive semicolon as arr[i]() is fine
1160+
return false
1161+
case *ast.ParenExpr:
1162+
// Already parenthesized - needs defensive semicolon
1163+
return true
1164+
}
1165+
case *ast.CompositeLit:
1166+
// Array/slice literals start with '['
1167+
if _, isArray := e.Type.(*ast.ArrayType); isArray {
1168+
return true
1169+
}
1170+
case *ast.ParenExpr:
1171+
// Parenthesized expressions start with '('
1172+
return true
1173+
}
1174+
return false
1175+
}

compiler/type.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,8 @@ func (c *GoToTSCompiler) WriteNamedType(t *types.Named) {
249249
typePkg := t.Obj().Pkg()
250250
if typePkg != nil && typePkg != c.pkg.Types {
251251
// This type is from an imported package, find the import alias
252-
if alias, found := c.resolveImportAlias(typePkg); found {
252+
alias, found := c.resolveImportAlias(typePkg)
253+
if found && alias != "" {
253254
// Write the qualified name: importAlias.TypeName
254255
c.tsw.WriteLiterally(alias)
255256
c.tsw.WriteLiterally(".")

compliance/WIP.md

Lines changed: 0 additions & 51 deletions
This file was deleted.

compliance/deps/encoding/json/decode.gs.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ import { foldName } from "./fold.gs.js";
44
import { checkValid, stateEndValue } from "./scanner.gs.js";
55
import { structFields } from "./encode.gs.js";
66
import { scanner } from "./scanner.gs.js";
7+
import * as bytes from "bytes/index.js"
8+
import * as cmp from "cmp/index.js"
9+
import * as errors from "errors/index.js"
10+
import * as io from "io/index.js"
11+
import * as math from "math/index.js"
12+
import * as slices from "slices/index.js"
13+
import * as sync from "sync/index.js"
714

815
import * as encoding from "@goscript/encoding/index.js"
916

0 commit comments

Comments
 (0)