Skip to content

Commit 94fe901

Browse files
avoid some usage of fun
1 parent 832af8b commit 94fe901

File tree

2 files changed

+75
-13
lines changed

2 files changed

+75
-13
lines changed

headers/wasm/controls.hpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,58 @@
44

55
#include <functional>
66

7+
#include <iostream>
8+
#include <memory>
79
#include <variant>
810

9-
using MCont_t = std::function<std::monostate(std::monostate)>;
11+
class MContRepr;
12+
struct MCont_t {
13+
std::shared_ptr<MContRepr> ptr;
14+
MCont_t() : ptr(nullptr) {}
15+
MCont_t(const MCont_t &p) : ptr(p.ptr) {}
16+
MCont_t(std::shared_ptr<MContRepr> p) : ptr(p) {}
17+
MCont_t(std::function<std::monostate(std::monostate)> haltK)
18+
: ptr(std::make_shared<MContRepr>(haltK)) {}
19+
bool is_null() const { return ptr == nullptr; }
20+
21+
std::monostate enter();
22+
};
1023
using Cont_t = std::function<std::monostate(MCont_t)>;
24+
class MContRepr {
25+
public:
26+
MContRepr(Cont_t cont, MCont_t mcont) : cont(cont), mcont(mcont) {}
27+
28+
MContRepr(std::function<std::monostate(std::monostate)> haltK)
29+
: cont([=](MCont_t) {
30+
// std::cout << "Halting the program..." << std::endl;
31+
32+
return haltK(std::monostate{});
33+
}),
34+
mcont() {}
35+
36+
MContRepr() : cont(nullptr), mcont() {}
37+
38+
std::monostate enter() {
39+
// std::cout << "Entering MCont\n";
40+
// std::cout << "Cont cont: " << (cont ? "valid" : "null") << "\n";
41+
// std::cout << "MCont mcont: " << (mcont ? "valid" : "null") << "\n";
42+
if (mcont.is_null()) {
43+
return cont(std::make_shared<MContRepr>(
44+
MContRepr())); // when mcont is null, we pass a dummy MContRepr
45+
}
46+
return cont(mcont);
47+
}
48+
49+
private:
50+
Cont_t cont;
51+
MCont_t mcont;
52+
};
53+
54+
inline MCont_t prependCont(Cont_t k, MCont_t mcont) {
55+
return std::make_shared<MContRepr>(k, mcont);
56+
}
57+
58+
inline std::monostate MCont_t::enter() { return ptr->enter(); }
1159

1260
struct Control {
1361
Cont_t cont;

src/main/scala/wasm/StagedConcolicMiniWasm.scala

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,13 @@ trait StagedWasmEvaluator extends SAIOps {
119119
}
120120
}
121121

122-
type MCont[A] = Unit => A
122+
class MCont[A]
123123
type Cont[A] = (MCont[A]) => A
124124
type Trail[A] = List[Context => Rep[Cont[A]]]
125125

126126
// a cache storing the compiled code for each function, to reduce re-compilation
127127
val compileCache = new HashMap[Int, Rep[(MCont[Unit]) => Unit]]
128128

129-
def makeDummy: Rep[Unit] = "dummy".reflectCtrlWith[Unit]()
130-
131129
def funHere[A:Manifest,B:Manifest](f: Rep[A] => Rep[B], dummy: Rep[Unit]): Rep[A => B] = {
132130
// to avoid LMS lifting a function, we create a dummy node and read it inside function
133131
fun((x: Rep[A]) => {
@@ -136,6 +134,20 @@ trait StagedWasmEvaluator extends SAIOps {
136134
})
137135
}
138136

137+
def makeInitMCont[A:Manifest](f: Rep[Unit => A]): Rep[MCont[A]] = {
138+
"make-init-mcont".reflectCtrlWith[MCont[A]](f)
139+
}
140+
141+
implicit class MContOps[A:Manifest](mk: Rep[MCont[A]]) {
142+
def prependCont(k: Rep[Cont[A]]): Rep[MCont[A]] = {
143+
"mcont-prepend".reflectCtrlWith[MCont[A]](mk, k)
144+
}
145+
146+
def enter(): Rep[A] = {
147+
"mcont-enter".reflectCtrlWith[A](mk)
148+
}
149+
}
150+
139151
trait Control
140152

141153
// Save the current control information into a structure Control
@@ -308,7 +320,6 @@ trait StagedWasmEvaluator extends SAIOps {
308320
// the type system guarantees that we will never take more than the input size from the stack
309321
val funcTy = ty.funcType
310322
val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size
311-
val dummy = makeDummy
312323
def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => {
313324
info(s"Exiting the block, stackSize =", Stack.size)
314325
val offset = restCtx.stackTypes.size - exitSize
@@ -321,7 +332,6 @@ trait StagedWasmEvaluator extends SAIOps {
321332
case Loop(ty, inner) =>
322333
val funcTy = ty.funcType
323334
val exitSize = ctx.stackTypes.size - funcTy.inps.size + funcTy.out.size
324-
val dummy = makeDummy
325335
def restK(restCtx: Context): Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => {
326336
info(s"Exiting the loop, stackSize =", Stack.size)
327337
val offset = restCtx.stackTypes.size - exitSize
@@ -454,7 +464,7 @@ trait StagedWasmEvaluator extends SAIOps {
454464
}
455465
}
456466

457-
def forwardKont: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => mk(()))
467+
def forwardKont: Rep[Cont[Unit]] = topFun((mk: Rep[MCont[Unit]]) => mk.enter())
458468

459469

460470
def evalCall(rest: List[Instr],
@@ -481,7 +491,7 @@ trait StagedWasmEvaluator extends SAIOps {
481491
val offset = ctx.stackTypes.size - ty.out.size
482492
Stack.shiftC(offset, ty.out.size)
483493
Stack.shiftS(offset, ty.out.size)
484-
mk(())
494+
mk.enter()
485495
})
486496
eval(body, retK _, mk, retK _::Nil)(Context(Nil, locals))
487497
})
@@ -510,10 +520,7 @@ trait StagedWasmEvaluator extends SAIOps {
510520
Frames.popFrameS(locals.size)
511521
eval(rest, kont, mk, trail)(newCtx.copy(stackTypes = ty.out.reverse ++ ctx.stackTypes.drop(ty.inps.size)))
512522
})
513-
val dummy = makeDummy
514-
val newMKont: Rep[MCont[Unit]] = funHere((_u: Rep[Unit]) => {
515-
restK(mkont)
516-
}, dummy)
523+
val newMKont: Rep[MCont[Unit]] = mkont.prependCont(restK)
517524
Frames.pushFrameC(locals)
518525
Frames.pushFrameS(locals)
519526
Frames.putAllC(argsC)
@@ -672,7 +679,7 @@ trait StagedWasmEvaluator extends SAIOps {
672679
ExploreTree.fillWithFinished()
673680
"no-op".reflectCtrlWith[Unit]()
674681
}
675-
val temp: Rep[MCont[Unit]] = topFun(haltK)
682+
val temp: Rep[MCont[Unit]] = makeInitMCont(topFun(haltK))
676683
evalTop(temp, main)
677684
}
678685

@@ -1356,6 +1363,7 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase {
13561363
else if (m.toString.endsWith("I64V")) "I64V"
13571364
else if (m.toString.endsWith("SymVal")) "SymVal"
13581365
else if (m.toString.endsWith("Snapshot")) "Snapshot_t"
1366+
else if (m.toString.endsWith("MCont[Unit]")) "MCont_t"
13591367
else super.remap(m)
13601368
}
13611369

@@ -1547,6 +1555,12 @@ trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase {
15471555
emit("ExploreTree.dump_graphviz("); shallow(f); emit(")")
15481556
case Node(_, "sym-not", List(s), _) =>
15491557
shallow(s); emit(".negate()")
1558+
case Node(_, "make-init-mcont", List(haltK), _) =>
1559+
emit("MCont_t("); shallow(haltK); emit(")")
1560+
case Node(_, "mcont-prepend", List(mkont, kont), _) =>
1561+
emit("prependCont("); shallow(kont); emit(", "); shallow(mkont); emit(")")
1562+
case Node(_, "mcont-enter", List(mkont), _) =>
1563+
shallow(mkont); emit(".enter()")
15501564
case Node(_, "dummy", _, _) => emit("std::monostate()")
15511565
case Node(_, "dummy-op", _, _) => emit("std::monostate()")
15521566
case Node(_, "no-op", _, _) =>

0 commit comments

Comments
 (0)