Skip to content

Commit c286f04

Browse files
author
Guannan Wei
committed
refactor module out of frame
1 parent c9d4b40 commit c286f04

File tree

3 files changed

+43
-46
lines changed

3 files changed

+43
-46
lines changed

src/main/scala/wasm/MiniWasm.scala

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ import Console.{GREEN, RED, RESET, YELLOW_B, UNDERLINED}
1111
case class Trap() extends Exception
1212

1313
case class ModuleInstance(
14-
types: List[FuncType],
15-
funcs: HashMap[Int, WIR],
16-
memory: List[RTMemory] = List(RTMemory()),
17-
globals: List[RTGlobal] = List(),
18-
exports: List[Export] = List()
14+
defs: List[Definition],
15+
types: List[FuncType],
16+
funcs: HashMap[Int, WIR],
17+
memory: List[RTMemory] = List(RTMemory()),
18+
globals: List[RTGlobal] = List(),
19+
exports: List[Export] = List()
1920
)
2021

2122
object ModuleInstance {
@@ -52,7 +53,7 @@ object ModuleInstance {
5253
})
5354
.toList
5455

55-
ModuleInstance(types, module.funcEnv, memory, globals, exports)
56+
ModuleInstance(module.definitions, types, module.funcEnv, memory, globals, exports)
5657
}
5758
}
5859

@@ -199,8 +200,8 @@ object Primtives {
199200
}
200201
}
201202

202-
def memOutOfBound(frame: Frame, memoryIndex: Int, offset: Int, size: Int) = {
203-
val memory = frame.module.memory(memoryIndex)
203+
def memOutOfBound(module: ModuleInstance, memoryIndex: Int, offset: Int, size: Int) = {
204+
val memory = module.memory(memoryIndex)
204205
offset + size > memory.size
205206
}
206207

@@ -217,22 +218,21 @@ object Primtives {
217218
}
218219
}
219220

220-
case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value])
221+
case class Frame(locals: ArrayBuffer[Value])
221222

222-
object Evaluator {
223+
case class Evaluator(module: ModuleInstance) {
223224
import Primtives._
224225

225226
type Cont[A] = List[Value] => A
226227

227-
def getFuncType(module: ModuleInstance, ty: BlockType): FuncType = {
228+
def getFuncType(ty: BlockType): FuncType =
228229
ty match {
229230
case VarBlockType(_, None) =>
230231
??? // TODO: fill this branch until we handle type index correctly
231232
case VarBlockType(_, Some(tipe)) => tipe
232233
case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe))
233234
case ValBlockType(None) => FuncType(List(), List(), List())
234235
}
235-
}
236236

237237
def evalCall[Ans](rest: List[Instr],
238238
stack: List[Value],
@@ -241,12 +241,12 @@ object Evaluator {
241241
trail: List[Cont[Ans]],
242242
funcIndex: Int,
243243
isTail: Boolean): Ans = {
244-
frame.module.funcs(funcIndex) match {
244+
module.funcs(funcIndex) match {
245245
case FuncDef(_, FuncBodyDef(ty, _, locals, body)) =>
246246
val args = stack.take(ty.inps.size).reverse
247247
val newStack = stack.drop(ty.inps.size)
248248
val frameLocals = args ++ locals.map(zero(_))
249-
val newFrame = Frame(frame.module, ArrayBuffer(frameLocals: _*))
249+
val newFrame = Frame(ArrayBuffer(frameLocals: _*))
250250
if (isTail)
251251
// when tail call, share the continuation for returning with the callee
252252
eval(body, List(), newFrame, kont, List(kont))
@@ -296,21 +296,21 @@ object Evaluator {
296296
frame.locals(i) = value
297297
eval(rest, stack, frame, kont, trail)
298298
case GlobalGet(i) =>
299-
eval(rest, frame.module.globals(i).value :: stack, frame, kont, trail)
299+
eval(rest, module.globals(i).value :: stack, frame, kont, trail)
300300
case GlobalSet(i) =>
301301
val value :: newStack = stack
302-
frame.module.globals(i).ty match {
302+
module.globals(i).ty match {
303303
case GlobalType(tipe, true) if value.tipe == tipe =>
304-
frame.module.globals(i).value = value
304+
module.globals(i).value = value
305305
case GlobalType(_, true) => throw new Exception("Invalid type")
306306
case _ => throw new Exception("Cannot set immutable global")
307307
}
308308
eval(rest, newStack, frame, kont, trail)
309309
case MemorySize =>
310-
eval(rest, I32V(frame.module.memory.head.size) :: stack, frame, kont, trail)
310+
eval(rest, I32V(module.memory.head.size) :: stack, frame, kont, trail)
311311
case MemoryGrow =>
312312
val I32V(delta) :: newStack = stack
313-
val mem = frame.module.memory.head
313+
val mem = module.memory.head
314314
val oldSize = mem.size
315315
mem.grow(delta) match {
316316
case Some(e) =>
@@ -320,18 +320,18 @@ object Evaluator {
320320
}
321321
case MemoryFill =>
322322
val I32V(value) :: I32V(offset) :: I32V(size) :: newStack = stack
323-
if (memOutOfBound(frame, 0, offset, size))
323+
if (memOutOfBound(module, 0, offset, size))
324324
throw new Exception("Out of bounds memory access") // GW: turn this into a `trap`?
325325
else {
326-
frame.module.memory.head.fill(offset, size, value.toByte)
326+
module.memory.head.fill(offset, size, value.toByte)
327327
eval(rest, newStack, frame, kont, trail)
328328
}
329329
case MemoryCopy =>
330330
val I32V(n) :: I32V(src) :: I32V(dest) :: newStack = stack
331-
if (memOutOfBound(frame, 0, src, n) || memOutOfBound(frame, 0, dest, n))
331+
if (memOutOfBound(module, 0, src, n) || memOutOfBound(module, 0, dest, n))
332332
throw new Exception("Out of bounds memory access")
333333
else {
334-
frame.module.memory.head.copy(dest, src, n)
334+
module.memory.head.copy(dest, src, n)
335335
eval(rest, newStack, frame, kont, trail)
336336
}
337337
case Const(n) => eval(rest, n :: stack, frame, kont, trail)
@@ -349,17 +349,17 @@ object Evaluator {
349349
eval(rest, evalTestOp(op, v) :: newStack, frame, kont, trail)
350350
case Store(StoreOp(align, offset, ty, None)) =>
351351
val I32V(v) :: I32V(addr) :: newStack = stack
352-
frame.module.memory(0).storeInt(addr + offset, v)
352+
module.memory(0).storeInt(addr + offset, v)
353353
eval(rest, newStack, frame, kont, trail)
354354
case Load(LoadOp(align, offset, ty, None, None)) =>
355355
val I32V(addr) :: newStack = stack
356-
val value = frame.module.memory(0).loadInt(addr + offset)
356+
val value = module.memory(0).loadInt(addr + offset)
357357
eval(rest, I32V(value) :: newStack, frame, kont, trail)
358358
case Nop =>
359359
eval(rest, stack, frame, kont, trail)
360360
case Unreachable => throw Trap()
361361
case Block(ty, inner) =>
362-
val funcTy = getFuncType(frame.module, ty)
362+
val funcTy = getFuncType(ty)
363363
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
364364
val restK: Cont[Ans] = (retStack) =>
365365
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail)
@@ -368,15 +368,15 @@ object Evaluator {
368368
// We construct two continuations, one for the break (to the begining of the loop),
369369
// and one for fall-through to the next instruction following the syntactic structure
370370
// of the program.
371-
val funcTy = getFuncType(frame.module, ty)
371+
val funcTy = getFuncType(ty)
372372
val (inputs, restStack) = stack.splitAt(funcTy.inps.size)
373373
val restK: Cont[Ans] = (retStack) =>
374374
eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail)
375375
def loop(retStack: List[Value]): Ans =
376376
eval(inner, retStack.take(funcTy.inps.size), frame, restK, loop _ :: trail)
377377
loop(inputs)
378378
case If(ty, thn, els) =>
379-
val funcTy = getFuncType(frame.module, ty)
379+
val funcTy = getFuncType(ty)
380380
val I32V(cond) :: newStack = stack
381381
val inner = if (cond != 0) thn else els
382382
val (inputs, restStack) = newStack.splitAt(funcTy.inps.size)
@@ -422,38 +422,33 @@ object Evaluator {
422422

423423
// If `main` is given, then we use that function as the entry point of the program;
424424
// otherwise, we look up the top-level `start` instruction to locate the entry point.
425-
def evalTop[Ans](module: Module, halt: Cont[Ans], main: Option[String] = None): Ans = {
425+
def evalTop[Ans](halt: Cont[Ans], main: Option[String] = None): Ans = {
426426
val instrs = main match {
427427
case Some(func_name) =>
428-
module.definitions.flatMap({
428+
module.defs.flatMap({
429429
case Export(`func_name`, ExportFunc(fid)) =>
430430
println(s"Entering function $main")
431-
module.funcEnv(fid) match {
431+
module.funcs(fid) match {
432432
case FuncDef(_, FuncBodyDef(_, _, _, body)) => body
433-
case _ =>
434-
throw new Exception("Entry function has no concrete body")
433+
case _ => throw new Exception("Entry function has no concrete body")
435434
}
436435
case _ => List()
437436
})
438437
case None =>
439-
module.definitions.flatMap({
438+
module.defs.flatMap({
440439
case Start(id) =>
441440
println(s"Entering unnamed function $id")
442-
module.funcEnv(id) match {
441+
module.funcs(id) match {
443442
case FuncDef(_, FuncBodyDef(_, _, _, body)) => body
444443
case _ =>
445444
throw new Exception("Entry function has no concrete body")
446445
}
447446
case _ => List()
448447
})
449448
}
450-
451449
if (instrs.isEmpty) println("Warning: nothing is executed")
452-
453-
val moduleInst = ModuleInstance(module)
454-
455-
Evaluator.eval(instrs, List(), Frame(moduleInst, ArrayBuffer(I32V(0))), halt, List(halt))
450+
eval(instrs, List(), Frame(ArrayBuffer(I32V(0))), halt, List(halt))
456451
}
457452

458-
def evalTop(m: Module): Unit = evalTop(m, stack => ())
453+
def evalTop(m: ModuleInstance): Unit = evalTop(stack => ())
459454
}

src/main/scala/wasm/MiniWasmScript.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ sealed class ScriptRunner {
1010

1111
def getInstance(instName: Option[String]): ModuleInstance = {
1212
instName match {
13-
case Some(name) => instanceMap(name)
13+
case Some(name) => instanceMap(name)
1414
case None => instances.head
1515
}
1616
}
@@ -28,7 +28,8 @@ sealed class ScriptRunner {
2828
case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => body
2929
}
3030
val k = (retStack: List[Value]) => retStack
31-
val actual = Evaluator.eval(instrs, List(), Frame(module, ArrayBuffer(args: _*)), k, List(k))
31+
val evaluator = Evaluator(module)
32+
val actual = evaluator.eval(instrs, List(), Frame(ArrayBuffer(args: _*)), k, List(k))
3233
assert(actual == expect)
3334
}
3435
}

src/test/scala/genwasym/TestEval.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ class TestEval extends FunSuite {
2424
def testFile(filename: String, main: Option[String] = None, expected: ExpResult = Ignore) = {
2525
val module = Parser.parseFile(filename)
2626
//println(module)
27-
val haltK: Evaluator.Cont[Unit] = stack => {
27+
val evaluator = Evaluator(ModuleInstance(module))
28+
val haltK: evaluator.Cont[Unit] = stack => {
2829
println(s"halt cont: $stack")
2930
expected match {
3031
case ExpInt(e) => assert(stack(0) == I32V(e))
3132
case ExpStack(e) => assert(stack == e)
3233
case Ignore => ()
3334
}
3435
}
35-
Evaluator.evalTop(module, haltK, main)
36+
evaluator.evalTop(haltK, main)
3637
}
3738

3839
// TODO: the power test can be used to test the stack
@@ -79,7 +80,7 @@ class TestEval extends FunSuite {
7980
testFile("./benchmarks/wasm/wasmfx/cont1-stripped.wat")
8081
}
8182

82-
// can parse this file,
83+
// can parse this file,
8384
// but there's no support for ref.func, cont.new, suspend, resume to run it yet
8485
// test("gen") {
8586
// testFile("./benchmarks/wasm/wasmfx/gen-stripped.wat")

0 commit comments

Comments
 (0)