@@ -119,15 +119,13 @@ trait StagedWasmEvaluator extends SAIOps {
119119 }
120120 }
121121
122- type MCont [A ] = Unit => A
122+ class MCont [A ]
123123 type Cont [A ] = (MCont [A ]) => A
124124 type Trail [A ] = List [Context => Rep [Cont [A ]]]
125125
126126 // a cache storing the compiled code for each function, to reduce re-compilation
127127 val compileCache = new HashMap [Int , Rep [(MCont [Unit ]) => Unit ]]
128128
129- def makeDummy : Rep [Unit ] = " dummy" .reflectCtrlWith[Unit ]()
130-
131129 def funHere [A : Manifest ,B : Manifest ](f : Rep [A ] => Rep [B ], dummy : Rep [Unit ]): Rep [A => B ] = {
132130 // to avoid LMS lifting a function, we create a dummy node and read it inside function
133131 fun((x : Rep [A ]) => {
@@ -136,6 +134,20 @@ trait StagedWasmEvaluator extends SAIOps {
136134 })
137135 }
138136
137+ def makeInitMCont [A : Manifest ](f : Rep [Unit => A ]): Rep [MCont [A ]] = {
138+ " make-init-mcont" .reflectCtrlWith[MCont [A ]](f)
139+ }
140+
141+ implicit class MContOps [A : Manifest ](mk : Rep [MCont [A ]]) {
142+ def prependCont (k : Rep [Cont [A ]]): Rep [MCont [A ]] = {
143+ " mcont-prepend" .reflectCtrlWith[MCont [A ]](mk, k)
144+ }
145+
146+ def enter (): Rep [A ] = {
147+ " mcont-enter" .reflectCtrlWith[A ](mk)
148+ }
149+ }
150+
139151 trait Control
140152
141153 // Save the current control information into a structure Control
@@ -308,7 +320,6 @@ trait StagedWasmEvaluator extends SAIOps {
308320 // the type system guarantees that we will never take more than the input size from the stack
309321 val funcTy = ty.funcType
310322 val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size
311- val dummy = makeDummy
312323 def restK (restCtx : Context ): Rep [Cont [Unit ]] = topFun((mk : Rep [MCont [Unit ]]) => {
313324 info(s " Exiting the block, stackSize = " , Stack .size)
314325 val offset = restCtx.stackTypes.size - exitSize
@@ -321,7 +332,6 @@ trait StagedWasmEvaluator extends SAIOps {
321332 case Loop (ty, inner) =>
322333 val funcTy = ty.funcType
323334 val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size
324- val dummy = makeDummy
325335 def restK (restCtx : Context ): Rep [Cont [Unit ]] = topFun((mk : Rep [MCont [Unit ]]) => {
326336 info(s " Exiting the loop, stackSize = " , Stack .size)
327337 val offset = restCtx.stackTypes.size - exitSize
@@ -454,7 +464,7 @@ trait StagedWasmEvaluator extends SAIOps {
454464 }
455465 }
456466
457- def forwardKont : Rep [Cont [Unit ]] = topFun((mk : Rep [MCont [Unit ]]) => mk(() ))
467+ def forwardKont : Rep [Cont [Unit ]] = topFun((mk : Rep [MCont [Unit ]]) => mk.enter( ))
458468
459469
460470 def evalCall (rest : List [Instr ],
@@ -481,7 +491,7 @@ trait StagedWasmEvaluator extends SAIOps {
481491 val offset = ctx.stackTypes.size - ty.out.size
482492 Stack .shiftC(offset, ty.out.size)
483493 Stack .shiftS(offset, ty.out.size)
484- mk(() )
494+ mk.enter( )
485495 })
486496 eval(body, retK _, mk, retK _:: Nil )(Context (Nil , locals))
487497 })
@@ -510,10 +520,7 @@ trait StagedWasmEvaluator extends SAIOps {
510520 Frames .popFrameS(locals.size)
511521 eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size)))
512522 })
513- val dummy = makeDummy
514- val newMKont : Rep [MCont [Unit ]] = funHere((_u : Rep [Unit ]) => {
515- restK(mkont)
516- }, dummy)
523+ val newMKont : Rep [MCont [Unit ]] = mkont.prependCont(restK)
517524 Frames .pushFrameC(locals)
518525 Frames .pushFrameS(locals)
519526 Frames .putAllC(argsC)
@@ -672,7 +679,7 @@ trait StagedWasmEvaluator extends SAIOps {
672679 ExploreTree .fillWithFinished()
673680 " no-op" .reflectCtrlWith[Unit ]()
674681 }
675- val temp : Rep [MCont [Unit ]] = topFun(haltK)
682+ val temp : Rep [MCont [Unit ]] = makeInitMCont( topFun(haltK) )
676683 evalTop(temp, main)
677684 }
678685
@@ -1356,6 +1363,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase {
13561363 else if (m.toString.endsWith(" I64V" )) " I64V"
13571364 else if (m.toString.endsWith(" SymVal" )) " SymVal"
13581365 else if (m.toString.endsWith(" Snapshot" )) " Snapshot_t"
1366+ else if (m.toString.endsWith(" MCont[Unit]" )) " MCont_t"
13591367 else super .remap(m)
13601368 }
13611369
@@ -1547,6 +1555,12 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase {
15471555 emit(" ExploreTree.dump_graphviz(" ); shallow(f); emit(" )" )
15481556 case Node (_, " sym-not" , List (s), _) =>
15491557 shallow(s); emit(" .negate()" )
1558+ case Node (_, " make-init-mcont" , List (haltK), _) =>
1559+ emit(" MCont_t(" ); shallow(haltK); emit(" )" )
1560+ case Node (_, " mcont-prepend" , List (mkont, kont), _) =>
1561+ emit(" prependCont(" ); shallow(kont); emit(" , " ); shallow(mkont); emit(" )" )
1562+ case Node (_, " mcont-enter" , List (mkont), _) =>
1563+ shallow(mkont); emit(" .enter()" )
15501564 case Node (_, " dummy" , _, _) => emit(" std::monostate()" )
15511565 case Node (_, " dummy-op" , _, _) => emit(" std::monostate()" )
15521566 case Node (_, " no-op" , _, _) =>
0 commit comments