@@ -269,7 +269,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
269269 val levelScope = getExactScope(currentScope)(result) // could provide as input ...
270270 // TODO: cannot in general fuse several effect loops (one effectful and several pure ones is ok though)
271271 // so we need a strategy. a simple one would be exclude all effectful loops right away (TODO).
272- levelScope collect { case e @ TTP (_, _, SimpleFatLoop (_,_,_)) => e }
272+ levelScope collect { case e @ TTP (_, SimpleFatLoop (_,_,_)) => e }
273273 }
274274
275275 // FIXME: more than one super call means exponential cost -- is there a better way?
@@ -284,7 +284,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
284284 var done = false
285285
286286 // keep track of loops in inner scopes
287- var UloopSyms = currentScope collect { case e @ TTP (lhs, _, SimpleFatLoop (_,_,_)) if ! Wloops .contains(e) => lhs }
287+ var UloopSyms = currentScope collect { case e @ TTP (_, SimpleFatLoop (_,_,_)) if ! Wloops .contains(e) => e. lhs }
288288
289289 // do{
290290
@@ -367,7 +367,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
367367 var partitionsIn = Wloops
368368 var partitionsOut = Nil : List [Stm ]
369369
370- for (b@ TTP (_,_,_ ) <- partitionsIn) {
370+ for (b@ TTP (_,_) <- partitionsIn) {
371371 // try to add to an item in partitionsOut, if not possible add as-is
372372 partitionsOut.find(a => canFuse(a,b)) match {
373373 case Some (a : TTP ) =>
@@ -394,13 +394,14 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
394394 shapeA
395395 }
396396
397- val lhs = a.lhs ++ b.lhs
397+ val tps = a.tps ++ b.tps
398398
399- val fused = TTP (lhs, a.mhs ++ b.mhs , SimpleFatLoop (shape, targetVar, WgetLoopRes (a)::: WgetLoopRes (b)))
399+ val fused = TTP (tps , SimpleFatLoop (shape, targetVar, WgetLoopRes (a)::: WgetLoopRes (b)))
400400 partitionsOut = fused :: (partitionsOut diff List (a))
401401
402- val preNeg = WtableNeg collect { case p if (lhs contains p._2) => p._1 }
403- val postNeg = WtableNeg collect { case p if (lhs contains p._1) => p._2 }
402+ val syms = tps.map(_.sym).toSet
403+ val preNeg = WtableNeg collect { case p if (syms contains p._2) => p._1 }
404+ val postNeg = WtableNeg collect { case p if (syms contains p._1) => p._2 }
404405
405406 val fusedNeg = preNeg flatMap { s1 => postNeg map { s2 => (s1,s2) } }
406407 WtableNeg = (fusedNeg ++ WtableNeg ).distinct
@@ -461,10 +462,10 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
461462
462463 // prune Wloops (some might be no longer necessary)
463464 Wloops = pOutT map {
464- case TTP (lhs, mhs , SimpleFatLoop (s, x, rhs)) =>
465- val ex = lhs map (s => currentScope exists (_.lhs contains s))
465+ case TTP (tps , SimpleFatLoop (s, x, rhs)) =>
466+ val ex = tps map (s => currentScope exists (_.lhs contains s.sym ))
466467 def select [A ](a : List [A ], b : List [Boolean ]) = (a zip b) collect { case (w, true ) => w }
467- TTP (select(lhs, ex), select(mhs , ex), SimpleFatLoop (s, x, select(rhs, ex)))
468+ TTP (select(tps , ex), SimpleFatLoop (s, x, select(rhs, ex)))
468469 }
469470
470471 currentScope = (currentScope diff pInT) ++ Wloops
@@ -520,7 +521,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
520521 val levelScope = getExactScope(currentScope)(result) // could provide as input ...
521522 // TODO: cannot in general fuse several effect loops (one effectful and several pure ones is ok though)
522523 // so we need a strategy. a simple one would be exclude all effectful loops right away (TODO).
523- levelScope collect { case e @ TTP(_, _, SimpleFatLoop(_,_,_)) => e }
524+ levelScope collect { case e @ TTP(_, SimpleFatLoop(_,_,_)) => e }
524525 }
525526
526527 // FIXME: more than one super call means exponential cost -- is there a better way?
@@ -535,7 +536,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
535536 var done = false
536537
537538 // keep track of loops in inner scopes
538- var UloopSyms = currentScope collect { case e @ TTP(lhs, _, SimpleFatLoop(_,_,_)) if !Wloops.contains(e) => lhs }
539+ var UloopSyms = currentScope collect { case e @ TTP(lhs, SimpleFatLoop(_,_,_)) if !Wloops.contains(e) => lhs.map(_.sym) }
539540
540541 do {
541542 // utils
@@ -630,7 +631,7 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
630631 var partitionsIn = Wloops
631632 var partitionsOut = Nil:List[Stm]
632633
633- for (b@ TTP(_,_,_ ) <- partitionsIn) {
634+ for (b@ TTP(_,_) <- partitionsIn) {
634635 // try to add to an item in partitionsOut, if not possible add as-is
635636 partitionsOut.find(a => canFuse(a,b)) match {
636637 case Some(a: TTP) =>
@@ -659,11 +660,12 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
659660
660661 val lhs = a.lhs ++ b.lhs
661662
662- val fused = TTP(lhs, a.mhs ++ b.mhs, SimpleFatLoop(shape, targetVar, WgetLoopRes(a):::WgetLoopRes(b)))
663+ val fused = TTP(lhs, SimpleFatLoop(shape, targetVar, WgetLoopRes(a):::WgetLoopRes(b)))
663664 partitionsOut = fused :: (partitionsOut diff List(a))
664665
665- val preNeg = WtableNeg collect { case p if (lhs contains p._2) => p._1 }
666- val postNeg = WtableNeg collect { case p if (lhs contains p._1) => p._2 }
666+ val syms = lhs.map(_.sym).toSet
667+ val preNeg = WtableNeg collect { case p if (syms contains p._2) => p._1 }
668+ val postNeg = WtableNeg collect { case p if (syms contains p._1) => p._2 }
667669
668670 val fusedNeg = preNeg flatMap { s1 => postNeg map { s2 => (s1,s2) } }
669671 WtableNeg = (fusedNeg ++ WtableNeg).distinct
@@ -722,19 +724,19 @@ trait LoopFusionCore extends internal.FatScheduling with CodeMotion with Simplif
722724
723725 // prune Wloops (some might be no longer necessary)
724726 Wloops = Wloops map {
725- case TTP(lhs, mhs, SimpleFatLoop(s, x, rhs)) =>
726- val ex = lhs map (s => currentScope exists (_.lhs == List(s)))
727+ case TTP(lhs, SimpleFatLoop(s, x, rhs)) =>
728+ val ex = lhs map (s => currentScope exists (_.lhs == List(s.sym )))
727729 def select[A](a: List[A], b: List[Boolean]) = (a zip b) collect { case (w, true) => w }
728- TTP(select(lhs, ex), select(mhs, ex), SimpleFatLoop(s, x, select(rhs, ex)))
730+ TTP(select(lhs, ex), SimpleFatLoop(s, x, select(rhs, ex)))
729731 }
730732
731733 // PREVIOUS PROBLEM: don't throw out all loops, might have some that are *not* in levelScope
732734 // note: if we don't do it here, we will likely see a problem going back to innerScope in
733735 // FatCodegen.focusExactScopeFat below. --> how to go back from SimpleFatLoop to VectorPlus??
734736 // UPDATE: UloopSyms puts a tentative fix in place. check if it is sufficient!!
735737 // what is the reason we cannot just look at Wloops??
736- currentScope = currentScope.filter { case e@TTP(lhs, _, _ : AbstractFatLoop) =>
737- val keep = UloopSyms contains lhs
738+ currentScope = currentScope.filter { case e@TTP(lhs, _: AbstractFatLoop) =>
739+ val keep = UloopSyms contains lhs.map(_.sym)
738740 //if (!keep) println("dropping: " + e + ", not int UloopSyms: " + UloopSyms)
739741 keep case _ => true } ::: Wloops
740742
0 commit comments