Skip to content

Commit 2fd0f16

Browse files
authored
Proper handling of local variables capturing inside a loop body (#351)
1 parent 8814bc3 commit 2fd0f16

File tree

13 files changed

+276
-137
lines changed

13 files changed

+276
-137
lines changed

hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ class CompileTestRunner
4848

4949
val preludePath = mainTestDir/"mlscript"/"decls"/"Prelude.mls"
5050

51-
given Config = Config.default
51+
// Stack safety relies on the fact that runtime uses while loops for resumption
52+
// and does not create extra stack depth. Hence we disable while loop rewriting here.
53+
given Config = Config.default.copy(rewriteWhileLoops = false)
5254

5355
val compiler = MLsCompiler(
5456
preludePath,

hkmc2/shared/src/main/scala/hkmc2/Config.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ case class Config(
2222
liftDefns: Opt[LiftDefns],
2323
stageCode: Bool,
2424
target: CompilationTarget,
25+
rewriteWhileLoops: Bool,
2526
):
2627

2728
def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety)
@@ -36,8 +37,9 @@ object Config:
3637
// sanityChecks = S(SanityChecks(light = true)),
3738
effectHandlers = N,
3839
liftDefns = N,
40+
target = CompilationTarget.JS,
41+
rewriteWhileLoops = true,
3942
stageCode = false,
40-
target = CompilationTarget.JS
4143
)
4244

4345
case class SanityChecks(light: Bool)

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ object Thrw extends TailOp:
3636

3737

3838
// * No longer in meaningful use and could be removed if we don't find a use for it:
39-
class Subst(initMap: Map[Local, Value]):
39+
class LoweringCtx(initMap: Map[Local, Value], val mayRet: Bool):
4040
val map = initMap
4141
/*
4242
def +(kv: (Local, Value)): Subst =
@@ -49,12 +49,13 @@ class Subst(initMap: Map[Local, Value]):
4949
def apply(v: Value): Value = v match
5050
case Value.Ref(l) => map.getOrElse(l, v)
5151
case _ => v
52-
object Subst:
53-
val empty = Subst(Map.empty)
54-
def subst(using sub: Subst): Subst = sub
55-
end Subst
52+
object LoweringCtx:
53+
val empty = LoweringCtx(Map.empty, false)
54+
def subst(using sub: LoweringCtx): LoweringCtx = sub
55+
def nestFunc(using sub: LoweringCtx): LoweringCtx = LoweringCtx(sub.map, true)
56+
end LoweringCtx
5657

57-
import Subst.subst
58+
import LoweringCtx.subst
5859

5960

6061
class Lowering()(using Config, TL, Raise, State, Ctx):
@@ -75,7 +76,6 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
7576
def unit: Path =
7677
Select(Value.Ref(State.runtimeSymbol), Tree.Ident("Unit"))(S(State.unitSymbol))
7778

78-
7979
def fail(err: ErrorReport): Block =
8080
raise(err)
8181
End("error")
@@ -84,9 +84,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
8484
// type Rcd = (mut: Bool, args: List[RcdArg]) // * Better, but Scala's patmat exhaustiveness chokes on it
8585
type Rcd = (Bool, List[RcdArg])
8686

87-
def returnedTerm(t: st)(using Subst): Block = term(t)(Ret)
87+
def returnedTerm(t: st)(using LoweringCtx): Block = term(t)(Ret)(using LoweringCtx.nestFunc)
8888

89-
def parentConstructor(cls: Term, args: Ls[Term])(using Subst) =
89+
def parentConstructor(cls: Term, args: Ls[Term])(using LoweringCtx) =
9090
if args.length > 1 then
9191
raise:
9292
ErrorReport(
@@ -101,7 +101,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
101101
)(c => Return(c, implct = true))
102102

103103
// * Used to work around Scala's @tailrec annotation for those few calls that are not in tail position.
104-
final def term_nonTail(t: st, inStmtPos: Bool = false)(k: Result => Block)(using Subst): Block =
104+
final def term_nonTail(t: st, inStmtPos: Bool = false)(k: Result => Block)(using LoweringCtx): Block =
105105
term(t: st, inStmtPos: Bool)(k)
106106

107107

@@ -120,12 +120,12 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
120120
(imps.reverse, funs.reverse, rest.reverse)
121121

122122

123-
def block(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using Subst): Block =
123+
def block(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using LoweringCtx): Block =
124124
// TODO we should also isolate and reorder classes by inheritance topological sort
125125
val (imps, funs, rest) = splitBlock(stats, Nil, Nil, Nil)
126126
blockImpl(imps ::: funs ::: rest, res)(k)
127127

128-
def blockImpl(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using Subst): Block =
128+
def blockImpl(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using LoweringCtx): Block =
129129
stats match
130130
case (t: sem.Term) :: stats =>
131131
subTerm(t, inStmtPos = true): r =>
@@ -178,7 +178,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
178178
// Assign(td.sym, r,
179179
// term(st.Blk(stats, res))(k)))
180180
Define(ValDefn(td.tsym, td.sym, r),
181-
blockImpl(stats, res)(k)))
181+
blockImpl(stats, res)(k)))(using LoweringCtx.nestFunc)
182182
case syntax.Fun =>
183183
val (paramLists, bodyBlock) = setupFunctionOrByNameDef(td.params, bod, S(td.sym.nme))
184184
Define(FunDefn(td.owner, td.sym, paramLists, bodyBlock),
@@ -302,17 +302,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
302302
blockImpl(stats, res)(k)
303303

304304

305-
def lowerCall(fr: Path, isMlsFun: Bool, arg: Opt[Term], loc: Opt[Loc])(k: Result => Block)(using Subst): Block =
305+
def lowerCall(fr: Path, isMlsFun: Bool, arg: Opt[Term], loc: Opt[Loc])(k: Result => Block)(using LoweringCtx): Block =
306306
arg match
307307
case S(a) =>
308308
lowerCall(fr, isMlsFun, a, loc)(k)
309309
case N =>
310310
// * No arguments means a nullary call, e.g., `f()`
311311
k(Call(fr, Nil)(isMlsFun, true).withLoc(loc))
312-
def lowerCall(fr: Path, isMlsFun: Bool, arg: Term, loc: Opt[Loc])(k: Result => Block)(using Subst): Block =
312+
def lowerCall(fr: Path, isMlsFun: Bool, arg: Term, loc: Opt[Loc])(k: Result => Block)(using LoweringCtx): Block =
313313
lowerArg(arg)(as => k(Call(fr, as)(isMlsFun, true).withLoc(loc)))
314314

315-
def lowerArg(arg: Term)(k: Ls[Arg] => Block)(using Subst): Block =
315+
def lowerArg(arg: Term)(k: Ls[Arg] => Block)(using LoweringCtx): Block =
316316
arg match
317317
case Tup(fs) =>
318318
if fs.exists(e => e match
@@ -329,7 +329,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
329329
k(Arg(spread = S(true), ar) :: Nil)
330330

331331
@tailrec
332-
final def term(t: st, inStmtPos: Bool = false)(k: Result => Block)(using Subst): Block =
332+
final def term(t: st, inStmtPos: Bool = false)(k: Result => Block)(using LoweringCtx): Block =
333333
tl.log(s"Lowering.term ${t.showDbg.truncate(100, "[...]")}${
334334
if inStmtPos then " (in stmt)" else ""}${
335335
t.resolvedSym.fold("")(" – symbol " + _)}")
@@ -680,17 +680,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
680680
// case _ =>
681681
// subTerm(t)(k)
682682

683-
def setupTerm(name: Str, args: Ls[Path])(k: Result => Block)(using Subst): Block =
683+
def setupTerm(name: Str, args: Ls[Path])(k: Result => Block)(using LoweringCtx): Block =
684684
k(Instantiate(mut = false, Value.Ref(State.termSymbol).selSN(name), args.map(_.asArg)))
685685

686686
def setupQuotedKeyword(kw: Str): Path =
687687
Value.Ref(State.termSymbol).selSN("Keyword").selSN(kw)
688688

689-
def setupSymbol(symbol: Local)(k: Result => Block)(using Subst): Block =
689+
def setupSymbol(symbol: Local)(k: Result => Block)(using LoweringCtx): Block =
690690
k(Instantiate(mut = false, Value.Ref(State.termSymbol).selSN("Symbol"),
691691
Value.Lit(Tree.StrLit(symbol.nme)).asArg :: Nil))
692692

693-
def quotePattern(p: FlatPattern)(k: Result => Block)(using Subst): Block = p match
693+
def quotePattern(p: FlatPattern)(k: Result => Block)(using LoweringCtx): Block = p match
694694
case FlatPattern.Lit(lit) => setupTerm("LitPattern", Value.Lit(lit) :: Nil)(k)
695695
case _ => // TODO
696696
fail:
@@ -700,7 +700,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
700700
source = Diagnostic.Source.Compilation
701701
)
702702

703-
def quoteSplit(split: Split)(k: Result => Block)(using Subst): Block = split match
703+
def quoteSplit(split: Split)(k: Result => Block)(using LoweringCtx): Block = split match
704704
case Split.Cons(Branch(scrutinee, pattern, continuation), tail) => quote(scrutinee): r1 =>
705705
val l1, l2, l3, l4, l5 = new TempSymbol(N)
706706
blockBuilder.assign(l1, r1)
@@ -725,7 +725,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
725725
val state = summon[State]
726726
Value.Ref(state.importSymbol).selSN("meta").selSN("url")
727727

728-
def quote(t: st)(k: Result => Block)(using Subst): Block = t match
728+
def quote(t: st)(k: Result => Block)(using LoweringCtx): Block = t match
729729
case Lit(lit) =>
730730
setupTerm("Lit", Value.Lit(lit) :: Nil)(k)
731731
case Ref(sym) if Elaborator.binaryOps.contains(sym.nme) => // builtin symbols
@@ -756,7 +756,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
756756
source = Diagnostic.Source.Compilation
757757
)
758758
case Lam(params, body) =>
759-
def rec(ps: Ls[LocalSymbol & NamedSymbol], ds: Ls[Path])(k: Result => Block)(using Subst): Block = ps match
759+
def rec(ps: Ls[LocalSymbol & NamedSymbol], ds: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = ps match
760760
case Nil => quote(body): r =>
761761
val l = new TempSymbol(N)
762762
val arr = new TempSymbol(N, "arr")
@@ -818,7 +818,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
818818
source = Diagnostic.Source.Compilation
819819
)
820820

821-
def gatherMembers(clsBody: ObjBody)(using Subst)
821+
def gatherMembers(clsBody: ObjBody)(using LoweringCtx)
822822
: (Ls[FunDefn], Ls[BlockMemberSymbol -> TermSymbol], Ls[TermSymbol], Block) =
823823
val mtds = clsBody.methods
824824
.flatMap: td =>
@@ -838,7 +838,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
838838
case t => t
839839
(mtds, publicFlds, privateFlds, ctor)
840840

841-
def args(elems: Ls[Elem])(k: Ls[Arg] => Block)(using Subst): Block =
841+
def args(elems: Ls[Elem])(k: Ls[Arg] => Block)(using LoweringCtx): Block =
842842
val as = elems.map:
843843
case sem.Fld(sem.FldFlags.benign(), value, N) => R(N -> value)
844844
case sem.Fld(sem.FldFlags.benign(), idx, S(rhs)) => L(idx -> rhs)
@@ -880,10 +880,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
880880
k((Arg(N, Value.Ref(rcdSym)) :: asr).reverse)))
881881

882882

883-
inline def plainArgs(ts: Ls[st])(k: Ls[Arg] => Block)(using Subst): Block =
883+
inline def plainArgs(ts: Ls[st])(k: Ls[Arg] => Block)(using LoweringCtx): Block =
884884
subTerms(ts)(asr => k(asr.map(Arg(N, _))))
885885

886-
inline def subTerms(ts: Ls[st])(k: Ls[Path] => Block)(using Subst): Block =
886+
inline def subTerms(ts: Ls[st])(k: Ls[Path] => Block)(using LoweringCtx): Block =
887887
// @tailrec // TODO
888888
def rec(as: Ls[st], asr: Ls[Path]): Block = as match
889889
case Nil => k(asr.reverse)
@@ -892,10 +892,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
892892
rec(as, ar :: asr)
893893
rec(ts, Nil)
894894

895-
def subTerm_nonTail(t: st, inStmtPos: Bool = false)(k: Path => Block)(using Subst): Block =
895+
def subTerm_nonTail(t: st, inStmtPos: Bool = false)(k: Path => Block)(using LoweringCtx): Block =
896896
subTerm(t: st, inStmtPos: Bool)(k)
897897

898-
inline def subTerm(t: st, inStmtPos: Bool = false)(k: Path => Block)(using Subst): Block =
898+
inline def subTerm(t: st, inStmtPos: Bool = false)(k: Path => Block)(using LoweringCtx): Block =
899899
term(t, inStmtPos = inStmtPos):
900900
case v: Value => k(v)
901901
case p: Path => k(p)
@@ -912,7 +912,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
912912

913913
val (imps, funs, rest) = splitBlock(main.stats, Nil, Nil, Nil)
914914

915-
val blk = block(funs ::: rest, R(main.res))(ImplctRet)(using Subst.empty)
915+
val blk = block(funs ::: rest, R(main.res))(ImplctRet)(using LoweringCtx.empty)
916916

917917
val desug = LambdaRewriter.desugar(blk)
918918

@@ -945,20 +945,20 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
945945
)
946946

947947

948-
def setupSelection(prefix: Term, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using Subst): Block =
948+
def setupSelection(prefix: Term, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using LoweringCtx): Block =
949949
subTerm(prefix): p =>
950950
val selRes = TempSymbol(N, "selRes")
951951
k(Select(p, nme)(sym))
952952

953953
final def setupFunctionOrByNameDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str])
954-
(using Subst): (List[ParamList], Block) =
954+
(using LoweringCtx): (List[ParamList], Block) =
955955
val physicalParams = paramLists match
956956
case Nil => ParamList(ParamListFlags.empty, Nil, N) :: Nil
957957
case ps => ps
958958
setupFunctionDef(physicalParams, bodyTerm, name)
959959

960960
def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str])
961-
(using Subst): (List[ParamList], Block) =
961+
(using LoweringCtx): (List[ParamList], Block) =
962962
(paramLists, returnedTerm(bodyTerm))
963963

964964
def reportAnnotations(target: Statement, annotations: Ls[Annot]): Unit =
@@ -974,7 +974,7 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State)
974974

975975
private val instrument: Bool = config.sanityChecks.isDefined
976976

977-
override def setupSelection(prefix: st, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using Subst): Block =
977+
override def setupSelection(prefix: st, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using LoweringCtx): Block =
978978
if !instrument then return super.setupSelection(prefix, nme, sym)(k)
979979
subTerm(prefix): p =>
980980
val selRes = TempSymbol(N, "selRes")
@@ -1021,7 +1021,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State)
10211021

10221022

10231023
override def setupFunctionDef(paramLists: List[ParamList], bodyTerm: st, name: Option[Str])
1024-
(using Subst): (List[ParamList], Block) =
1024+
(using LoweringCtx): (List[ParamList], Block) =
10251025
if instrument then
10261026
val (ps, bod) = handleMultipleParamLists(paramLists, bodyTerm)
10271027
val instrumentedBody = setupFunctionBody(ps, bod, name)
@@ -1037,7 +1037,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State)
10371037
case h :: t => go(t, Term.Lam(h, bod))
10381038
go(paramLists.reverse, bod)
10391039

1040-
def setupFunctionBody(params: ParamList, bod: Term, name: Option[Str])(using Subst): Block =
1040+
def setupFunctionBody(params: ParamList, bod: Term, name: Option[Str])(using LoweringCtx): Block =
10411041
val enterMsgSym = TempSymbol(N, dbgNme = "traceLogEnterMsg")
10421042
val prevIndentLvlSym = TempSymbol(N, dbgNme = "traceLogPrevIndent")
10431043
val resSym = TempSymbol(N, dbgNme = "traceLogRes")
@@ -1073,7 +1073,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State)
10731073
TempSymbol(N) -> pureCall(traceLogFn, Arg(N, Value.Ref(retMsgSym)) :: Nil)
10741074
) |>:
10751075
Ret(Value.Ref(resSym))
1076-
)
1076+
)(using LoweringCtx.nestFunc)
10771077

10781078

10791079
object TrivialStatementsAndMatch:

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ object Elaborator:
227227
given State = this
228228
val globalThisSymbol = TopLevelSymbol("globalThis")
229229
val unitSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Obj), Ident("Unit"))
230+
val loopEndSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Obj), Ident("LoopEnd"))
230231
// In JavaScript, `import` can be used for getting current file path, as `import.meta`
231232
val importSymbol = new VarSymbol(Ident("import"))
232233
val runtimeSymbol = TempSymbol(N, "runtime")

0 commit comments

Comments
 (0)