@@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
6262 }
6363}
6464
65+ // Translate the TC def input params to corresponding Halide components.
66+ // params, inputs will be populated here.
6567void translateParam (
6668 const lang::Param& p,
6769 map<string, Parameter>* params,
6870 vector<ImageParam>* inputs) {
71+ // Check if the param has already been converted to halide components.
6972 if (params->find (p.ident ().name ()) != params->end ()) {
7073 return ;
71- } else {
72- lang::TensorType type = p.tensorType ();
73- int dimensions = (int )type.dims ().size ();
74- ImageParam imageParam (
75- translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
76- inputs->push_back (imageParam);
77- vector<Expr> dims;
78- for (auto d_ : type.dims ()) {
79- if (d_->kind () == lang::TK_IDENT) {
80- auto d = lang::Ident (d_);
81- auto it = params->find (d.name ());
82- Parameter p;
83- if (it != params->end ()) {
84- p = it->second ;
85- } else {
86- p = Parameter (Int (32 ), false , 0 , d.name (), true );
87- (*params)[d.name ()] = p;
88- }
89- dims.push_back (Variable::make (Int (32 ), p.name (), p));
74+ }
75+ lang::TensorType type = p.tensorType ();
76+ int dimensions = (int )type.dims ().size ();
77+ ImageParam imageParam (
78+ translateScalarType (type.scalarType ()), dimensions, p.ident ().name ());
79+ inputs->push_back (imageParam);
80+ vector<Expr> dims;
81+ for (auto d_ : type.dims ()) {
82+ if (d_->kind () == lang::TK_IDENT) {
83+ auto d = lang::Ident (d_);
84+ auto it = params->find (d.name ());
85+ Parameter p;
86+ if (it != params->end ()) {
87+ p = it->second ;
9088 } else {
91- CHECK (d_->kind () == lang::TK_CONST);
92- int32_t value = lang::Const (d_).value ();
93- dims.push_back (Expr (value));
89+ p = Parameter (Int (32 ), false , 0 , d.name (), true );
90+ (*params)[d.name ()] = p;
9491 }
92+ dims.push_back (Variable::make (Int (32 ), p.name (), p));
93+ } else {
94+ CHECK (d_->kind () == lang::TK_CONST);
95+ int32_t value = lang::Const (d_).value ();
96+ dims.push_back (Expr (value));
9597 }
98+ }
9699
97- for (int i = 0 ; i < imageParam.dimensions (); i++) {
98- imageParam.dim (i).set_bounds (0 , dims[i]);
99- }
100- (*params)[imageParam.name ()] = imageParam.parameter ();
100+ for (int i = 0 ; i < imageParam.dimensions (); i++) {
101+ imageParam.dim (i).set_bounds (0 , dims[i]);
101102 }
103+ (*params)[imageParam.name ()] = imageParam.parameter ();
102104}
103105
104106void translateOutput (
@@ -156,6 +158,8 @@ Expr translateExpr(
156158 return t (0 ) * t (1 );
157159 case ' /' :
158160 return t (0 ) / t (1 );
161+ case ' %' :
162+ return t (0 ) % t (1 );
159163 case lang::TK_MIN:
160164 return min (t (0 ), t (1 ));
161165 case lang::TK_MAX:
@@ -488,22 +492,25 @@ Expr reductionUpdate(Expr e) {
488492 return Call::make (e.type (), kReductionUpdate , {e}, Call::Intrinsic);
489493}
490494
495+ // Translate a single TC comprehension/statement to Halide components: funcs,
496+ // bounds, reductions.
497+ //
491498// Note that the function definitions created by translateComprehension may
492499// contain kReductionUpdate intrinsics. These may have to be removed
493500// in order to be able to apply internal Halide analysis passes on them.
494501void translateComprehension (
495- const lang::Comprehension& c ,
502+ const lang::Comprehension& comprehension ,
496503 const map<string, Parameter>& params,
497504 bool throwWarnings,
498505 map<string, Function>* funcs,
499506 FunctionBounds* bounds) {
500507 Function f;
501- auto it = funcs->find (c .ident ().name ());
508+ auto it = funcs->find (comprehension .ident ().name ());
502509 if (it != funcs->end ()) {
503510 f = it->second ;
504511 } else {
505- f = Function (c .ident ().name ());
506- (*funcs)[c .ident ().name ()] = f;
512+ f = Function (comprehension .ident ().name ());
513+ (*funcs)[comprehension .ident ().name ()] = f;
507514 }
508515 // Function is the internal Halide IR type for a pipeline
509516 // stage. Func is the front-end class that wraps it. Here it's
@@ -512,7 +519,7 @@ void translateComprehension(
512519
513520 vector<Var> lhs;
514521 vector<Expr> lhs_as_exprs;
515- for (lang::Ident id : c .indices ()) {
522+ for (lang::Ident id : comprehension .indices ()) {
516523 lhs.push_back (Var (id.name ()));
517524 lhs_as_exprs.push_back (lhs.back ());
518525 }
@@ -521,17 +528,17 @@ void translateComprehension(
521528 // in the future we may consider using Halide Let bindings when they
522529 // are supported later
523530 map<string, Expr> lets;
524- for (auto wc : c .whereClauses ()) {
531+ for (auto wc : comprehension .whereClauses ()) {
525532 if (wc->kind () == lang::TK_LET) {
526533 auto let = lang::Let (wc);
527534 lets[let.name ().name ()] = translateExpr (let.rhs (), params, *funcs, lets);
528535 }
529536 }
530537
531- Expr rhs = translateExpr (c .rhs (), params, *funcs, lets);
538+ Expr rhs = translateExpr (comprehension .rhs (), params, *funcs, lets);
532539
533540 std::vector<Expr> all_exprs;
534- for (auto wc : c .whereClauses ()) {
541+ for (auto wc : comprehension .whereClauses ()) {
535542 if (wc->kind () == lang::TK_EXISTS) {
536543 all_exprs.push_back (
537544 translateExpr (lang::Exists (wc).exp (), params, *funcs, lets));
@@ -555,7 +562,7 @@ void translateComprehension(
555562 // values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
556563 // for the reduction and then applies the reduction.
557564 bool should_zero = false ;
558- switch (c .assignment ()->kind ()) {
565+ switch (comprehension .assignment ()->kind ()) {
559566 case lang::TK_PLUS_EQ_B:
560567 should_zero = true ; // fallthrough
561568 case lang::TK_PLUS_EQ:
@@ -587,12 +594,13 @@ void translateComprehension(
587594 case ' =' :
588595 break ;
589596 default :
590- throw lang::ErrorReport (c) << " Unimplemented reduction "
591- << c.assignment ()->range ().text () << " \n " ;
597+ throw lang::ErrorReport (comprehension)
598+ << " Unimplemented reduction "
599+ << comprehension.assignment ()->range ().text () << " \n " ;
592600 }
593601
594602 // Tag reductions as such
595- if (c .assignment ()->kind () != ' =' ) {
603+ if (comprehension .assignment ()->kind () != ' =' ) {
596604 rhs = reductionUpdate (rhs);
597605 }
598606
@@ -632,7 +640,7 @@ void translateComprehension(
632640 Scope<Interval> solution;
633641
634642 // Put anything explicitly specified with a 'where' class in the solution
635- for (auto constraint_ : c .whereClauses ()) {
643+ for (auto constraint_ : comprehension .whereClauses ()) {
636644 if (constraint_->kind () != lang::TK_RANGE_CONSTRAINT)
637645 continue ;
638646 auto constraint = lang::RangeConstraint (constraint_);
@@ -653,7 +661,8 @@ void translateComprehension(
653661
654662 // Infer the rest
655663 all_exprs.push_back (rhs);
656- forwardBoundsInference (all_exprs, *bounds, c, throwWarnings, &solution);
664+ forwardBoundsInference (
665+ all_exprs, *bounds, comprehension, throwWarnings, &solution);
657666
658667 // TODO: What if subsequent updates have incompatible bounds
659668 // (e.g. an in-place stencil)?. The .bound directive will use the
@@ -664,7 +673,7 @@ void translateComprehension(
664673
665674 for (Var v : lhs) {
666675 if (!solution.contains (v.name ())) {
667- throw lang::ErrorReport (c )
676+ throw lang::ErrorReport (comprehension )
668677 << " Free variable " << v
669678 << " was not solved in range inference. May not be used right-hand side" ;
670679 }
@@ -688,7 +697,7 @@ void translateComprehension(
688697 for (size_t i = 0 ; i < unbound.size (); i++) {
689698 auto v = unbound[unbound.size () - 1 - i];
690699 if (!solution.contains (v->name )) {
691- throw lang::ErrorReport (c )
700+ throw lang::ErrorReport (comprehension )
692701 << " Free variable " << v << " is unconstrained. "
693702 << " Use a 'where' clause to set its range." ;
694703 }
@@ -736,6 +745,7 @@ void translateComprehension(
736745 stage.reorder (loop_nest);
737746}
738747
748+ // Translate a semantically checked TC def to HalideComponents struct.
739749HalideComponents translateDef (const lang::Def& def, bool throwWarnings) {
740750 map<string, Function> funcs;
741751 HalideComponents components;
@@ -895,6 +905,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
895905 lang::Def (lang::Sema ().checkFunction (treeRef)), throwWarnings);
896906}
897907
908+ // NOTE: there is no guarantee here that the tc string has only one def. It
909+ // could have many defs. Only first def will be converted in that case.
898910HalideComponents
899911translate (isl::ctx ctx, const std::string& tc, bool throwWarnings) {
900912 LOG_IF (INFO, tc::FLAGS_debug_halide) << tc;
0 commit comments