From 9da35523a2dd5876999fc5d3436e8f2850286508 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 2 Jun 2024 14:26:05 -0400 Subject: [PATCH 01/31] CPS --- cps.py | 323 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 323 insertions(+) create mode 100644 cps.py diff --git a/cps.py b/cps.py new file mode 100644 index 00000000..73275862 --- /dev/null +++ b/cps.py @@ -0,0 +1,323 @@ +import dataclasses +import itertools +import unittest +from scrapscript import parse, tokenize, Assign, Int, Var as ScrapVar, Object, Binop, BinopKind, Where + + +@dataclasses.dataclass +class CPSExpr: + pass + + +@dataclasses.dataclass +class Atom(CPSExpr): + value: object + + def __repr__(self) -> str: + return repr(self.value) + + +@dataclasses.dataclass +class Var(CPSExpr): + name: str + + def __repr__(self) -> str: + return self.name + + +@dataclasses.dataclass +class Prim(CPSExpr): + op: str + args: list[CPSExpr] + + def __repr__(self) -> str: + return f"({self.op} {' '.join(map(repr, self.args))})" + + +@dataclasses.dataclass +class Fun(CPSExpr): + args: list[CPSExpr] + body: CPSExpr + + def __repr__(self) -> str: + args = " ".join(map(repr, self.args)) + return f"(fun ({args}) {self.body!r})" + + +@dataclasses.dataclass +class App(CPSExpr): + fun: CPSExpr + args: list[CPSExpr] + + def __repr__(self) -> str: + return f"({self.fun!r} {' '.join(map(repr, self.args))})" + + +cps_counter = itertools.count() + + +def gensym() -> str: + return f"v{next(cps_counter)}" + + +def cont(arg: Var, body: CPSExpr) -> CPSExpr: + return Fun([arg], body) + + +def cps(exp: Object, k: CPSExpr) -> CPSExpr: + if isinstance(exp, Int): + return App(k, [Atom(exp.value)]) + if isinstance(exp, ScrapVar): + return App(k, [Var(exp.name)]) + if isinstance(exp, Binop): + left = Var(gensym()) + right = Var(gensym()) + return cps(exp.left, cont(left, cps(exp.right, cont(right, Prim(BinopKind.to_str(exp.op), [left, right, k]))))) + if isinstance(exp, Where): + assert isinstance(exp.binding, Assign) + assert isinstance(exp.binding.name, ScrapVar) + name = exp.binding.name.name + value = exp.binding.value + body = exp.body + return cps(value, cont(Var(name), cps(body, k))) + raise NotImplementedError(f"cps: {exp}") + + +class CPSTests(unittest.TestCase): + def setUp(self) -> None: + global cps_counter + cps_counter = itertools.count() + + def test_atom(self) -> None: + self.assertEqual(cps(Int(42), Var("k")), App(Var("k"), [Atom(42)])) + + def test_var(self) -> None: + self.assertEqual(cps(ScrapVar("x"), Var("k")), App(Var("k"), [Var("x")])) + + def test_binop(self) -> None: + self.assertEqual( + cps(parse(tokenize("1 + 2")), Var("k")), + # ((fun (v0) ((fun (v1) (+ v0 v1 k)) 2)) 1) + App( + Fun( + [Var("v0")], + App( + Fun( + [Var("v1")], + Prim("+", [Var("v0"), Var("v1"), Var("k")]), + ), + [Atom(2)], + ), + ), + [Atom(1)], + ), + ) + + def test_where(self) -> None: + exp = parse(tokenize("a + b . a = 1 . b = 2")) + self.assertEqual( + cps(exp, Var("k")), + # ((fun (b) ((fun (a) ((fun (v0) ((fun (v1) (+ v0 v1 k)) b)) a)) 1)) 2) + App( + Fun( + [Var("b")], + App( + Fun( + [Var("a")], + App( + Fun( + [Var("v0")], + App( + Fun( + [Var("v1")], + Prim("+", [Var("v0"), Var("v1"), Var("k")]), + ), + [Var("b")], + ), + ), + [Var("a")], + ), + ), + [Atom(1)], + ), + ), + [Atom(2)], + ), + ) + + +def arg_name(arg: CPSExpr) -> str: + assert isinstance(arg, Var) + return arg.name + + +def alphatise_(exp: CPSExpr, env: dict[str, str]) -> CPSExpr: + if isinstance(exp, Atom): + return exp + if isinstance(exp, Var): + return Var(env.get(exp.name, exp.name)) + if isinstance(exp, Prim): + return Prim(exp.op, [alphatise_(arg, env) for arg in exp.args]) + if isinstance(exp, Fun): + new_env = {arg_name(arg): gensym() for arg in exp.args} + new_body = alphatise_(exp.body, {**env, **new_env}) + return Fun([Var(new_env[arg_name(arg)]) for arg in exp.args], new_body) + if isinstance(exp, App): + return App(alphatise_(exp.fun, env), [alphatise_(arg, env) for arg in exp.args]) + raise NotImplementedError(f"alphatise: {exp}") + + +def alphatise(exp: CPSExpr) -> CPSExpr: + return alphatise_(exp, {}) + + +class AlphatiseTests(unittest.TestCase): + def setUp(self) -> None: + global cps_counter + cps_counter = itertools.count() + + def test_atom(self) -> None: + self.assertEqual(alphatise(Atom(42)), Atom(42)) + + def test_var(self) -> None: + self.assertEqual(alphatise(Var("x")), Var("x")) + + def test_prim(self) -> None: + exp = Prim("+", [Var("x"), Var("y"), Var("z")]) + self.assertEqual( + alphatise_(exp, {"x": "v0", "y": "v1"}), + Prim("+", [Var("v0"), Var("v1"), Var("z")]), + ) + + def test_fun(self) -> None: + exp = Fun([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) + self.assertEqual( + alphatise(exp), + Fun( + [Var("v0"), Var("v1")], + Prim("+", [Var("v0"), Var("v1"), Var("z")]), + ), + ) + + def test_app(self) -> None: + exp = App(Var("f"), [Var("x"), Var("y")]) + self.assertEqual(alphatise_(exp, {"x": "v0", "y": "v1"}), App(Var("f"), [Var("v0"), Var("v1")])) + + +def subst(exp: CPSExpr, env: dict[str, CPSExpr]) -> CPSExpr: + if isinstance(exp, Atom): + return exp + if isinstance(exp, Var): + return env.get(exp.name, exp) + if isinstance(exp, Prim): + return Prim(exp.op, [subst(arg, env) for arg in exp.args]) + if isinstance(exp, Fun): + new_env = {arg_name(arg): Var(gensym()) for arg in exp.args} + new_body = subst(exp.body, {**env, **new_env}) + return Fun([Var(new_env[arg_name(arg)].name) for arg in exp.args], new_body) + if isinstance(exp, App): + return App(subst(exp.fun, env), [subst(arg, env) for arg in exp.args]) + raise NotImplementedError(f"subst: {exp}") + + +class SubstTests(unittest.TestCase): + def setUp(self) -> None: + global cps_counter + cps_counter = itertools.count() + + def test_atom(self) -> None: + self.assertEqual(subst(Atom(42), {}), Atom(42)) + + def test_var(self) -> None: + self.assertEqual(subst(Var("x"), {}), Var("x")) + self.assertEqual(subst(Var("x"), {"x": Atom(42)}), Atom(42)) + + def test_prim(self) -> None: + exp = Prim("+", [Var("x"), Var("y"), Var("z")]) + self.assertEqual( + subst(exp, {"x": Atom(1), "z": Atom(3)}), + Prim("+", [Atom(1), Var("y"), Atom(3)]), + ) + + def test_fun(self) -> None: + exp = Fun([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) + self.assertEqual( + subst(exp, {"z": Atom(3)}), + Fun( + [Var("v0"), Var("v1")], + Prim("+", [Var("v0"), Var("v1"), Atom(3)]), + ), + ) + + def test_app(self) -> None: + exp = App(Var("f"), [Var("x"), Var("y")]) + self.assertEqual(subst(exp, {"x": Atom(1), "y": Atom(2)}), App(Var("f"), [Atom(1), Atom(2)])) + + +def is_simple(exp: CPSExpr) -> bool: + return isinstance(exp, (Atom, Var)) + + +def opt(exp: CPSExpr) -> CPSExpr: + if isinstance(exp, Atom): + return exp + if isinstance(exp, Var): + return exp + if isinstance(exp, Prim): + args = [opt(arg) for arg in exp.args] + consts = [arg for arg in args if isinstance(arg, Atom)] + vars = [arg for arg in args if not isinstance(arg, Atom)] + if exp.op == "+": + consts = [Atom(sum(c.value for c in consts))] # type: ignore + args = consts + vars + if len(args) == 1: + return args[0] + return Prim(exp.op, args) + if isinstance(exp, App) and isinstance(exp.fun, Fun): + fun = opt(exp.fun) + assert isinstance(fun, Fun) + formals = exp.fun.args + actuals = [opt(arg) for arg in exp.args] + if len(formals) != len(actuals): + return App(fun, actuals) + if all(is_simple(arg) for arg in actuals): + new_env = {arg_name(formal): actual for formal, actual in zip(formals, actuals)} + return subst(fun.body, new_env) + return exp + + +def spin_opt(exp: CPSExpr) -> CPSExpr: + while True: + new_exp = opt(exp) + if new_exp == exp: + return exp + exp = new_exp + + +class OptTests(unittest.TestCase): + def setUp(self) -> None: + global cps_counter + cps_counter = itertools.count() + + def test_prim(self) -> None: + exp = Prim("+", [Atom(1), Atom(2), Atom(3)]) + self.assertEqual(opt(exp), Atom(6)) + + def test_prim_var(self) -> None: + exp = Prim("+", [Atom(1), Var("x"), Atom(3)]) + self.assertEqual(opt(exp), Prim("+", [Atom(4), Var("x")])) + + def test_subst(self) -> None: + exp = App(Fun([Var("x")], Prim("+", [Atom(1), Var("x"), Atom(2)])), [Atom(3)]) + self.assertEqual(spin_opt(exp), Atom(6)) + + def test_add(self) -> None: + exp = parse(tokenize("1 + 2 + c")) + self.assertEqual( + spin_opt(cps(exp, Var("k"))), + Prim("+", [Atom(2), Var("c"), Fun([Var("v9")], Prim("+", [Atom(1), Var("v9"), Var("k")]))]), + ) + + +if __name__ == "__main__": + unittest.main() From 41ac07db6768bad19c3ce8fd3737a33cae0f1316 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 3 Jun 2024 10:54:02 -0400 Subject: [PATCH 02/31] More! --- cps.py | 79 ++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 11 deletions(-) diff --git a/cps.py b/cps.py index 73275862..f073e2c2 100644 --- a/cps.py +++ b/cps.py @@ -1,7 +1,7 @@ import dataclasses import itertools import unittest -from scrapscript import parse, tokenize, Assign, Int, Var as ScrapVar, Object, Binop, BinopKind, Where +from scrapscript import parse, tokenize, Assign, Int, Var as ScrapVar, Object, Binop, BinopKind, Where, Apply, Function @dataclasses.dataclass @@ -80,6 +80,15 @@ def cps(exp: Object, k: CPSExpr) -> CPSExpr: value = exp.binding.value body = exp.body return cps(value, cont(Var(name), cps(body, k))) + if isinstance(exp, Apply): + fun = Var(gensym()) + arg = Var(gensym()) + return cps(exp.func, cont(fun, cps(exp.arg, cont(arg, App(fun, [arg, k]))))) + if isinstance(exp, Function): + assert isinstance(exp.arg, ScrapVar) + arg = Var(exp.arg.name) + subk = Var(gensym()) + return App(k, [Fun([arg, subk], cps(exp.body, subk))]) raise NotImplementedError(f"cps: {exp}") @@ -255,7 +264,7 @@ def test_app(self) -> None: def is_simple(exp: CPSExpr) -> bool: - return isinstance(exp, (Atom, Var)) + return isinstance(exp, (Atom, Var, Fun)) def opt(exp: CPSExpr) -> CPSExpr: @@ -270,8 +279,9 @@ def opt(exp: CPSExpr) -> CPSExpr: if exp.op == "+": consts = [Atom(sum(c.value for c in consts))] # type: ignore args = consts + vars - if len(args) == 1: - return args[0] + if len(args) == 2: + # Last argument is a cont + return App(args[1], [args[0]]) return Prim(exp.op, args) if isinstance(exp, App) and isinstance(exp.fun, Fun): fun = opt(exp.fun) @@ -283,6 +293,14 @@ def opt(exp: CPSExpr) -> CPSExpr: if all(is_simple(arg) for arg in actuals): new_env = {arg_name(formal): actual for formal, actual in zip(formals, actuals)} return subst(fun.body, new_env) + return App(fun, actuals) + if isinstance(exp, App): + fun = opt(exp.fun) + args = [opt(arg) for arg in exp.args] + return App(fun, args) + if isinstance(exp, Fun): + body = opt(exp.body) + return Fun(exp.args, body) return exp @@ -300,22 +318,61 @@ def setUp(self) -> None: cps_counter = itertools.count() def test_prim(self) -> None: - exp = Prim("+", [Atom(1), Atom(2), Atom(3)]) - self.assertEqual(opt(exp), Atom(6)) + exp = Prim("+", [Atom(1), Atom(2), Atom(3), Var("k")]) + self.assertEqual(opt(exp), App(Var("k"), [Atom(6)])) def test_prim_var(self) -> None: - exp = Prim("+", [Atom(1), Var("x"), Atom(3)]) - self.assertEqual(opt(exp), Prim("+", [Atom(4), Var("x")])) + exp = Prim("+", [Atom(1), Var("x"), Atom(3), Var("k")]) + self.assertEqual(opt(exp), Prim("+", [Atom(4), Var("x"), Var("k")])) def test_subst(self) -> None: - exp = App(Fun([Var("x")], Prim("+", [Atom(1), Var("x"), Atom(2)])), [Atom(3)]) - self.assertEqual(spin_opt(exp), Atom(6)) + exp = App(Fun([Var("x")], Prim("+", [Atom(1), Var("x"), Atom(2), Var("k")])), [Atom(3)]) + self.assertEqual(spin_opt(exp), App(Var("k"), [Atom(6)])) def test_add(self) -> None: exp = parse(tokenize("1 + 2 + c")) self.assertEqual( spin_opt(cps(exp, Var("k"))), - Prim("+", [Atom(2), Var("c"), Fun([Var("v9")], Prim("+", [Atom(1), Var("v9"), Var("k")]))]), + Prim("+", [Atom(2), Var("c"), Fun([Var("v6")], Prim("+", [Atom(1), Var("v6"), Var("k")]))]), + ) + + def test_simple_fun(self) -> None: + exp = cps(parse(tokenize("_ -> 1")), Var("k")) + self.assertEqual( + spin_opt(exp), + # (k (fun (_ v0) (v0 1))) + App( + Var("k"), + [ + Fun( + [Var("_"), Var("v0")], + App(Var("v0"), [Atom(1)]), + ) + ], + ), + ) + + def test_fun(self) -> None: + exp = cps(parse(tokenize("_ -> 1 + 2 + 3")), Var("k")) + self.assertEqual( + spin_opt(exp), + # (k (fun (_ v0) (v0 6))) + App( + Var("k"), + [ + Fun( + [Var("_"), Var("v0")], + App(Var("v0"), [Atom(6)]), + ) + ], + ), + ) + + def test_add_function(self) -> None: + exp = parse(tokenize("add a b . add = x -> y -> x + y . a = 3 . b = 4")) + self.assertEqual( + spin_opt(cps(exp, Var("k"))), + App(Var("k"), [Atom(7)]), ) From c1383df257b7f1ef47985304afbf07ef7ea0ee65 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 3 Jun 2024 11:28:54 -0400 Subject: [PATCH 03/31] Add list cons --- cps.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/cps.py b/cps.py index f073e2c2..60276491 100644 --- a/cps.py +++ b/cps.py @@ -1,7 +1,20 @@ import dataclasses import itertools import unittest -from scrapscript import parse, tokenize, Assign, Int, Var as ScrapVar, Object, Binop, BinopKind, Where, Apply, Function +from scrapscript import ( + parse, + tokenize, + Assign, + Int, + Var as ScrapVar, + Object, + Binop, + BinopKind, + Where, + Apply, + Function, + List, +) @dataclasses.dataclass @@ -31,7 +44,7 @@ class Prim(CPSExpr): args: list[CPSExpr] def __repr__(self) -> str: - return f"({self.op} {' '.join(map(repr, self.args))})" + return f"(${self.op} {' '.join(map(repr, self.args))})" @dataclasses.dataclass @@ -89,6 +102,13 @@ def cps(exp: Object, k: CPSExpr) -> CPSExpr: arg = Var(exp.arg.name) subk = Var(gensym()) return App(k, [Fun([arg, subk], cps(exp.body, subk))]) + if isinstance(exp, List) and not exp.items: + return App(k, [Atom([])]) + if isinstance(exp, List): + items = exp.items + head = Var(gensym()) + tail = Var(gensym()) + return cps(items[0], cont(head, cps(List(items[1:]), cont(tail, Prim("cons", [head, tail, k]))))) raise NotImplementedError(f"cps: {exp}") @@ -154,6 +174,9 @@ def test_where(self) -> None: ), ) + def test_empty_list(self) -> None: + self.assertEqual(cps(List([]), Var("k")), App(Var("k"), [Atom([])])) + def arg_name(arg: CPSExpr) -> str: assert isinstance(arg, Var) @@ -273,16 +296,20 @@ def opt(exp: CPSExpr) -> CPSExpr: if isinstance(exp, Var): return exp if isinstance(exp, Prim): - args = [opt(arg) for arg in exp.args] - consts = [arg for arg in args if isinstance(arg, Atom)] - vars = [arg for arg in args if not isinstance(arg, Atom)] + args = [opt(arg) for arg in exp.args[:-1]] + cont = exp.args[-1] + if exp.op == "cons": + assert len(args) == 2 + if all(isinstance(arg, Atom) for arg in args): + return App(cont, [Atom(args)]) if exp.op == "+": + consts = [arg for arg in args if isinstance(arg, Atom)] + vars = [arg for arg in args if not isinstance(arg, Atom)] consts = [Atom(sum(c.value for c in consts))] # type: ignore args = consts + vars - if len(args) == 2: - # Last argument is a cont - return App(args[1], [args[0]]) - return Prim(exp.op, args) + if len(args) == 1: + return App(cont, args) + return Prim(exp.op, args + [cont]) if isinstance(exp, App) and isinstance(exp.fun, Fun): fun = opt(exp.fun) assert isinstance(fun, Fun) @@ -375,6 +402,33 @@ def test_add_function(self) -> None: App(Var("k"), [Atom(7)]), ) + def test_make_empty_list(self) -> None: + exp = parse(tokenize("[]")) + self.assertEqual(spin_opt(cps(exp, Var("k"))), App(Var("k"), [Atom([])])) + + def test_make_const_list(self) -> None: + exp = parse(tokenize("[1+2, 2+3, 3+4]")) + self.assertEqual( + spin_opt(cps(exp, Var("k"))), + App(Var("k"), [Atom([Atom(3), Atom([Atom(5), Atom([Atom(7), Atom([])])])])]), + ) + + def test_make_list(self) -> None: + exp = parse(tokenize("[1+2, x, 3+4]")) + self.assertEqual( + spin_opt(cps(exp, Var("k"))), + # ($cons x [7, []] (fun (v46) ($cons 3 v46 k))) + Prim( + "cons", + [ + Var("x"), + Atom([Atom(7), Atom([])]), + Fun([Var("v46")], Prim("cons", [Atom(3), Var("v46"), Var("k")])), + ], + ), + ) + if __name__ == "__main__": + __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 unittest.main() From 5c813385e9446c6928c964babaa2ae5592e38e4a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 3 Jun 2024 11:36:40 -0400 Subject: [PATCH 04/31] . --- cps.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/cps.py b/cps.py index 60276491..3170b962 100644 --- a/cps.py +++ b/cps.py @@ -305,8 +305,9 @@ def opt(exp: CPSExpr) -> CPSExpr: if exp.op == "+": consts = [arg for arg in args if isinstance(arg, Atom)] vars = [arg for arg in args if not isinstance(arg, Atom)] - consts = [Atom(sum(c.value for c in consts))] # type: ignore - args = consts + vars + if consts: + consts = [Atom(sum(c.value for c in consts))] # type: ignore + args = consts + vars if len(args) == 1: return App(cont, args) return Prim(exp.op, args + [cont]) @@ -396,6 +397,14 @@ def test_fun(self) -> None: ) def test_add_function(self) -> None: + exp = parse(tokenize("(x -> y -> x + y) a b")) + self.assertEqual( + spin_opt(cps(exp, Var("k"))), + # ($+ a b k) + Prim("+", [Var("a"), Var("b"), Var("k")]), + ) + + def test_fold_add_function(self) -> None: exp = parse(tokenize("add a b . add = x -> y -> x + y . a = 3 . b = 4")) self.assertEqual( spin_opt(cps(exp, Var("k"))), From 94f005b693e2fa4956d078f8bf9d9ed26c9ebdeb Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 3 Jun 2024 11:40:26 -0400 Subject: [PATCH 05/31] . --- cps.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/cps.py b/cps.py index 3170b962..d0525fa4 100644 --- a/cps.py +++ b/cps.py @@ -306,6 +306,7 @@ def opt(exp: CPSExpr) -> CPSExpr: consts = [arg for arg in args if isinstance(arg, Atom)] vars = [arg for arg in args if not isinstance(arg, Atom)] if consts: + # TODO(max): Only sum ints consts = [Atom(sum(c.value for c in consts))] # type: ignore args = consts + vars if len(args) == 1: @@ -397,6 +398,30 @@ def test_fun(self) -> None: ) def test_add_function(self) -> None: + exp = parse(tokenize("x -> y -> x + y")) + self.assertEqual( + spin_opt(cps(exp, Var("k"))), + # (k (fun (x v0) (v0 (fun (y v1) ($+ x y v1))))) + App( + Var("k"), + [ + Fun( + [Var("x"), Var("v0")], + App( + Var("v0"), + [ + Fun( + [Var("y"), Var("v1")], + Prim("+", [Var("x"), Var("y"), Var("v1")]), + ) + ], + ), + ) + ], + ), + ) + + def test_fold_add_function_var(self) -> None: exp = parse(tokenize("(x -> y -> x + y) a b")) self.assertEqual( spin_opt(cps(exp, Var("k"))), @@ -404,7 +429,7 @@ def test_add_function(self) -> None: Prim("+", [Var("a"), Var("b"), Var("k")]), ) - def test_fold_add_function(self) -> None: + def test_fold_add_function_int(self) -> None: exp = parse(tokenize("add a b . add = x -> y -> x + y . a = 3 . b = 4")) self.assertEqual( spin_opt(cps(exp, Var("k"))), From 1144886e708ac545867732701b78906f011b8df1 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 3 Jun 2024 11:41:40 -0400 Subject: [PATCH 06/31] . --- cps.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/cps.py b/cps.py index d0525fa4..dc09fa3e 100644 --- a/cps.py +++ b/cps.py @@ -421,6 +421,22 @@ def test_add_function(self) -> None: ), ) + def test_fold_add_function_curried(self) -> None: + exp = parse(tokenize("(x -> y -> x + y) 3")) + self.assertEqual( + spin_opt(cps(exp, Var("k"))), + # (k (fun (v6 v7) ($+ 3 v6 v7))) + App( + Var("k"), + [ + Fun( + [Var("v6"), Var("v7")], + Prim("+", [Atom(3), Var("v6"), Var("v7")]), + ) + ], + ), + ) + def test_fold_add_function_var(self) -> None: exp = parse(tokenize("(x -> y -> x + y) a b")) self.assertEqual( From 464ec1712a9ad892e2f04b9e9389c85c6b5d70d9 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 3 Jun 2024 12:35:47 -0400 Subject: [PATCH 07/31] Support variants --- cps.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/cps.py b/cps.py index dc09fa3e..38189239 100644 --- a/cps.py +++ b/cps.py @@ -14,6 +14,7 @@ Apply, Function, List, + Variant, ) @@ -109,6 +110,9 @@ def cps(exp: Object, k: CPSExpr) -> CPSExpr: head = Var(gensym()) tail = Var(gensym()) return cps(items[0], cont(head, cps(List(items[1:]), cont(tail, Prim("cons", [head, tail, k]))))) + if isinstance(exp, Variant): + tag_value = Var(gensym()) + return cps(exp.value, cont(tag_value, Prim("tag", [Atom(exp.tag), tag_value, k]))) raise NotImplementedError(f"cps: {exp}") @@ -177,6 +181,13 @@ def test_where(self) -> None: def test_empty_list(self) -> None: self.assertEqual(cps(List([]), Var("k")), App(Var("k"), [Atom([])])) + def test_variant(self) -> None: + self.assertEqual( + cps(parse(tokenize("# a_tag 123")), Var("k")), + # ((fun (v0) ($tag 'a_tag' v0 k)) 123) + App(Fun([Var("v0")], Prim("tag", [Atom("a_tag"), Var("v0"), Var("k")])), [Atom(123)]), + ) + def arg_name(arg: CPSExpr) -> str: assert isinstance(arg, Var) @@ -478,6 +489,14 @@ def test_make_list(self) -> None: ), ) + def test_variant(self) -> None: + exp = parse(tokenize("# a_tag 123")) + self.assertEqual( + spin_opt(cps(exp, Var("k"))), + # ($tag 'a_tag' 123 k) + Prim("tag", [Atom("a_tag"), Atom(123), Var("k")]), + ) + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 From 1cde7098b64342220daa8e91fde5358271f295cc Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 16:50:44 -0400 Subject: [PATCH 08/31] wip ssa --- cps.py | 316 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 313 insertions(+), 3 deletions(-) diff --git a/cps.py b/cps.py index 38189239..21f88f84 100644 --- a/cps.py +++ b/cps.py @@ -20,7 +20,7 @@ @dataclasses.dataclass class CPSExpr: - pass + annotations: dict[str, object] = dataclasses.field(default_factory=dict, init=False) @dataclasses.dataclass @@ -48,10 +48,17 @@ def __repr__(self) -> str: return f"(${self.op} {' '.join(map(repr, self.args))})" +fun_counter = itertools.count() + + @dataclasses.dataclass class Fun(CPSExpr): - args: list[CPSExpr] + args: list[Var] body: CPSExpr + id: int = dataclasses.field(default_factory=lambda: next(fun_counter), compare=False) + + def name(self) -> str: + return f"fun{self.id}" def __repr__(self) -> str: args = " ".join(map(repr, self.args)) @@ -298,7 +305,7 @@ def test_app(self) -> None: def is_simple(exp: CPSExpr) -> bool: - return isinstance(exp, (Atom, Var, Fun)) + return isinstance(exp, (Atom, Var, Fun)) or (isinstance(exp, Prim) and exp.op in {"clo", "tag"}) def opt(exp: CPSExpr) -> CPSExpr: @@ -498,6 +505,309 @@ def test_variant(self) -> None: ) +def free_in(exp: CPSExpr) -> set[str]: + match exp: + case Atom(_): + return set() + case Var(name): + return {name} + case Prim(_, args): + return {name for arg in args for name in free_in(arg)} + case Fun(args, body): + return free_in(body) - {arg_name(arg) for arg in args} + case App(fun, args): + return free_in(fun) | {name for arg in args for name in free_in(arg)} + raise NotImplementedError(f"free_in: {exp}") + + +class FreeInTests(unittest.TestCase): + def test_atom(self) -> None: + self.assertEqual(free_in(Atom(42)), set()) + + def test_var(self) -> None: + self.assertEqual(free_in(Var("x")), {"x"}) + + def test_prim(self) -> None: + exp = Prim("+", [Var("x"), Var("y"), Var("z")]) + self.assertEqual(free_in(exp), {"x", "y", "z"}) + + def test_fun(self) -> None: + exp = Fun([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) + self.assertEqual(free_in(exp), {"z"}) + + def test_app(self) -> None: + exp = App(Var("f"), [Var("x"), Var("y")]) + self.assertEqual(free_in(exp), {"f", "x", "y"}) + + +def make_closures_explicit(exp: CPSExpr, replacements: dict[str, CPSExpr]) -> CPSExpr: + def rec(exp: CPSExpr) -> CPSExpr: + return make_closures_explicit(exp, replacements) + + match exp: + case Atom(_): + return exp + case Var(name): + if name in replacements: + return replacements[name] + return exp + case Prim(op, args): + return Prim(op, [rec(arg) for arg in args]) + case Fun(args, body): + freevars = sorted(free_in(exp)) + this = Var("this") + new_replacements = {fv: Prim("clo", [this, Atom(idx)]) for idx, fv in enumerate(freevars)} + body = make_closures_explicit(body, {**replacements, **new_replacements}) + return Fun([this] + args, body) + case App(fun, args): + return App(rec(fun), [rec(arg) for arg in args]) + raise NotImplementedError(f"make_closures_explicit: {exp}") + + +class ClosureTests(unittest.TestCase): + def test_no_freevars(self) -> None: + exp = Fun([Var("x")], Var("x")) + # (fun (this x) x) + self.assertEqual(make_closures_explicit(exp, {}), Fun([Var("this"), Var("x")], Var("x"))) + + def test_freevars(self) -> None: + exp = Fun([Var("k")], Prim("+", [Var("x"), Var("y"), Var("k")])) + # (fun (this k) ($+ ($clo this 0) ($clo this 1) k)) + self.assertEqual( + make_closures_explicit(exp, {}), + Fun( + [Var("this"), Var("k")], + Prim( + "+", + [ + Prim("clo", [Var("this"), Atom(0)]), + Prim("clo", [Var("this"), Atom(1)]), + Var("k"), + ], + ), + ), + ) + + def test_app_fun(self) -> None: + exp = App(Fun([Var("x")], Var("x")), [Atom(42)]) + # ((fun (this x) x) 42) + self.assertEqual( + make_closures_explicit(exp, {}), + App(Fun([Var("this"), Var("x")], Var("x")), [Atom(42)]), + ) + + def test_app(self) -> None: + exp = App(Var("f"), [Atom(42)]) + # (f 42) + self.assertEqual(make_closures_explicit(exp, {}), App(Var("f"), [Atom(42)])) + + +@dataclasses.dataclass +class CFun: + name: str + code: list[str] = dataclasses.field(default_factory=list) + + +class Compiler: + def __init__(self, main: CFun) -> None: + self.funs = [main] + self.fun: CFun = main + + def _emit(self, code: str) -> None: + self.fun.code.append(code) + + def _mktemp(self, expr: str) -> str: + name = gensym() + self._emit(f"struct object* {name} = {expr};") + return name + + def compile(self, exp: CPSExpr) -> str: + match exp: + case Var(name): + return name + case Atom(int(value)): + return self._mktemp(f"mknum({value})") + case Prim("clo", [Var("this"), Atom(idx)]): + return self._mktemp(f"closure_get(this, {idx})") + case Prim("+", [left, right, cont]): + left = self.compile(left) + right = self.compile(right) + result = self._mktemp(f"num_add({left}, {right})") + self._emit(f"return {cont}({result});") + return "" + # case Atom([]): + # return self._mktemp("empty_list()") + # case Atom(list(value_exprs)): + # values = [self.compile(value) for value in value_exprs] + # num_values = len(values) + # result = self._mktemp(f"list_cons({num_values})") + # for i, value in enumerate(values): + # self._emit(f"{result}[{i}] = {value};") + # return result + case App(Var("halt"), [arg]): + self._emit(f"return {self.compile(arg)};") + return "" + case App(fun, args): + assert isinstance(fun, Var), "((fun ...) ...) should be optimized out" + fun_name = fun.name + arg_names = [self.compile(arg) for arg in args] + result = self._mktemp(f"clo call {fun_name}({', '.join(arg_names)})") + self._emit(f"return {result};") + return "" + case Fun(_, _): + prev = self.fun + self.fun = CFun(exp.name()) + self.funs.append(self.fun) + result = self.compile_proc(exp) + self.fun = prev + return result + case _: + raise NotImplementedError(f"compile: {exp}") + + def compile_proc(self, exp: Fun) -> str: + args = [arg.name for arg in exp.args] + self._emit(f"object {exp.name()}({', '.join(args)}) {{") + self.compile(exp.body) + self._emit("}") + return exp.name() + + +class C: + def __init__(self) -> None: + self.funs = [] + + def G(self, exp: CPSExpr) -> str: + match exp: + case Atom(int(value)): + return str(value) + case Var(name): + return name + case App(k, [Fun(_, _)]): + assert isinstance(k, Var) + fun, name = self.G_proc(exp.args[0]) + self.funs.append(fun) + return f"return mkclosure({name});" + case App(k, [E]): + assert is_simple(E) + return f"return {E};" + case App(E, [*args, k]): + assert isinstance(E, Var) + assert all(is_simple(arg) for arg in args) + return self.G_cont(f"{E.name}({', '.join(str(arg) for arg in args)})", k) + case Prim("+", [x, y, k]): + assert is_simple(x) + assert is_simple(y) + return self.G_cont(f"{x} + {y}", k) + # TODO(max): j case + # TODO(max): Split cont and fun or annotate + case Prim("if", [cond, tk, fk]): + return f"if ({cond}) {{ {self.G(tk)} }} else {{ {self.G(fk)} }}" + case _: + raise NotImplementedError(f"G: {exp}") + + def G_cont(self, val: str, exp: CPSExpr) -> str: + match exp: + case Fun([res], M1): + return f"{res} <- {val}; {self.G(M1)}" + case Var(_): + return f"return {val};" + case _: + raise NotImplementedError(f"G_cont: {exp}") + + def G_proc(self, exp: Fun) -> str: + match exp: + case Fun([*args, _], M1): + return f"proc fun{exp.id}({', '.join(arg.name for arg in args)}) {{ {self.G(M1)} " + "}", f"fun{exp.id}" + case _: + raise NotImplementedError(f"G_proc: {exp}") + + def code(self) -> str: + return "\n\n".join(self.funs) + + +class GTests(unittest.TestCase): + def setUp(self) -> None: + global cps_counter + cps_counter = itertools.count() + + def test_app_cont(self) -> None: + # (E ... (fun (x) M1)) + exp = App(Var("f"), [Atom(1), Fun([Var("x")], App(Var("k"), [Var("x")]))]) + self.assertEqual(C().G(exp), "x <- f(1); return x;") + + def test_tailcall(self) -> None: + # (E ... k) + exp = App(Var("f"), [Atom(1), Var("k")]) + self.assertEqual(C().G(exp), "return f(1);") + + def test_return(self) -> None: + # (k E) + exp = App(Var("k"), [Atom(1)]) + self.assertEqual(C().G(exp), "return 1;") + + def test_if(self) -> None: + # ($if cond t f) + exp = Prim( + "if", + [ + Atom(1), + App(Var("k"), [Atom(2)]), + App(Var("k"), [Atom(3)]), + ], + ) + self.assertEqual(C().G(exp), "if (1) { return 2; } else { return 3; }") + + def test_add_cont(self) -> None: + # ($+ x y (fun (res) M1)) + exp = Prim("+", [Atom(1), Atom(2), Fun([Var("res")], App(Var("k"), [Var("res")]))]) + self.assertEqual(C().G(exp), "res <- 1 + 2; return res;") + + def test_add_cont_var(self) -> None: + # ($+ x y k) + exp = Prim("+", [Atom(1), Atom(2), Var("k")]) + self.assertEqual(C().G(exp), "return 1 + 2;") + + def test_proc(self) -> None: + exp = App(Var("k"), [Fun([Var("x"), Var("j")], Prim("+", [Var("x"), Atom(1), Var("j")]))]) + c = C() + code = c.G(exp) + self.assertEqual(c.code(), "proc fun45(x) { return x + 1; }") + self.assertEqual(code, "return mkclosure(fun45);") + + def test_add_fn(self) -> None: + exp = parse(tokenize("x -> y -> x + y")) + exp = cps(exp, Var("k")) + exp = alphatise(exp) + exp = spin_opt(exp) + exp = make_closures_explicit(exp, {}) + c = C() + code = c.G(exp) + self.assertEqual( + c.code(), + """ + +""", + ) + self.assertEqual(code, "") + + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 unittest.main() + + c = Compiler(CFun("main")) + exp = parse(tokenize("x -> y -> x + y")) + print(exp) + cps_exp = cps(exp, Var("halt")) + print(cps_exp) + alphaed = alphatise(cps_exp) + optimized = spin_opt(alphaed) + print(optimized) + closurized = make_closures_explicit(optimized, {}) + print(closurized) + print(c.G(closurized)) + # c.compile(closurized) + # for fun in c.funs: + # for line in fun.code: + # print(line) + # print() From d0815d7e9fb1c056a905dae7ab8d1c1448912e1d Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 16:54:28 -0400 Subject: [PATCH 09/31] . --- cps.py | 106 +++++++++++++-------------------------------------------- 1 file changed, 23 insertions(+), 83 deletions(-) diff --git a/cps.py b/cps.py index 21f88f84..1fa3c771 100644 --- a/cps.py +++ b/cps.py @@ -608,70 +608,6 @@ class CFun: code: list[str] = dataclasses.field(default_factory=list) -class Compiler: - def __init__(self, main: CFun) -> None: - self.funs = [main] - self.fun: CFun = main - - def _emit(self, code: str) -> None: - self.fun.code.append(code) - - def _mktemp(self, expr: str) -> str: - name = gensym() - self._emit(f"struct object* {name} = {expr};") - return name - - def compile(self, exp: CPSExpr) -> str: - match exp: - case Var(name): - return name - case Atom(int(value)): - return self._mktemp(f"mknum({value})") - case Prim("clo", [Var("this"), Atom(idx)]): - return self._mktemp(f"closure_get(this, {idx})") - case Prim("+", [left, right, cont]): - left = self.compile(left) - right = self.compile(right) - result = self._mktemp(f"num_add({left}, {right})") - self._emit(f"return {cont}({result});") - return "" - # case Atom([]): - # return self._mktemp("empty_list()") - # case Atom(list(value_exprs)): - # values = [self.compile(value) for value in value_exprs] - # num_values = len(values) - # result = self._mktemp(f"list_cons({num_values})") - # for i, value in enumerate(values): - # self._emit(f"{result}[{i}] = {value};") - # return result - case App(Var("halt"), [arg]): - self._emit(f"return {self.compile(arg)};") - return "" - case App(fun, args): - assert isinstance(fun, Var), "((fun ...) ...) should be optimized out" - fun_name = fun.name - arg_names = [self.compile(arg) for arg in args] - result = self._mktemp(f"clo call {fun_name}({', '.join(arg_names)})") - self._emit(f"return {result};") - return "" - case Fun(_, _): - prev = self.fun - self.fun = CFun(exp.name()) - self.funs.append(self.fun) - result = self.compile_proc(exp) - self.fun = prev - return result - case _: - raise NotImplementedError(f"compile: {exp}") - - def compile_proc(self, exp: Fun) -> str: - args = [arg.name for arg in exp.args] - self._emit(f"object {exp.name()}({', '.join(args)}) {{") - self.compile(exp.body) - self._emit("}") - return exp.name() - - class C: def __init__(self) -> None: self.funs = [] @@ -730,6 +666,9 @@ def setUp(self) -> None: global cps_counter cps_counter = itertools.count() + global fun_counter + fun_counter = itertools.count() + def test_app_cont(self) -> None: # (E ... (fun (x) M1)) exp = App(Var("f"), [Atom(1), Fun([Var("x")], App(Var("k"), [Var("x")]))]) @@ -771,31 +710,32 @@ def test_proc(self) -> None: exp = App(Var("k"), [Fun([Var("x"), Var("j")], Prim("+", [Var("x"), Atom(1), Var("j")]))]) c = C() code = c.G(exp) - self.assertEqual(c.code(), "proc fun45(x) { return x + 1; }") - self.assertEqual(code, "return mkclosure(fun45);") - - def test_add_fn(self) -> None: - exp = parse(tokenize("x -> y -> x + y")) - exp = cps(exp, Var("k")) - exp = alphatise(exp) - exp = spin_opt(exp) - exp = make_closures_explicit(exp, {}) - c = C() - code = c.G(exp) - self.assertEqual( - c.code(), - """ - -""", - ) - self.assertEqual(code, "") + self.assertEqual(c.code(), "proc fun0(x) { return x + 1; }") + self.assertEqual(code, "return mkclosure(fun0);") + + +# def test_add_fn(self) -> None: +# exp = parse(tokenize("x -> y -> x + y")) +# exp = cps(exp, Var("k")) +# exp = alphatise(exp) +# exp = spin_opt(exp) +# exp = make_closures_explicit(exp, {}) +# c = C() +# code = c.G(exp) +# self.assertEqual( +# c.code(), +# """ +# +# """, +# ) +# self.assertEqual(code, "") if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 unittest.main() - c = Compiler(CFun("main")) + c = C(CFun("main")) exp = parse(tokenize("x -> y -> x + y")) print(exp) cps_exp = cps(exp, Var("halt")) From 5b0dc7be8cf8e28e77281c9cf780fa1d81b2ca60 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 16:54:50 -0400 Subject: [PATCH 10/31] . --- cps.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) diff --git a/cps.py b/cps.py index 1fa3c771..58bf511a 100644 --- a/cps.py +++ b/cps.py @@ -714,40 +714,6 @@ def test_proc(self) -> None: self.assertEqual(code, "return mkclosure(fun0);") -# def test_add_fn(self) -> None: -# exp = parse(tokenize("x -> y -> x + y")) -# exp = cps(exp, Var("k")) -# exp = alphatise(exp) -# exp = spin_opt(exp) -# exp = make_closures_explicit(exp, {}) -# c = C() -# code = c.G(exp) -# self.assertEqual( -# c.code(), -# """ -# -# """, -# ) -# self.assertEqual(code, "") - - if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 unittest.main() - - c = C(CFun("main")) - exp = parse(tokenize("x -> y -> x + y")) - print(exp) - cps_exp = cps(exp, Var("halt")) - print(cps_exp) - alphaed = alphatise(cps_exp) - optimized = spin_opt(alphaed) - print(optimized) - closurized = make_closures_explicit(optimized, {}) - print(closurized) - print(c.G(closurized)) - # c.compile(closurized) - # for fun in c.funs: - # for line in fun.code: - # print(line) - # print() From 605cf2e6ccdf77832413ca41114f39aee36582c6 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 16:57:46 -0400 Subject: [PATCH 11/31] . --- cps.py | 33 +++++++++++++++++++++++++++------ 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/cps.py b/cps.py index 58bf511a..6e0f7563 100644 --- a/cps.py +++ b/cps.py @@ -601,16 +601,37 @@ def test_app(self) -> None: # (f 42) self.assertEqual(make_closures_explicit(exp, {}), App(Var("f"), [Atom(42)])) - -@dataclasses.dataclass -class CFun: - name: str - code: list[str] = dataclasses.field(default_factory=list) + def test_add_function(self) -> None: + exp = cps(parse(tokenize("x -> y -> x + y")), Var("k")) + exp = spin_opt(exp) + # (k (fun (this x v2) + # (v2 (fun (this y v3) + # ($+ ($clo this 0) y v3))))) + self.assertEqual( + make_closures_explicit(exp, {}), + App( + Var("k"), + [ + Fun( + [Var("this"), Var("x"), Var("v2")], + App( + Var("v2"), + [ + Fun( + [Var("this"), Var("y"), Var("v3")], + Prim("+", [Prim("clo", [Var("this"), Atom(0)]), Var("y"), Var("v3")]), + ) + ], + ), + ) + ], + ), + ) class C: def __init__(self) -> None: - self.funs = [] + self.funs: list[CPSExpr] = [] def G(self, exp: CPSExpr) -> str: match exp: From aaacefbdb532c3cfceb667ce0310f40fc38fe027 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 17:02:34 -0400 Subject: [PATCH 12/31] clean up types --- cps.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cps.py b/cps.py index 6e0f7563..c5d1fc3a 100644 --- a/cps.py +++ b/cps.py @@ -631,7 +631,7 @@ def test_add_function(self) -> None: class C: def __init__(self) -> None: - self.funs: list[CPSExpr] = [] + self.funs: list[str] = [] def G(self, exp: CPSExpr) -> str: match exp: @@ -641,6 +641,7 @@ def G(self, exp: CPSExpr) -> str: return name case App(k, [Fun(_, _)]): assert isinstance(k, Var) + assert isinstance(exp.args[0], Fun) fun, name = self.G_proc(exp.args[0]) self.funs.append(fun) return f"return mkclosure({name});" @@ -671,7 +672,7 @@ def G_cont(self, val: str, exp: CPSExpr) -> str: case _: raise NotImplementedError(f"G_cont: {exp}") - def G_proc(self, exp: Fun) -> str: + def G_proc(self, exp: Fun) -> tuple[str, str]: match exp: case Fun([*args, _], M1): return f"proc fun{exp.id}({', '.join(arg.name for arg in args)}) {{ {self.G(M1)} " + "}", f"fun{exp.id}" From 116428e07f479b335be57c8d0b6c9856b8c0fb4f Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 17:31:58 -0400 Subject: [PATCH 13/31] Start doing open/closed analysis --- cps.py | 477 ++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 285 insertions(+), 192 deletions(-) diff --git a/cps.py b/cps.py index c5d1fc3a..31a10022 100644 --- a/cps.py +++ b/cps.py @@ -20,7 +20,7 @@ @dataclasses.dataclass class CPSExpr: - annotations: dict[str, object] = dataclasses.field(default_factory=dict, init=False) + pass @dataclasses.dataclass @@ -55,11 +55,22 @@ def __repr__(self) -> str: class Fun(CPSExpr): args: list[Var] body: CPSExpr + annotations: dict[str, object] = dataclasses.field(default_factory=dict) id: int = dataclasses.field(default_factory=lambda: next(fun_counter), compare=False) def name(self) -> str: return f"fun{self.id}" + def freevars(self) -> set[str]: + result = self.annotations["freevars"] + assert isinstance(result, set) + return result + + def kind(self) -> str: + result = self.annotations["kind"] + assert isinstance(result, str) + return result + def __repr__(self) -> str: args = " ".join(map(repr, self.args)) return f"(fun ({args}) {self.body!r})" @@ -520,6 +531,27 @@ def free_in(exp: CPSExpr) -> set[str]: raise NotImplementedError(f"free_in: {exp}") +def annotate_free_in(exp: CPSExpr) -> None: + match exp: + case Atom(_): + return + case Var(_): + return + case Prim(_, args): + for arg in args: + annotate_free_in(arg) + case Fun(args, body): + freevars = free_in(exp) + exp.annotations["freevars"] = freevars + for arg in args: + annotate_free_in(arg) + annotate_free_in(body) + case App(fun, args): + for arg in args: + annotate_free_in(arg) + annotate_free_in(fun) + + class FreeInTests(unittest.TestCase): def test_atom(self) -> None: self.assertEqual(free_in(Atom(42)), set()) @@ -535,205 +567,266 @@ def test_fun(self) -> None: exp = Fun([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) self.assertEqual(free_in(exp), {"z"}) + def test_fun_annotate(self) -> None: + exp = Fun([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) + annotate_free_in(exp) + self.assertEqual(exp.freevars(), {"z"}) + def test_app(self) -> None: exp = App(Var("f"), [Var("x"), Var("y")]) self.assertEqual(free_in(exp), {"f", "x", "y"}) -def make_closures_explicit(exp: CPSExpr, replacements: dict[str, CPSExpr]) -> CPSExpr: - def rec(exp: CPSExpr) -> CPSExpr: - return make_closures_explicit(exp, replacements) - +def classify_lambdas(exp: CPSExpr) -> None: match exp: case Atom(_): - return exp - case Var(name): - if name in replacements: - return replacements[name] - return exp - case Prim(op, args): - return Prim(op, [rec(arg) for arg in args]) - case Fun(args, body): - freevars = sorted(free_in(exp)) - this = Var("this") - new_replacements = {fv: Prim("clo", [this, Atom(idx)]) for idx, fv in enumerate(freevars)} - body = make_closures_explicit(body, {**replacements, **new_replacements}) - return Fun([this] + args, body) - case App(fun, args): - return App(rec(fun), [rec(arg) for arg in args]) - raise NotImplementedError(f"make_closures_explicit: {exp}") - - -class ClosureTests(unittest.TestCase): - def test_no_freevars(self) -> None: + return + case Var(_): + return + case App(Fun(_, body) as lam, args): + lam.annotations["kind"] = "open" + classify_lambdas(body) + for arg in args: + classify_lambdas(arg) + case Prim(_, [*args, Fun(_, _) as lam]): + lam.annotations["kind"] = "open" + for arg in args: + classify_lambdas(arg) + case App(f, args): + classify_lambdas(f) + for arg in args: + classify_lambdas(arg) + case Fun(_, body) as lam: + lam.annotations["kind"] = "closed" + classify_lambdas(body) + case Prim(_, args): + for arg in args: + classify_lambdas(arg) + case _: + raise NotImplementedError(f"classify_lambdas: {exp}") + + +class ClassificationTests(unittest.TestCase): + def test_open(self) -> None: + lam = Fun([Var("x")], Var("x")) + exp = App(lam, [Atom(42)]) + classify_lambdas(exp) + self.assertEqual(lam.kind(), "open") + + def test_open_prim(self) -> None: + lam = Fun([Var("x")], Var("x")) + exp = Prim("+", [Var("x"), Var("y"), lam]) + classify_lambdas(exp) + self.assertEqual(lam.kind(), "open") + + def test_closed_arg(self) -> None: + lam = Fun([Var("x")], Var("x")) + exp = App(Var("f"), [lam]) + classify_lambdas(exp) + self.assertEqual(lam.kind(), "closed") + + def test_closed(self) -> None: exp = Fun([Var("x")], Var("x")) - # (fun (this x) x) - self.assertEqual(make_closures_explicit(exp, {}), Fun([Var("this"), Var("x")], Var("x"))) - - def test_freevars(self) -> None: - exp = Fun([Var("k")], Prim("+", [Var("x"), Var("y"), Var("k")])) - # (fun (this k) ($+ ($clo this 0) ($clo this 1) k)) - self.assertEqual( - make_closures_explicit(exp, {}), - Fun( - [Var("this"), Var("k")], - Prim( - "+", - [ - Prim("clo", [Var("this"), Atom(0)]), - Prim("clo", [Var("this"), Atom(1)]), - Var("k"), - ], - ), - ), - ) - - def test_app_fun(self) -> None: - exp = App(Fun([Var("x")], Var("x")), [Atom(42)]) - # ((fun (this x) x) 42) - self.assertEqual( - make_closures_explicit(exp, {}), - App(Fun([Var("this"), Var("x")], Var("x")), [Atom(42)]), - ) - - def test_app(self) -> None: - exp = App(Var("f"), [Atom(42)]) - # (f 42) - self.assertEqual(make_closures_explicit(exp, {}), App(Var("f"), [Atom(42)])) - - def test_add_function(self) -> None: - exp = cps(parse(tokenize("x -> y -> x + y")), Var("k")) - exp = spin_opt(exp) - # (k (fun (this x v2) - # (v2 (fun (this y v3) - # ($+ ($clo this 0) y v3))))) - self.assertEqual( - make_closures_explicit(exp, {}), - App( - Var("k"), - [ - Fun( - [Var("this"), Var("x"), Var("v2")], - App( - Var("v2"), - [ - Fun( - [Var("this"), Var("y"), Var("v3")], - Prim("+", [Prim("clo", [Var("this"), Atom(0)]), Var("y"), Var("v3")]), - ) - ], - ), - ) - ], - ), - ) - - -class C: - def __init__(self) -> None: - self.funs: list[str] = [] - - def G(self, exp: CPSExpr) -> str: - match exp: - case Atom(int(value)): - return str(value) - case Var(name): - return name - case App(k, [Fun(_, _)]): - assert isinstance(k, Var) - assert isinstance(exp.args[0], Fun) - fun, name = self.G_proc(exp.args[0]) - self.funs.append(fun) - return f"return mkclosure({name});" - case App(k, [E]): - assert is_simple(E) - return f"return {E};" - case App(E, [*args, k]): - assert isinstance(E, Var) - assert all(is_simple(arg) for arg in args) - return self.G_cont(f"{E.name}({', '.join(str(arg) for arg in args)})", k) - case Prim("+", [x, y, k]): - assert is_simple(x) - assert is_simple(y) - return self.G_cont(f"{x} + {y}", k) - # TODO(max): j case - # TODO(max): Split cont and fun or annotate - case Prim("if", [cond, tk, fk]): - return f"if ({cond}) {{ {self.G(tk)} }} else {{ {self.G(fk)} }}" - case _: - raise NotImplementedError(f"G: {exp}") - - def G_cont(self, val: str, exp: CPSExpr) -> str: - match exp: - case Fun([res], M1): - return f"{res} <- {val}; {self.G(M1)}" - case Var(_): - return f"return {val};" - case _: - raise NotImplementedError(f"G_cont: {exp}") - - def G_proc(self, exp: Fun) -> tuple[str, str]: - match exp: - case Fun([*args, _], M1): - return f"proc fun{exp.id}({', '.join(arg.name for arg in args)}) {{ {self.G(M1)} " + "}", f"fun{exp.id}" - case _: - raise NotImplementedError(f"G_proc: {exp}") - - def code(self) -> str: - return "\n\n".join(self.funs) - - -class GTests(unittest.TestCase): - def setUp(self) -> None: - global cps_counter - cps_counter = itertools.count() - - global fun_counter - fun_counter = itertools.count() - - def test_app_cont(self) -> None: - # (E ... (fun (x) M1)) - exp = App(Var("f"), [Atom(1), Fun([Var("x")], App(Var("k"), [Var("x")]))]) - self.assertEqual(C().G(exp), "x <- f(1); return x;") - - def test_tailcall(self) -> None: - # (E ... k) - exp = App(Var("f"), [Atom(1), Var("k")]) - self.assertEqual(C().G(exp), "return f(1);") - - def test_return(self) -> None: - # (k E) - exp = App(Var("k"), [Atom(1)]) - self.assertEqual(C().G(exp), "return 1;") - - def test_if(self) -> None: - # ($if cond t f) - exp = Prim( - "if", - [ - Atom(1), - App(Var("k"), [Atom(2)]), - App(Var("k"), [Atom(3)]), - ], - ) - self.assertEqual(C().G(exp), "if (1) { return 2; } else { return 3; }") - - def test_add_cont(self) -> None: - # ($+ x y (fun (res) M1)) - exp = Prim("+", [Atom(1), Atom(2), Fun([Var("res")], App(Var("k"), [Var("res")]))]) - self.assertEqual(C().G(exp), "res <- 1 + 2; return res;") - - def test_add_cont_var(self) -> None: - # ($+ x y k) - exp = Prim("+", [Atom(1), Atom(2), Var("k")]) - self.assertEqual(C().G(exp), "return 1 + 2;") - - def test_proc(self) -> None: - exp = App(Var("k"), [Fun([Var("x"), Var("j")], Prim("+", [Var("x"), Atom(1), Var("j")]))]) - c = C() - code = c.G(exp) - self.assertEqual(c.code(), "proc fun0(x) { return x + 1; }") - self.assertEqual(code, "return mkclosure(fun0);") + classify_lambdas(exp) + self.assertEqual(exp.kind(), "closed") + + +# +# +# def make_closures_explicit(exp: CPSExpr, replacements: dict[str, CPSExpr]) -> CPSExpr: +# def rec(exp: CPSExpr) -> CPSExpr: +# return make_closures_explicit(exp, replacements) +# +# match exp: +# case Atom(_): +# return exp +# case Var(name): +# if name in replacements: +# return replacements[name] +# return exp +# case Prim(op, args): +# return Prim(op, [rec(arg) for arg in args]) +# case Fun(args, body): +# freevars = sorted(free_in(exp)) +# this = Var("this") +# new_replacements = {fv: Prim("clo", [this, Atom(idx)]) for idx, fv in enumerate(freevars)} +# body = make_closures_explicit(body, {**replacements, **new_replacements}) +# return Fun([this] + args, body) +# case App(fun, args): +# return App(rec(fun), [rec(arg) for arg in args]) +# raise NotImplementedError(f"make_closures_explicit: {exp}") +# +# +# class ClosureTests(unittest.TestCase): +# def test_no_freevars(self) -> None: +# exp = Fun([Var("x")], Var("x")) +# # (fun (this x) x) +# self.assertEqual(make_closures_explicit(exp, {}), Fun([Var("this"), Var("x")], Var("x"))) +# +# def test_freevars(self) -> None: +# exp = Fun([Var("k")], Prim("+", [Var("x"), Var("y"), Var("k")])) +# # (fun (this k) ($+ ($clo this 0) ($clo this 1) k)) +# self.assertEqual( +# make_closures_explicit(exp, {}), +# Fun( +# [Var("this"), Var("k")], +# Prim( +# "+", +# [ +# Prim("clo", [Var("this"), Atom(0)]), +# Prim("clo", [Var("this"), Atom(1)]), +# Var("k"), +# ], +# ), +# ), +# ) +# +# def test_app_fun(self) -> None: +# exp = App(Fun([Var("x")], Var("x")), [Atom(42)]) +# # ((fun (this x) x) 42) +# self.assertEqual( +# make_closures_explicit(exp, {}), +# App(Fun([Var("this"), Var("x")], Var("x")), [Atom(42)]), +# ) +# +# def test_app(self) -> None: +# exp = App(Var("f"), [Atom(42)]) +# # (f 42) +# self.assertEqual(make_closures_explicit(exp, {}), App(Var("f"), [Atom(42)])) +# +# def test_add_function(self) -> None: +# exp = cps(parse(tokenize("x -> y -> x + y")), Var("k")) +# exp = spin_opt(exp) +# # (k (fun (this x v2) +# # (v2 (fun (this y v3) +# # ($+ ($clo this 0) y v3))))) +# self.assertEqual( +# make_closures_explicit(exp, {}), +# App( +# Var("k"), +# [ +# Fun( +# [Var("this"), Var("x"), Var("v2")], +# App( +# Var("v2"), +# [ +# Fun( +# [Var("this"), Var("y"), Var("v3")], +# Prim("+", [Prim("clo", [Var("this"), Atom(0)]), Var("y"), Var("v3")]), +# ) +# ], +# ), +# ) +# ], +# ), +# ) +# +# +# class C: +# def __init__(self) -> None: +# self.funs: list[str] = [] +# +# def G(self, exp: CPSExpr) -> str: +# match exp: +# case Atom(int(value)): +# return str(value) +# case Var(name): +# return name +# case App(k, [Fun(_, _)]): +# assert isinstance(k, Var) +# assert isinstance(exp.args[0], Fun) +# fun, name = self.G_proc(exp.args[0]) +# self.funs.append(fun) +# return f"return mkclosure({name});" +# case App(k, [E]): +# assert is_simple(E) +# return f"return {E};" +# case App(E, [*args, k]): +# assert isinstance(E, Var) +# assert all(is_simple(arg) for arg in args) +# return self.G_cont(f"{E.name}({', '.join(str(arg) for arg in args)})", k) +# case Prim("+", [x, y, k]): +# assert is_simple(x) +# assert is_simple(y) +# return self.G_cont(f"{x} + {y}", k) +# # TODO(max): j case +# # TODO(max): Split cont and fun or annotate +# case Prim("if", [cond, tk, fk]): +# return f"if ({cond}) {{ {self.G(tk)} }} else {{ {self.G(fk)} }}" +# case _: +# raise NotImplementedError(f"G: {exp}") +# +# def G_cont(self, val: str, exp: CPSExpr) -> str: +# match exp: +# case Fun([res], M1): +# return f"{res} <- {val}; {self.G(M1)}" +# case Var(_): +# return f"return {val};" +# case _: +# raise NotImplementedError(f"G_cont: {exp}") +# +# def G_proc(self, exp: Fun) -> tuple[str, str]: +# match exp: +# case Fun([*args, _], M1): +# return f"proc fun{exp.id}({', '.join(arg.name for arg in args)}) {{ {self.G(M1)} " + "}", f"fun{exp.id}" +# case _: +# raise NotImplementedError(f"G_proc: {exp}") +# +# def code(self) -> str: +# return "\n\n".join(self.funs) +# +# +# class GTests(unittest.TestCase): +# def setUp(self) -> None: +# global cps_counter +# cps_counter = itertools.count() +# +# global fun_counter +# fun_counter = itertools.count() +# +# def test_app_cont(self) -> None: +# # (E ... (fun (x) M1)) +# exp = App(Var("f"), [Atom(1), Fun([Var("x")], App(Var("k"), [Var("x")]))]) +# self.assertEqual(C().G(exp), "x <- f(1); return x;") +# +# def test_tailcall(self) -> None: +# # (E ... k) +# exp = App(Var("f"), [Atom(1), Var("k")]) +# self.assertEqual(C().G(exp), "return f(1);") +# +# def test_return(self) -> None: +# # (k E) +# exp = App(Var("k"), [Atom(1)]) +# self.assertEqual(C().G(exp), "return 1;") +# +# def test_if(self) -> None: +# # ($if cond t f) +# exp = Prim( +# "if", +# [ +# Atom(1), +# App(Var("k"), [Atom(2)]), +# App(Var("k"), [Atom(3)]), +# ], +# ) +# self.assertEqual(C().G(exp), "if (1) { return 2; } else { return 3; }") +# +# def test_add_cont(self) -> None: +# # ($+ x y (fun (res) M1)) +# exp = Prim("+", [Atom(1), Atom(2), Fun([Var("res")], App(Var("k"), [Var("res")]))]) +# self.assertEqual(C().G(exp), "res <- 1 + 2; return res;") +# +# def test_add_cont_var(self) -> None: +# # ($+ x y k) +# exp = Prim("+", [Atom(1), Atom(2), Var("k")]) +# self.assertEqual(C().G(exp), "return 1 + 2;") +# +# def test_proc(self) -> None: +# exp = App(Var("k"), [Fun([Var("x"), Var("j")], Prim("+", [Var("x"), Atom(1), Var("j")]))]) +# c = C() +# code = c.G(exp) +# self.assertEqual(c.code(), "proc fun0(x) { return x + 1; }") +# self.assertEqual(code, "return mkclosure(fun0);") if __name__ == "__main__": From 85d34984d5d3d55b4266d239090e6439e0de0138 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:05:24 -0400 Subject: [PATCH 14/31] Add census --- cps.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/cps.py b/cps.py index 31a10022..9d1140b7 100644 --- a/cps.py +++ b/cps.py @@ -1,6 +1,7 @@ import dataclasses import itertools import unittest +from collections import Counter from scrapscript import ( parse, tokenize, @@ -319,6 +320,40 @@ def is_simple(exp: CPSExpr) -> bool: return isinstance(exp, (Atom, Var, Fun)) or (isinstance(exp, Prim) and exp.op in {"clo", "tag"}) +def census(exp: CPSExpr) -> Counter[str]: + if isinstance(exp, Atom): + return Counter() + if isinstance(exp, Var): + return Counter({exp.name: 1}) + if isinstance(exp, Prim): + return sum((census(arg) for arg in exp.args), Counter()) + if isinstance(exp, Fun): + return census(exp.body) + if isinstance(exp, App): + return sum((census(arg) for arg in exp.args), census(exp.fun)) + raise NotImplementedError(f"census: {exp}") + + +class CensusTests(unittest.TestCase): + def test_atom(self) -> None: + self.assertEqual(census(Atom(42)), {}) + + def test_var(self) -> None: + self.assertEqual(census(Var("x")), {"x": 1}) + + def test_prim(self) -> None: + exp = Prim("+", [Var("x"), Var("y"), Var("x")]) + self.assertEqual(census(exp), {"x": 2, "y": 1}) + + def test_fun(self) -> None: + exp = Fun([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("x")])) + self.assertEqual(census(exp), {"x": 2, "y": 1}) + + def test_app(self) -> None: + exp = App(Var("f"), [Var("x"), Var("y")]) + self.assertEqual(census(exp), {"f": 1, "x": 1, "y": 1}) + + def opt(exp: CPSExpr) -> CPSExpr: if isinstance(exp, Atom): return exp From 9b45b5cd178f3c5152a0df083437662b3a268632 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:05:32 -0400 Subject: [PATCH 15/31] Print annotations in function --- cps.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cps.py b/cps.py index 9d1140b7..14ed703b 100644 --- a/cps.py +++ b/cps.py @@ -74,7 +74,8 @@ def kind(self) -> str: def __repr__(self) -> str: args = " ".join(map(repr, self.args)) - return f"(fun ({args}) {self.body!r})" + annotations = f" {self.annotations}" if self.annotations else "" + return f"(fun ({args}){annotations} {self.body!r})" @dataclasses.dataclass From 69c91e511f4acb7183b93768e9243414ea2f13da Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:05:56 -0400 Subject: [PATCH 16/31] Only substitute small or used-once terms --- cps.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/cps.py b/cps.py index 14ed703b..bc81c321 100644 --- a/cps.py +++ b/cps.py @@ -321,6 +321,10 @@ def is_simple(exp: CPSExpr) -> bool: return isinstance(exp, (Atom, Var, Fun)) or (isinstance(exp, Prim) and exp.op in {"clo", "tag"}) +def is_small(exp: CPSExpr) -> bool: + return isinstance(exp, (Atom, Var)) + + def census(exp: CPSExpr) -> Counter[str]: if isinstance(exp, Atom): return Counter() @@ -384,7 +388,10 @@ def opt(exp: CPSExpr) -> CPSExpr: actuals = [opt(arg) for arg in exp.args] if len(formals) != len(actuals): return App(fun, actuals) - if all(is_simple(arg) for arg in actuals): + cen = census(fun.body) + # Idea: only substitute if the substituting would not blow up the size + # of the expression + if all(cen[arg_name(formal)] < 2 or is_small(actual) for formal, actual in zip(formals, actuals)): new_env = {arg_name(formal): actual for formal, actual in zip(formals, actuals)} return subst(fun.body, new_env) return App(fun, actuals) @@ -551,6 +558,28 @@ def test_variant(self) -> None: Prim("tag", [Atom("a_tag"), Atom(123), Var("k")]), ) + def test_beta_reduce_fun_with_zero_uses(self) -> None: + exp = App(Fun([Var("x")], Atom(1)), [Fun([Var("y")], Var("y"))]) + self.assertEqual( + spin_opt(exp), + Atom(1), + ) + + def test_beta_reduce_fun_with_one_use(self) -> None: + exp = App(Fun([Var("x")], Var("x")), [Fun([Var("y")], Var("y"))]) + self.assertEqual( + spin_opt(exp), + Fun([Var("y")], Var("y")), + ) + + def test_does_not_beta_reduce_fun_with_two_uses(self) -> None: + exp = App(Fun([Var("x")], Prim("+", [Var("x"), Var("x"), Var("k")])), [Fun([Var("y")], Var("y"))]) + self.assertEqual( + spin_opt(exp), + # ((fun (x) ($+ x x k)) (fun (y) y)) + App(Fun([Var("x")], Prim("+", [Var("x"), Var("x"), Var("k")])), [Fun([Var("y")], Var("y"))]), + ) + def free_in(exp: CPSExpr) -> set[str]: match exp: From 3b1d0ba972ce6e4f8d359c326e89336ea23a9830 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:08:57 -0400 Subject: [PATCH 17/31] Add TODO --- cps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cps.py b/cps.py index bc81c321..e9bed98c 100644 --- a/cps.py +++ b/cps.py @@ -267,6 +267,7 @@ def test_app(self) -> None: self.assertEqual(alphatise_(exp, {"x": "v0", "y": "v1"}), App(Var("f"), [Var("v0"), Var("v1")])) +# TODO(max): Freshen substituted terms in binders def subst(exp: CPSExpr, env: dict[str, CPSExpr]) -> CPSExpr: if isinstance(exp, Atom): return exp From a715619973a1ae4754e62fb9f1913618bf0126c7 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:10:14 -0400 Subject: [PATCH 18/31] Add TODO --- cps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cps.py b/cps.py index e9bed98c..af79e8d1 100644 --- a/cps.py +++ b/cps.py @@ -392,6 +392,7 @@ def opt(exp: CPSExpr) -> CPSExpr: cen = census(fun.body) # Idea: only substitute if the substituting would not blow up the size # of the expression + # TODO(max): Partial substitution for parameters that pass the guard if all(cen[arg_name(formal)] < 2 or is_small(actual) for formal, actual in zip(formals, actuals)): new_env = {arg_name(formal): actual for formal, actual in zip(formals, actuals)} return subst(fun.body, new_env) From 69c8cc0129787dc23d69f9d4009a92180bfa8a62 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:14:20 -0400 Subject: [PATCH 19/31] Add TODO --- cps.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cps.py b/cps.py index af79e8d1..cefc7ce9 100644 --- a/cps.py +++ b/cps.py @@ -645,6 +645,7 @@ def test_app(self) -> None: def classify_lambdas(exp: CPSExpr) -> None: + # TODO(max): Find first-order lambdas match exp: case Atom(_): return From 966cbfda419a8d53e9c386c9a8c2db59569f067b Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:22:40 -0400 Subject: [PATCH 20/31] Oops that only applies to + --- cps.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cps.py b/cps.py index cefc7ce9..f16a1b38 100644 --- a/cps.py +++ b/cps.py @@ -379,8 +379,8 @@ def opt(exp: CPSExpr) -> CPSExpr: # TODO(max): Only sum ints consts = [Atom(sum(c.value for c in consts))] # type: ignore args = consts + vars - if len(args) == 1: - return App(cont, args) + if len(args) == 1: + return App(cont, args) return Prim(exp.op, args + [cont]) if isinstance(exp, App) and isinstance(exp.fun, Fun): fun = opt(exp.fun) From f8166ed549c158c61745f5aba4e555f63c699262 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:24:00 -0400 Subject: [PATCH 21/31] Split cases --- cps.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cps.py b/cps.py index f16a1b38..05747c74 100644 --- a/cps.py +++ b/cps.py @@ -373,14 +373,14 @@ def opt(exp: CPSExpr) -> CPSExpr: if all(isinstance(arg, Atom) for arg in args): return App(cont, [Atom(args)]) if exp.op == "+": + if len(args) == 1: + return App(cont, args) consts = [arg for arg in args if isinstance(arg, Atom)] vars = [arg for arg in args if not isinstance(arg, Atom)] if consts: # TODO(max): Only sum ints consts = [Atom(sum(c.value for c in consts))] # type: ignore args = consts + vars - if len(args) == 1: - return App(cont, args) return Prim(exp.op, args + [cont]) if isinstance(exp, App) and isinstance(exp.fun, Fun): fun = opt(exp.fun) @@ -422,7 +422,11 @@ def setUp(self) -> None: def test_prim(self) -> None: exp = Prim("+", [Atom(1), Atom(2), Atom(3), Var("k")]) - self.assertEqual(opt(exp), App(Var("k"), [Atom(6)])) + self.assertEqual(opt(exp), Prim("+", [Atom(6), Var("k")])) + + def test_prim_spin(self) -> None: + exp = Prim("+", [Atom(1), Atom(2), Atom(3), Var("k")]) + self.assertEqual(spin_opt(exp), App(Var("k"), [Atom(6)])) def test_prim_var(self) -> None: exp = Prim("+", [Atom(1), Var("x"), Atom(3), Var("k")]) From fd25166999963602d080e0c3225c8e7d9da7c7b4 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:31:55 -0400 Subject: [PATCH 22/31] Factor CPS --- cps.py | 70 ++++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/cps.py b/cps.py index 05747c74..32976fe1 100644 --- a/cps.py +++ b/cps.py @@ -78,6 +78,19 @@ def __repr__(self) -> str: return f"(fun ({args}){annotations} {self.body!r})" +@dataclasses.dataclass +class Cont(CPSExpr): + args: list[Var] + body: CPSExpr + annotations: dict[str, object] = dataclasses.field(default_factory=dict) + id: int = dataclasses.field(default_factory=lambda: next(fun_counter), compare=False) + + def __repr__(self) -> str: + args = " ".join(map(repr, self.args)) + annotations = f" {self.annotations}" if self.annotations else "" + return f"(cont ({args}){annotations} {self.body!r})" + + @dataclasses.dataclass class App(CPSExpr): fun: CPSExpr @@ -95,7 +108,7 @@ def gensym() -> str: def cont(arg: Var, body: CPSExpr) -> CPSExpr: - return Fun([arg], body) + return Cont([arg], body) def cps(exp: Object, k: CPSExpr) -> CPSExpr: @@ -150,12 +163,12 @@ def test_var(self) -> None: def test_binop(self) -> None: self.assertEqual( cps(parse(tokenize("1 + 2")), Var("k")), - # ((fun (v0) ((fun (v1) (+ v0 v1 k)) 2)) 1) + # ((cont (v0) ((cont (v1) (+ v0 v1 k)) 2)) 1) App( - Fun( + Cont( [Var("v0")], App( - Fun( + Cont( [Var("v1")], Prim("+", [Var("v0"), Var("v1"), Var("k")]), ), @@ -170,18 +183,18 @@ def test_where(self) -> None: exp = parse(tokenize("a + b . a = 1 . b = 2")) self.assertEqual( cps(exp, Var("k")), - # ((fun (b) ((fun (a) ((fun (v0) ((fun (v1) (+ v0 v1 k)) b)) a)) 1)) 2) + # ((cont (b) ((cont (a) ((cont (v0) ((cont (v1) (+ v0 v1 k)) b)) a)) 1)) 2) App( - Fun( + Cont( [Var("b")], App( - Fun( + Cont( [Var("a")], App( - Fun( + Cont( [Var("v0")], App( - Fun( + Cont( [Var("v1")], Prim("+", [Var("v0"), Var("v1"), Var("k")]), ), @@ -204,8 +217,8 @@ def test_empty_list(self) -> None: def test_variant(self) -> None: self.assertEqual( cps(parse(tokenize("# a_tag 123")), Var("k")), - # ((fun (v0) ($tag 'a_tag' v0 k)) 123) - App(Fun([Var("v0")], Prim("tag", [Atom("a_tag"), Var("v0"), Var("k")])), [Atom(123)]), + # ((cont (v0) ($tag 'a_tag' v0 k)) 123) + App(Cont([Var("v0")], Prim("tag", [Atom("a_tag"), Var("v0"), Var("k")])), [Atom(123)]), ) @@ -221,10 +234,11 @@ def alphatise_(exp: CPSExpr, env: dict[str, str]) -> CPSExpr: return Var(env.get(exp.name, exp.name)) if isinstance(exp, Prim): return Prim(exp.op, [alphatise_(arg, env) for arg in exp.args]) - if isinstance(exp, Fun): + if isinstance(exp, (Fun, Cont)): + ty = type(exp) new_env = {arg_name(arg): gensym() for arg in exp.args} new_body = alphatise_(exp.body, {**env, **new_env}) - return Fun([Var(new_env[arg_name(arg)]) for arg in exp.args], new_body) + return ty([Var(new_env[arg_name(arg)]) for arg in exp.args], new_body) if isinstance(exp, App): return App(alphatise_(exp.fun, env), [alphatise_(arg, env) for arg in exp.args]) raise NotImplementedError(f"alphatise: {exp}") @@ -262,6 +276,16 @@ def test_fun(self) -> None: ), ) + def test_cont(self) -> None: + exp = Cont([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) + self.assertEqual( + alphatise(exp), + Cont( + [Var("v0"), Var("v1")], + Prim("+", [Var("v0"), Var("v1"), Var("z")]), + ), + ) + def test_app(self) -> None: exp = App(Var("f"), [Var("x"), Var("y")]) self.assertEqual(alphatise_(exp, {"x": "v0", "y": "v1"}), App(Var("f"), [Var("v0"), Var("v1")])) @@ -275,10 +299,11 @@ def subst(exp: CPSExpr, env: dict[str, CPSExpr]) -> CPSExpr: return env.get(exp.name, exp) if isinstance(exp, Prim): return Prim(exp.op, [subst(arg, env) for arg in exp.args]) - if isinstance(exp, Fun): + if isinstance(exp, (Fun, Cont)): + ty = type(exp) new_env = {arg_name(arg): Var(gensym()) for arg in exp.args} new_body = subst(exp.body, {**env, **new_env}) - return Fun([Var(new_env[arg_name(arg)].name) for arg in exp.args], new_body) + return ty([Var(new_env[arg_name(arg)].name) for arg in exp.args], new_body) if isinstance(exp, App): return App(subst(exp.fun, env), [subst(arg, env) for arg in exp.args]) raise NotImplementedError(f"subst: {exp}") @@ -333,7 +358,7 @@ def census(exp: CPSExpr) -> Counter[str]: return Counter({exp.name: 1}) if isinstance(exp, Prim): return sum((census(arg) for arg in exp.args), Counter()) - if isinstance(exp, Fun): + if isinstance(exp, (Fun, Cont)): return census(exp.body) if isinstance(exp, App): return sum((census(arg) for arg in exp.args), census(exp.fun)) @@ -382,9 +407,9 @@ def opt(exp: CPSExpr) -> CPSExpr: consts = [Atom(sum(c.value for c in consts))] # type: ignore args = consts + vars return Prim(exp.op, args + [cont]) - if isinstance(exp, App) and isinstance(exp.fun, Fun): + if isinstance(exp, App) and isinstance(exp.fun, (Fun, Cont)): fun = opt(exp.fun) - assert isinstance(fun, Fun) + assert isinstance(fun, (Fun, Cont)) formals = exp.fun.args actuals = [opt(arg) for arg in exp.args] if len(formals) != len(actuals): @@ -401,9 +426,10 @@ def opt(exp: CPSExpr) -> CPSExpr: fun = opt(exp.fun) args = [opt(arg) for arg in exp.args] return App(fun, args) - if isinstance(exp, Fun): + if isinstance(exp, (Fun, Cont)): body = opt(exp.body) return Fun(exp.args, body) + raise NotImplementedError(f"opt: {exp}") return exp @@ -440,7 +466,7 @@ def test_add(self) -> None: exp = parse(tokenize("1 + 2 + c")) self.assertEqual( spin_opt(cps(exp, Var("k"))), - Prim("+", [Atom(2), Var("c"), Fun([Var("v6")], Prim("+", [Atom(1), Var("v6"), Var("k")]))]), + Prim("+", [Atom(2), Var("c"), Cont([Var("v6")], Prim("+", [Atom(1), Var("v6"), Var("k")]))]), ) def test_simple_fun(self) -> None: @@ -545,13 +571,13 @@ def test_make_list(self) -> None: exp = parse(tokenize("[1+2, x, 3+4]")) self.assertEqual( spin_opt(cps(exp, Var("k"))), - # ($cons x [7, []] (fun (v46) ($cons 3 v46 k))) + # ($cons x [7, []] (cont (v46) ($cons 3 v46 k))) Prim( "cons", [ Var("x"), Atom([Atom(7), Atom([])]), - Fun([Var("v46")], Prim("cons", [Atom(3), Var("v46"), Var("k")])), + Cont([Var("v46")], Prim("cons", [Atom(3), Var("v46"), Var("k")])), ], ), ) From fe49353134fac1c93c4115f4c99ae588efb70983 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:43:44 -0400 Subject: [PATCH 23/31] Mul primitive --- cps.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/cps.py b/cps.py index 32976fe1..b223a97d 100644 --- a/cps.py +++ b/cps.py @@ -1,6 +1,8 @@ import dataclasses import itertools import unittest +import operator +from functools import reduce from collections import Counter from scrapscript import ( parse, @@ -406,6 +408,14 @@ def opt(exp: CPSExpr) -> CPSExpr: # TODO(max): Only sum ints consts = [Atom(sum(c.value for c in consts))] # type: ignore args = consts + vars + if exp.op == "*": + if len(args) == 1: + return App(cont, args) + consts = [arg for arg in args if isinstance(arg, Atom)] + vars = [arg for arg in args if not isinstance(arg, Atom)] + if consts: + consts = [Atom(reduce(operator.mul, (c.value for c in consts), 1))] + args = consts + vars return Prim(exp.op, args + [cont]) if isinstance(exp, App) and isinstance(exp.fun, (Fun, Cont)): fun = opt(exp.fun) @@ -454,6 +464,10 @@ def test_prim_spin(self) -> None: exp = Prim("+", [Atom(1), Atom(2), Atom(3), Var("k")]) self.assertEqual(spin_opt(exp), App(Var("k"), [Atom(6)])) + def test_prim_mul_spin(self) -> None: + exp = Prim("*", [Atom(2), Atom(3), Atom(4), Var("k")]) + self.assertEqual(spin_opt(exp), App(Var("k"), [Atom(24)])) + def test_prim_var(self) -> None: exp = Prim("+", [Atom(1), Var("x"), Atom(3), Var("k")]) self.assertEqual(opt(exp), Prim("+", [Atom(4), Var("x"), Var("k")])) From 6c277f1844d72502f0657683c9a803ad093c1036 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:52:55 -0400 Subject: [PATCH 24/31] Simplify pattern matching and fix opt --- cps.py | 53 ++++++++++++++++++++--------------------------------- 1 file changed, 20 insertions(+), 33 deletions(-) diff --git a/cps.py b/cps.py index b223a97d..34796ce0 100644 --- a/cps.py +++ b/cps.py @@ -1,8 +1,6 @@ import dataclasses import itertools import unittest -import operator -from functools import reduce from collections import Counter from scrapscript import ( parse, @@ -392,31 +390,23 @@ def opt(exp: CPSExpr) -> CPSExpr: return exp if isinstance(exp, Var): return exp - if isinstance(exp, Prim): - args = [opt(arg) for arg in exp.args[:-1]] - cont = exp.args[-1] - if exp.op == "cons": - assert len(args) == 2 - if all(isinstance(arg, Atom) for arg in args): - return App(cont, [Atom(args)]) - if exp.op == "+": - if len(args) == 1: - return App(cont, args) - consts = [arg for arg in args if isinstance(arg, Atom)] - vars = [arg for arg in args if not isinstance(arg, Atom)] - if consts: - # TODO(max): Only sum ints - consts = [Atom(sum(c.value for c in consts))] # type: ignore - args = consts + vars - if exp.op == "*": - if len(args) == 1: - return App(cont, args) + match exp: + # TODO(max): Only sum/multiply ints + case Prim("+" | "*", [Atom(int(x)), k]): + return App(k, [Atom(x)]) + case Prim("+", [Atom(int(x)), Atom(int(y)), *args]): + return Prim("+", [Atom(x + y), *args]) + case Prim("*", [Atom(int(x)), Atom(int(y)), *args]): + return Prim("*", [Atom(x * y), *args]) + case Prim("+" | "*" as op, args): + # Move constants left consts = [arg for arg in args if isinstance(arg, Atom)] vars = [arg for arg in args if not isinstance(arg, Atom)] - if consts: - consts = [Atom(reduce(operator.mul, (c.value for c in consts), 1))] - args = consts + vars - return Prim(exp.op, args + [cont]) + return Prim(op, consts + vars) + case Prim("cons", [Atom(_) as x, Atom(_) as y, k]): + return App(k, [Atom([x, y])]) + case Prim(op, args): + return Prim(op, [opt(arg) for arg in args]) if isinstance(exp, App) and isinstance(exp.fun, (Fun, Cont)): fun = opt(exp.fun) assert isinstance(fun, (Fun, Cont)) @@ -437,8 +427,9 @@ def opt(exp: CPSExpr) -> CPSExpr: args = [opt(arg) for arg in exp.args] return App(fun, args) if isinstance(exp, (Fun, Cont)): + ty = type(exp) body = opt(exp.body) - return Fun(exp.args, body) + return ty(exp.args, body) raise NotImplementedError(f"opt: {exp}") return exp @@ -456,21 +447,17 @@ def setUp(self) -> None: global cps_counter cps_counter = itertools.count() - def test_prim(self) -> None: - exp = Prim("+", [Atom(1), Atom(2), Atom(3), Var("k")]) - self.assertEqual(opt(exp), Prim("+", [Atom(6), Var("k")])) - - def test_prim_spin(self) -> None: + def test_prim_add(self) -> None: exp = Prim("+", [Atom(1), Atom(2), Atom(3), Var("k")]) self.assertEqual(spin_opt(exp), App(Var("k"), [Atom(6)])) - def test_prim_mul_spin(self) -> None: + def test_prim_mul(self) -> None: exp = Prim("*", [Atom(2), Atom(3), Atom(4), Var("k")]) self.assertEqual(spin_opt(exp), App(Var("k"), [Atom(24)])) def test_prim_var(self) -> None: exp = Prim("+", [Atom(1), Var("x"), Atom(3), Var("k")]) - self.assertEqual(opt(exp), Prim("+", [Atom(4), Var("x"), Var("k")])) + self.assertEqual(spin_opt(exp), Prim("+", [Atom(4), Var("x"), Var("k")])) def test_subst(self) -> None: exp = App(Fun([Var("x")], Prim("+", [Atom(1), Var("x"), Atom(2), Var("k")])), [Atom(3)]) From 3315564f095472ae7fd0fd9174fc8f06cfba9903 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 6 Jun 2024 09:52:03 -0400 Subject: [PATCH 25/31] Add CPS web repl --- Dockerfile | 1 + cpsrepl.html | 224 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+) create mode 100644 cpsrepl.html diff --git a/Dockerfile b/Dockerfile index fc6c1a33..5c9d51de 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ COPY . . RUN echo ":8000" > /etc/caddy/Caddyfile RUN echo "rewrite /repl /repl.html" >> /etc/caddy/Caddyfile RUN echo "rewrite /compilerepl /compilerepl.html" >> /etc/caddy/Caddyfile +RUN echo "rewrite /cpsrepl /cpsrepl.html" >> /etc/caddy/Caddyfile RUN echo "log" >> /etc/caddy/Caddyfile RUN echo "file_server" >> /etc/caddy/Caddyfile diff --git a/cpsrepl.html b/cpsrepl.html new file mode 100644 index 00000000..b21e52ba --- /dev/null +++ b/cpsrepl.html @@ -0,0 +1,224 @@ + + + + + +Scrapscript Web REPL + + + + + + + + + +
+
+

See scrapscript.org for a slightly +out of date language reference.

+

This REPL is completely client-side and works by running +scrapscript.py in the +browser using Pyodide.

+
+
+ + +
+
+Output: +
+
+ +
+ + +
+ + + From 4fed7b84d79635c2d29980606c215c5022eacb47 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 6 Jun 2024 09:56:57 -0400 Subject: [PATCH 26/31] Opt in web repl --- cpsrepl.html | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpsrepl.html b/cpsrepl.html index b21e52ba..bb64d056 100644 --- a/cpsrepl.html +++ b/cpsrepl.html @@ -68,8 +68,10 @@ async function sendRequest(exp) { const scrap = document.scrapscript; const prog = scrap.parse(scrap.tokenize(exp)); + const cps = document.cps.cps(prog, document.cps.Var("k")); + const opt = document.cps.spin_opt(cps); try { - return {result: document.cps.cps(prog, document.cps.Var("k")), ok: true}; + return {result: opt, ok: true}; } catch (e) { return {text: () => e.toString(), ok: false}; } From 143628e28e013a01f1c8d32b74d216d84130ca9b Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 6 Jun 2024 10:07:36 -0400 Subject: [PATCH 27/31] Push more inside try --- cpsrepl.html | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cpsrepl.html b/cpsrepl.html index bb64d056..47aee93c 100644 --- a/cpsrepl.html +++ b/cpsrepl.html @@ -66,11 +66,11 @@ } async function sendRequest(exp) { - const scrap = document.scrapscript; - const prog = scrap.parse(scrap.tokenize(exp)); - const cps = document.cps.cps(prog, document.cps.Var("k")); - const opt = document.cps.spin_opt(cps); try { + const scrap = document.scrapscript; + const prog = scrap.parse(scrap.tokenize(exp)); + const cps = document.cps.cps(prog, document.cps.Var("k")); + const opt = document.cps.spin_opt(cps); return {result: opt, ok: true}; } catch (e) { return {text: () => e.toString(), ok: false}; From 89af5a975456ca24a7d5da350013cf6d8f3868d0 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 5 Jun 2024 23:56:16 -0400 Subject: [PATCH 28/31] Support cont in classify --- cps.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cps.py b/cps.py index 34796ce0..0694b2b3 100644 --- a/cps.py +++ b/cps.py @@ -682,12 +682,12 @@ def classify_lambdas(exp: CPSExpr) -> None: return case Var(_): return - case App(Fun(_, body) as lam, args): + case App(Fun(_, body) | Cont(_, body) as lam, args): lam.annotations["kind"] = "open" classify_lambdas(body) for arg in args: classify_lambdas(arg) - case Prim(_, [*args, Fun(_, _) as lam]): + case Prim(_, [*args, Fun(_, _) | Cont(_, _) as lam]): lam.annotations["kind"] = "open" for arg in args: classify_lambdas(arg) @@ -695,7 +695,7 @@ def classify_lambdas(exp: CPSExpr) -> None: classify_lambdas(f) for arg in args: classify_lambdas(arg) - case Fun(_, body) as lam: + case Fun(_, body) | Cont(_, body) as lam: lam.annotations["kind"] = "closed" classify_lambdas(body) case Prim(_, args): From 62d7597fabaf2a2bdf7228e2969d7d059c167526 Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 6 Jun 2024 00:35:09 -0400 Subject: [PATCH 29/31] Support cont in free_in --- cps.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cps.py b/cps.py index 0694b2b3..9301d0a7 100644 --- a/cps.py +++ b/cps.py @@ -74,7 +74,8 @@ def kind(self) -> str: def __repr__(self) -> str: args = " ".join(map(repr, self.args)) - annotations = f" {self.annotations}" if self.annotations else "" + annotations = " ".join(f"[{k} {v}]" for k, v in self.annotations.items() if v) + annotations = f" {annotations}" # if self.annotations else "" return f"(fun ({args}){annotations} {self.body!r})" @@ -87,7 +88,8 @@ class Cont(CPSExpr): def __repr__(self) -> str: args = " ".join(map(repr, self.args)) - annotations = f" {self.annotations}" if self.annotations else "" + annotations = " ".join(f"[{k} {v}]" for k, v in self.annotations.items() if v) + annotations = f" {annotations}" # if self.annotations else "" return f"(cont ({args}){annotations} {self.body!r})" @@ -622,7 +624,7 @@ def free_in(exp: CPSExpr) -> set[str]: return {name} case Prim(_, args): return {name for arg in args for name in free_in(arg)} - case Fun(args, body): + case Fun(args, body) | Cont(args, body): return free_in(body) - {arg_name(arg) for arg in args} case App(fun, args): return free_in(fun) | {name for arg in args for name in free_in(arg)} @@ -665,6 +667,10 @@ def test_fun(self) -> None: exp = Fun([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) self.assertEqual(free_in(exp), {"z"}) + def test_cont(self) -> None: + exp = Cont([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) + self.assertEqual(free_in(exp), {"z"}) + def test_fun_annotate(self) -> None: exp = Fun([Var("x"), Var("y")], Prim("+", [Var("x"), Var("y"), Var("z")])) annotate_free_in(exp) From 78275d2b3a869266ba03eb693edfa0309507954e Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Thu, 6 Jun 2024 00:37:44 -0400 Subject: [PATCH 30/31] Bring explicit closure access back --- cps.py | 178 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 89 insertions(+), 89 deletions(-) diff --git a/cps.py b/cps.py index 9301d0a7..d66dae2b 100644 --- a/cps.py +++ b/cps.py @@ -736,95 +736,95 @@ def test_closed(self) -> None: self.assertEqual(exp.kind(), "closed") -# -# -# def make_closures_explicit(exp: CPSExpr, replacements: dict[str, CPSExpr]) -> CPSExpr: -# def rec(exp: CPSExpr) -> CPSExpr: -# return make_closures_explicit(exp, replacements) -# -# match exp: -# case Atom(_): -# return exp -# case Var(name): -# if name in replacements: -# return replacements[name] -# return exp -# case Prim(op, args): -# return Prim(op, [rec(arg) for arg in args]) -# case Fun(args, body): -# freevars = sorted(free_in(exp)) -# this = Var("this") -# new_replacements = {fv: Prim("clo", [this, Atom(idx)]) for idx, fv in enumerate(freevars)} -# body = make_closures_explicit(body, {**replacements, **new_replacements}) -# return Fun([this] + args, body) -# case App(fun, args): -# return App(rec(fun), [rec(arg) for arg in args]) -# raise NotImplementedError(f"make_closures_explicit: {exp}") -# -# -# class ClosureTests(unittest.TestCase): -# def test_no_freevars(self) -> None: -# exp = Fun([Var("x")], Var("x")) -# # (fun (this x) x) -# self.assertEqual(make_closures_explicit(exp, {}), Fun([Var("this"), Var("x")], Var("x"))) -# -# def test_freevars(self) -> None: -# exp = Fun([Var("k")], Prim("+", [Var("x"), Var("y"), Var("k")])) -# # (fun (this k) ($+ ($clo this 0) ($clo this 1) k)) -# self.assertEqual( -# make_closures_explicit(exp, {}), -# Fun( -# [Var("this"), Var("k")], -# Prim( -# "+", -# [ -# Prim("clo", [Var("this"), Atom(0)]), -# Prim("clo", [Var("this"), Atom(1)]), -# Var("k"), -# ], -# ), -# ), -# ) -# -# def test_app_fun(self) -> None: -# exp = App(Fun([Var("x")], Var("x")), [Atom(42)]) -# # ((fun (this x) x) 42) -# self.assertEqual( -# make_closures_explicit(exp, {}), -# App(Fun([Var("this"), Var("x")], Var("x")), [Atom(42)]), -# ) -# -# def test_app(self) -> None: -# exp = App(Var("f"), [Atom(42)]) -# # (f 42) -# self.assertEqual(make_closures_explicit(exp, {}), App(Var("f"), [Atom(42)])) -# -# def test_add_function(self) -> None: -# exp = cps(parse(tokenize("x -> y -> x + y")), Var("k")) -# exp = spin_opt(exp) -# # (k (fun (this x v2) -# # (v2 (fun (this y v3) -# # ($+ ($clo this 0) y v3))))) -# self.assertEqual( -# make_closures_explicit(exp, {}), -# App( -# Var("k"), -# [ -# Fun( -# [Var("this"), Var("x"), Var("v2")], -# App( -# Var("v2"), -# [ -# Fun( -# [Var("this"), Var("y"), Var("v3")], -# Prim("+", [Prim("clo", [Var("this"), Atom(0)]), Var("y"), Var("v3")]), -# ) -# ], -# ), -# ) -# ], -# ), -# ) +def make_closures_explicit(exp: CPSExpr, replacements: dict[str, CPSExpr]) -> CPSExpr: + def rec(exp: CPSExpr) -> CPSExpr: + return make_closures_explicit(exp, replacements) + + match exp: + case Atom(_): + return exp + case Var(name): + if name in replacements: + return replacements[name] + return exp + case Prim(op, args): + return Prim(op, [rec(arg) for arg in args]) + case Fun(args, body): + freevars = sorted(free_in(exp)) + this = Var("this") + new_replacements = {fv: Prim("clo", [this, Atom(idx)]) for idx, fv in enumerate(freevars)} + body = make_closures_explicit(body, {**replacements, **new_replacements}) + return Fun([this] + args, body) + case App(fun, args): + return App(rec(fun), [rec(arg) for arg in args]) + raise NotImplementedError(f"make_closures_explicit: {exp}") + + +class ClosureTests(unittest.TestCase): + def test_no_freevars(self) -> None: + exp = Fun([Var("x")], Var("x")) + # (fun (this x) x) + self.assertEqual(make_closures_explicit(exp, {}), Fun([Var("this"), Var("x")], Var("x"))) + + def test_freevars(self) -> None: + exp = Fun([Var("k")], Prim("+", [Var("x"), Var("y"), Var("k")])) + # (fun (this k) ($+ ($clo this 0) ($clo this 1) k)) + self.assertEqual( + make_closures_explicit(exp, {}), + Fun( + [Var("this"), Var("k")], + Prim( + "+", + [ + Prim("clo", [Var("this"), Atom(0)]), + Prim("clo", [Var("this"), Atom(1)]), + Var("k"), + ], + ), + ), + ) + + def test_app_fun(self) -> None: + exp = App(Fun([Var("x")], Var("x")), [Atom(42)]) + # ((fun (this x) x) 42) + self.assertEqual( + make_closures_explicit(exp, {}), + App(Fun([Var("this"), Var("x")], Var("x")), [Atom(42)]), + ) + + def test_app(self) -> None: + exp = App(Var("f"), [Atom(42)]) + # (f 42) + self.assertEqual(make_closures_explicit(exp, {}), App(Var("f"), [Atom(42)])) + + def test_add_function(self) -> None: + exp = cps(parse(tokenize("x -> y -> x + y")), Var("k")) + exp = spin_opt(exp) + # (k (fun (this x v2) + # (v2 (fun (this y v3) + # ($+ ($clo this 0) y v3))))) + self.assertEqual( + make_closures_explicit(exp, {}), + App( + Var("k"), + [ + Fun( + [Var("this"), Var("x"), Var("v2")], + App( + Var("v2"), + [ + Fun( + [Var("this"), Var("y"), Var("v3")], + Prim("+", [Prim("clo", [Var("this"), Atom(0)]), Var("y"), Var("v3")]), + ) + ], + ), + ) + ], + ), + ) + + # # # class C: From 9dbc1f7babdb68c836150fd79c9594264003659a Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Sun, 9 Jun 2024 17:16:00 -0400 Subject: [PATCH 31/31] wip-k --- cps.py | 92 +++++++++++++++++-------------- examples/0_home/combinators.scrap | 6 +- 2 files changed, 53 insertions(+), 45 deletions(-) diff --git a/cps.py b/cps.py index d66dae2b..dfd7d99c 100644 --- a/cps.py +++ b/cps.py @@ -105,8 +105,8 @@ def __repr__(self) -> str: cps_counter = itertools.count() -def gensym() -> str: - return f"v{next(cps_counter)}" +def gensym(stem="v") -> str: + return f"{stem}{next(cps_counter)}" def cont(arg: Var, body: CPSExpr) -> CPSExpr: @@ -136,7 +136,7 @@ def cps(exp: Object, k: CPSExpr) -> CPSExpr: if isinstance(exp, Function): assert isinstance(exp.arg, ScrapVar) arg = Var(exp.arg.name) - subk = Var(gensym()) + subk = Var(gensym("k")) return App(k, [Fun([arg, subk], cps(exp.body, subk))]) if isinstance(exp, List) and not exp.items: return App(k, [Atom([])]) @@ -476,13 +476,13 @@ def test_simple_fun(self) -> None: exp = cps(parse(tokenize("_ -> 1")), Var("k")) self.assertEqual( spin_opt(exp), - # (k (fun (_ v0) (v0 1))) + # (k (fun (_ k0) (k0 1))) App( Var("k"), [ Fun( - [Var("_"), Var("v0")], - App(Var("v0"), [Atom(1)]), + [Var("_"), Var("k0")], + App(Var("k0"), [Atom(1)]), ) ], ), @@ -492,13 +492,13 @@ def test_fun(self) -> None: exp = cps(parse(tokenize("_ -> 1 + 2 + 3")), Var("k")) self.assertEqual( spin_opt(exp), - # (k (fun (_ v0) (v0 6))) + # (k (fun (_ k0) (k0 6))) App( Var("k"), [ Fun( - [Var("_"), Var("v0")], - App(Var("v0"), [Atom(6)]), + [Var("_"), Var("k0")], + App(Var("k0"), [Atom(6)]), ) ], ), @@ -508,18 +508,18 @@ def test_add_function(self) -> None: exp = parse(tokenize("x -> y -> x + y")) self.assertEqual( spin_opt(cps(exp, Var("k"))), - # (k (fun (x v0) (v0 (fun (y v1) ($+ x y v1))))) + # (k (fun (x k0) (k0 (fun (y k1) ($+ x y k1))))) App( Var("k"), [ Fun( - [Var("x"), Var("v0")], + [Var("x"), Var("k0")], App( - Var("v0"), + Var("k0"), [ Fun( - [Var("y"), Var("v1")], - Prim("+", [Var("x"), Var("y"), Var("v1")]), + [Var("y"), Var("k1")], + Prim("+", [Var("x"), Var("y"), Var("k1")]), ) ], ), @@ -736,10 +736,18 @@ def test_closed(self) -> None: self.assertEqual(exp.kind(), "closed") +def is_cont_var(name: str) -> bool: + return name.startswith("k") + + def make_closures_explicit(exp: CPSExpr, replacements: dict[str, CPSExpr]) -> CPSExpr: def rec(exp: CPSExpr) -> CPSExpr: return make_closures_explicit(exp, replacements) + # def process_arg(arg: CPSExpr) -> CPSExpr: + # match arg: + # case Fun + match exp: case Atom(_): return exp @@ -749,12 +757,17 @@ def rec(exp: CPSExpr) -> CPSExpr: return exp case Prim(op, args): return Prim(op, [rec(arg) for arg in args]) - case Fun(args, body): - freevars = sorted(free_in(exp)) - this = Var("this") - new_replacements = {fv: Prim("clo", [this, Atom(idx)]) for idx, fv in enumerate(freevars)} - body = make_closures_explicit(body, {**replacements, **new_replacements}) - return Fun([this] + args, body) + # case Fun(args, body): + # freevars = sorted(free_in(exp)) + # this = Var("this") + # new_replacements = {fv: Prim("clo", [this, Atom(idx)]) for idx, fv in enumerate(freevars)} + # body = make_closures_explicit(body, {**replacements, **new_replacements}) + # return Fun([this] + args, body) + case App(Var(cont), args) if is_cont_var(cont): + return Prim("return", [Prim("return-address", [Var(cont)])]+[rec(arg) for arg in args]) + case App(Var(_) as clo, args): + return App(Prim("clo-fn", [clo]), [clo] + [rec(arg) for arg in args]) + return App(rec(fun), [rec(arg) for arg in args]) case App(fun, args): return App(rec(fun), [rec(arg) for arg in args]) raise NotImplementedError(f"make_closures_explicit: {exp}") @@ -795,33 +808,18 @@ def test_app_fun(self) -> None: def test_app(self) -> None: exp = App(Var("f"), [Atom(42)]) # (f 42) - self.assertEqual(make_closures_explicit(exp, {}), App(Var("f"), [Atom(42)])) + self.assertEqual( + make_closures_explicit(exp, {}), + # (($clo-fn f) f 42) + App(Prim("clo-fn", [Var("f")]), [Var("f"), Atom(42)]), + ) def test_add_function(self) -> None: exp = cps(parse(tokenize("x -> y -> x + y")), Var("k")) exp = spin_opt(exp) - # (k (fun (this x v2) - # (v2 (fun (this y v3) - # ($+ ($clo this 0) y v3))))) self.assertEqual( make_closures_explicit(exp, {}), - App( - Var("k"), - [ - Fun( - [Var("this"), Var("x"), Var("v2")], - App( - Var("v2"), - [ - Fun( - [Var("this"), Var("y"), Var("v3")], - Prim("+", [Prim("clo", [Var("this"), Atom(0)]), Var("y"), Var("v3")]), - ) - ], - ), - ) - ], - ), + 1, ) @@ -934,6 +932,18 @@ def test_add_function(self) -> None: # self.assertEqual(code, "return mkclosure(fun0);") +def compile_exp(exp: Object) -> CPSExpr: + prog = spin_opt(alphatise(cps(exp, Var("k")))) + annotate_free_in(prog) + classify_lambdas(prog) + return prog + + if __name__ == "__main__": __import__("sys").modules["unittest.util"]._MAX_LENGTH = 999999999 unittest.main() + with open("examples/0_home/combinators.scrap", "r") as f: + source = f.read() + + exp = parse(tokenize(source)) + print(compile_exp(exp)) diff --git a/examples/0_home/combinators.scrap b/examples/0_home/combinators.scrap index 64a71245..c381d180 100644 --- a/examples/0_home/combinators.scrap +++ b/examples/0_home/combinators.scrap @@ -14,7 +14,5 @@ Z factr 5 . Y = f -> (x -> f (x x)) (x -> f (x x)) . Z = f -> (x -> f (v -> (x x) v)) (x -> f (v -> (x x) v)) -. factr = facti -> - | 0 -> 1 - | n -> (mult n) (facti (n - 1)) -. mult = x -> y -> x * y \ No newline at end of file +. factr = facti -> n -> mult n 10 +. mult = x -> y -> x * y