1+ package tutorial
2+
3+ import lms .core ._
4+ import lms .macros ._
5+ import lms .core .stub ._
6+ import lms .core .Backend ._
7+
8+ import gensym .lmsx ._
9+
10+ import scala .collection .immutable .{List => SList }
11+
12+ object ImpLang {
13+ sealed trait Stmt
14+ case class Skip () extends Stmt
15+ case class Break () extends Stmt
16+ case class Assign (x : String , e : Expr ) extends Stmt
17+ case class Cond (e : Expr , thn : Stmt , els : Stmt ) extends Stmt
18+ case class Seq (s1 : Stmt , s2 : Stmt ) extends Stmt
19+ case class While (b : Expr , s : Stmt ) extends Stmt
20+ case class Output (e : Expr ) extends Stmt
21+ case class Assert (e : Expr ) extends Stmt
22+
23+ sealed trait Expr {
24+ def toSExp : String
25+ }
26+ case class Input () extends Expr {
27+ def toSExp = ???
28+ }
29+ case class Lit (x : Any ) extends Expr {
30+ override def toString : String = s " Lit( ${x.toString}) "
31+ def toSExp : String = x.toString
32+ }
33+ case class Var (x : String ) extends Expr {
34+ override def toString : String = " Var(\" " + x.toString + " \" )"
35+ def toSExp : String = x.toString
36+ }
37+ case class Op1 (op : String , e : Expr ) extends Expr {
38+ override def toString : String = " Op1(\" " + op + " \" ," + s " ${e.toString}) "
39+ def toSExp : String = s " ( $op ${e.toSExp}) "
40+ }
41+ case class Op2 (op : String , e1 : Expr , e2 : Expr ) extends Expr {
42+ override def toString : String =
43+ " Op2(\" " + op + " \" ," + s " ${e1.toString}, ${e2.toString}) "
44+ def toSExp : String = s " ( $op ${e1.toSExp} ${e2.toSExp}) "
45+ }
46+
47+ def let_ (x : String , rhs : Int )(body : Var => Stmt ): Stmt =
48+ Seq (Assign (x, Lit (rhs)), body(Var (x)))
49+ def let_ (x : String , rhs : Expr )(body : Var => Stmt ): Stmt =
50+ Seq (Assign (x, rhs), body(Var (x)))
51+
52+ def set_ (x : String , rhs : Expr ): Stmt = Assign (x, rhs)
53+
54+ def while_ (e : Expr , s : Stmt ): Stmt = While (e, s)
55+
56+ object Examples {
57+ val fact5 =
58+ Seq (Assign (" i" , Lit (1 )),
59+ Seq (Assign (" fact" , Lit (1 )),
60+ While (Op2 (" <=" , Var (" i" ), Lit (5 )),
61+ Seq (Assign (" fact" , Op2 (" *" , Var (" fact" ), Var (" i" ))),
62+ Assign (" i" , Op2 (" +" , Var (" i" ), Lit (1 )))))))
63+
64+ val fact_n =
65+ Seq (Assign (" i" , Lit (1 )),
66+ Seq (Assign (" fact" , Lit (1 )),
67+ While (Op2 (" <=" , Var (" i" ), Var (" n" )),
68+ Seq (Assign (" fact" , Op2 (" *" , Var (" fact" ), Var (" i" ))),
69+ Assign (" i" , Op2 (" +" , Var (" i" ), Lit (1 )))))))
70+
71+ val w2 =
72+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
73+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
74+ Assign (" x" , Op2 (" -" , Var (" x" ), Lit (1 )))))
75+
76+ val w3 =
77+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
78+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
79+ While (Op2 (" <=" , Var (" i" ), Var (" x" )),
80+ Assign (" x" , Op2 (" -" , Var (" x" ), Lit (1 ))))))
81+
82+ val another_fact5 =
83+ let_(" i" , 1 ){ i =>
84+ let_(" fact" , 1 ){ fact =>
85+ while_(Op2 (" <=" , i, Lit (5 )),
86+ let_(" fact" , Op2 (" *" , fact, i)){ _ =>
87+ set_(" i" , Op2 (" +" , i, Lit (1 )))
88+ })}}
89+
90+
91+ // println(another_fact5)
92+ assert(fact5 == another_fact5)
93+
94+ val x = Var (" x" )
95+ val y = Var (" y" )
96+ val z = Var (" z" )
97+ val a = Var (" a" )
98+ val b = Var (" b" )
99+ val i = Var (" i" )
100+
101+ val cond1 =
102+ Cond (Op2 (" <=" , Lit (1 ), Lit (2 )),
103+ Assign (" x" , Lit (3 )),
104+ Assign (" x" , Lit (4 )))
105+
106+ /* if (x <= y) {
107+ * z = x
108+ * } else {
109+ * z = y
110+ * }
111+ * z = z + 1
112+ */
113+ val cond2 =
114+ Seq (Cond (Op2 (" <=" , Var (" x" ), Var (" y" )),
115+ Assign (" z" , Var (" x" )),
116+ Assign (" z" , Var (" y" ))),
117+ Assign (" z" , Op2 (" +" , Var (" z" ), Lit (1 ))))
118+
119+ /* if (x <= y) {
120+ * z = x
121+ * } else {
122+ * z = y
123+ * }
124+ * z = z - 1
125+ * if (z >= y) {
126+ * z = z * 2
127+ * } else {
128+ * z = z + 3
129+ * }
130+ */
131+ val cond3 =
132+ Seq (Cond (Op2 (" <=" , x, y),
133+ Assign (" z" , x),
134+ Assign (" z" , y)),
135+ Seq (Assign (" z" , Op2 (" -" , z, Lit (1 ))),
136+ Seq (Cond (Op2 (" >=" , z, y),
137+ Assign (" z" , Op2 (" *" , z, Lit (2 ))),
138+ Assign (" z" , Op2 (" +" , z, Lit (3 )))),
139+ Skip ())))
140+
141+ val condInput =
142+ Seq (Assign (" x" , Input ()),
143+ Seq (Cond (Op2 (" <=" , x, y),
144+ Assign (" z" , x),
145+ Assign (" z" , y)),
146+ Seq (Assign (" z" , Op2 (" +" , z, Lit (1 ))),
147+ Seq (Cond (Op2 (" >=" , z, y),
148+ Assign (" z" , Op2 (" +" , z, Lit (2 ))),
149+ Assign (" z" , Op2 (" +" , z, Lit (3 )))),
150+ Skip ()))))
151+
152+ val condAssert =
153+ Seq (Assign (" x" , Input ()),
154+ Seq (Assert (Op2 (" >=" , x, Lit (1 ))),
155+ Seq (Cond (Op2 (" <=" , x, y),
156+ Assign (" z" , x),
157+ Assign (" z" , y)),
158+ Seq (Assign (" z" , Op2 (" +" , z, Lit (1 ))),
159+ Seq (Cond (Op2 (" >=" , z, y),
160+ Assign (" z" , Op2 (" +" , z, Lit (2 ))),
161+ Assign (" z" , Op2 (" +" , z, Lit (3 )))),
162+ Skip ())))))
163+
164+ val unboundLoop =
165+ Seq (Assign (" i" , Input ()),
166+ While (Op2 (" <" , i, Lit (42 )),
167+ Assign (" i" , Op2 (" +" , i, Lit (1 )))))
168+ }
169+ }
170+
171+ import ImpLang ._
172+
173+ @ virtualize
174+ trait ImpureStagedImpSemantics extends SAIOps {
175+ trait Value
176+ def IntV (i : Rep [Int ]): Rep [Value ] = " IntV" .reflectWith[Value ](i)
177+ def BoolV (b : Rep [Boolean ]): Rep [Value ] = " BoolV" .reflectWith[Value ](b)
178+
179+ implicit def repIntProj (i : Rep [Value ]): Rep [Int ] = Unwrap (i) match {
180+ // case Adapter.g.Def("IntV", SList(v: Backend.Exp)) => Wrap[Int](v)
181+ case _ => " IntV-proj" .reflectWith[Int ](i)
182+ }
183+ implicit def repBoolProj (b : Rep [Value ]): Rep [Boolean ] = Unwrap (b) match {
184+ case Adapter .g.Def (" BoolV" , SList (v : Backend .Exp )) => Wrap [Boolean ](v)
185+ case _ => " BoolV-proj" .reflectWith[Boolean ](b)
186+ }
187+
188+ trait MutState
189+ def newMutState (kvs : (String , Rep [Value ])* ): Rep [MutState ] =
190+ " mutstate-new" .reflectMutableWith[MutState ](kvs.map({ case (k, v) => __liftTuple2RepLhs(k, v) }):_* )
191+ implicit class MutStateOps (s : Rep [MutState ]) {
192+ def apply (x : String ): Rep [Value ] = " mutstate-read" .reflectReadWith[Value ](s, x)(s)
193+ def += (x : String , v : Rep [Value ]): Rep [Unit ] = " mutstate-update" .reflectWriteWith[Unit ](s, x, v)(s)
194+ }
195+ def dummyRead (s : Rep [MutState ]): Rep [Unit ] = " mutstate-dummyread" .reflectRWWith[Unit ]()(s)(Adapter .CTRL )
196+
197+ def eval (e : Expr , σ : Rep [MutState ]): Rep [Value ] = e match {
198+ case Lit (i : Int ) => IntV (i)
199+ case Lit (b : Boolean ) => BoolV (b)
200+ case Var (x) => σ(x)
201+ case Op1 (" -" , e) =>
202+ val i : Rep [Int ] = eval(e, σ)
203+ IntV (- i)
204+ case Op2 (op, e1, e2) =>
205+ val i1 : Rep [Int ] = eval(e1, σ)
206+ val i2 : Rep [Int ] = eval(e2, σ)
207+ op match {
208+ case " +" => IntV (i1 + i2)
209+ case " -" => IntV (i1 - i2)
210+ case " *" => IntV (i1 * i2)
211+ case " ==" => BoolV (i1 == i2)
212+ case " <=" => BoolV (i1 <= i2)
213+ case " <" => BoolV (i1 < i2)
214+ case " >=" => BoolV (i1 >= i2)
215+ case " >" => BoolV (i1 > i2)
216+ }
217+ }
218+
219+ def exec (s : Stmt , σ : Rep [MutState ]): Rep [Unit ] = s match {
220+ case Skip () => ()
221+ case Assign (x, e) => σ += (x, eval(e, σ))
222+ case Cond (e, s1, s2) =>
223+ if (eval(e, σ)) exec(s1, σ) else exec(s2, σ)
224+ case Seq (s1, s2) => exec(s1, σ); exec(s2, σ)
225+ case While (e, b) => while (eval(e, σ)) exec(b, σ)
226+ }
227+ }
228+
229+ trait ImpureStagedImpGen extends SAICodeGenBase {
230+ override def traverse (n : Node ): Unit = n match {
231+ case Node (s, " mutstate-new" , kvs, _) =>
232+ es " val ${quote(s)} = Map[String, Value]( "
233+ kvs.zipWithIndex.map { case (kv, i) =>
234+ shallow(kv)
235+ if (i != kvs.length- 1 ) emit(" , " )
236+ }
237+ esln " ) "
238+ case Node (_, " mutstate-update" , List (s, x, v), _) => esln " $s( $x) = $v"
239+ case Node (_, " mutstate-dummyread" , _, _) => es " "
240+ case _ => super .traverse(n)
241+ }
242+ // shallow : code generation for pure node/expression
243+ override def shallow (n : Node ): Unit = n match {
244+ case Node (s, " IntV" , List (i), _) => es " IntV( $i) "
245+ case Node (s, " BoolV" , List (b), _) => es " BoolV( $b) "
246+ case Node (s, " IntV-proj" , List (i), _) => es " $i.I "
247+ case Node (s, " BoolV-proj" , List (i), _) => es " $i.B "
248+ case Node (_, " mutstate-read" , List (s, x), _) => es " $s( $x) "
249+ case _ => super .shallow(n)
250+ }
251+ }
252+
253+ trait ImpureStagedImpDriver [A , B ] extends SAIDriver [A , B ] with ImpureStagedImpSemantics { q =>
254+ override val codegen = new ScalaGenBase with ImpureStagedImpGen {
255+ val IR : q.type = q
256+ import IR ._
257+ override def remap (m : Manifest [_]): String = {
258+ if (m.toString.endsWith(" $Value" )) " Value"
259+ else if (m.toString.endsWith(" $MutState" )) " MutState"
260+ else super .remap(m)
261+ }
262+ }
263+
264+ override val prelude =
265+ """
266+ import scala.collection.mutable.Map
267+ import sai.lang.ImpLang._
268+ object Prelude {
269+ trait Value
270+ case class IntV(i: Int) extends Value
271+ case class BoolV(b: Boolean) extends Value
272+ implicit class ValueOps(v: Value) {
273+ def I: Int = v.asInstanceOf[IntV].i
274+ def B: Boolean = v.asInstanceOf[BoolV].b
275+ }
276+ }
277+ import Prelude._
278+ """
279+ }
280+
281+ object ImpureStagedImpTest {
282+ import ImpLang ._
283+ import ImpLang .Examples ._
284+ def main (args : Array [String ]): Unit = {
285+ val code = new ImpureStagedImpDriver [Int , Unit ] {
286+ @ virtualize
287+ def snippet (u : Rep [Int ]) = {
288+ // val st: Rep[MutState] = newMutState("x" -> IntV(3), "y" -> IntV(4))
289+ // exec(cond3, st)
290+ val st : Rep [MutState ] = newMutState()
291+ // exec(Seq(Assign("x", Lit(3)), Assign("y", Lit(4))), st)
292+ exec(fact5, st)
293+ dummyRead(st)
294+ println(st)
295+ }
296+ }
297+ println(code.code)
298+ // code.eval(0)
299+ }
300+ }
0 commit comments