@@ -166,6 +166,7 @@ struct Sema {
166166 }
167167 return expr_to_type.at (ref);
168168 }
169+
169170 // associate a type with this expression
170171 TreeRef withType (TreeRef expr, TreeRef type) {
171172 auto inserted = expr_to_type.emplace (expr, type).second ;
@@ -179,6 +180,7 @@ struct Sema {
179180 }
180181 return TensorType (typ);
181182 }
183+
182184 TreeRef matchAllTypes (TreeRef list, TreeRef matched_type = nullptr ) {
183185 for (auto e : list->trees ()) {
184186 if (!matched_type)
@@ -188,23 +190,27 @@ struct Sema {
188190 }
189191 return matched_type;
190192 }
193+
191194 TreeRef expectIntegral (TreeRef e) {
192195 if (TypeInfo (typeOfExpr (e)).code () == TypeInfo::Float) {
193196 throw ErrorReport (e) << " expected integral type but found "
194197 << kindToString (typeOfExpr (e)->kind ());
195198 }
196199 return e;
197200 }
201+
198202 void expectBool (TreeRef anchor, int token) {
199203 if (token != TK_BOOL) {
200204 throw ErrorReport (anchor)
201205 << " expected boolean but found " << kindToString (token);
202206 }
203207 }
208+
204209 TreeRef expectBool (TreeRef exp) {
205210 expectBool (exp, typeOfExpr (exp)->kind ());
206211 return exp;
207212 }
213+
208214 TreeRef lookupVarOrCreateIndex (Ident ident) {
209215 TreeRef type = lookup (ident, false );
210216 if (!type) {
@@ -216,6 +222,7 @@ struct Sema {
216222 }
217223 return type;
218224 }
225+
219226 TreeRef checkExp (TreeRef exp, bool allow_access) {
220227 switch (exp->kind ()) {
221228 case TK_APPLY: {
@@ -339,6 +346,7 @@ struct Sema {
339346 throw ErrorReport (exp) << " NYI - semantic checking for " << exp;
340347 }
341348 }
349+
342350 // This is the entry function for semantic analysis. It is called by
343351 // tc2halide to associate type with each node of the tree and to also make
344352 // sure that the tree is sematically correct. For example: a variable
@@ -352,7 +360,7 @@ struct Sema {
352360 //
353361 // Type checking is also done by small amount of code
354362 //
355- // The method 'withType' can be used to associate the type with a given node
363+ // The method 'withType' is used to associate the type with a given node
356364 //
357365 TreeRef checkFunction (TreeRef func_) {
358366 auto func = Def (func_);
@@ -385,21 +393,27 @@ struct Sema {
385393 Def::create (func.range (), func.name (), params_, returns_, statements_);
386394 return r;
387395 }
396+
388397 TreeRef indexType (TreeRef anchor) {
389- return c (TK_INT32, anchor->range (), {});
398+ return createCompound (TK_INT32, anchor->range (), {});
390399 }
400+
391401 TreeRef dimType (TreeRef anchor) {
392402 return indexType (anchor);
393403 }
404+
394405 TreeRef floatType (TreeRef anchor) {
395- return c (TK_FLOAT, anchor->range (), {});
406+ return createCompound (TK_FLOAT, anchor->range (), {});
396407 }
408+
397409 TreeRef boolType (TreeRef anchor) {
398- return c (TK_BOOL, anchor->range (), {});
410+ return createCompound (TK_BOOL, anchor->range (), {});
399411 }
412+
400413 void checkDim (Ident dim) {
401414 insert (env, dim, dimType (dim), false );
402415 }
416+
403417 TreeRef checkTensorType (TreeRef type) {
404418 auto tt = TensorType (type);
405419 for (const auto & d : tt.dims ()) {
@@ -409,18 +423,21 @@ struct Sema {
409423 }
410424 return type;
411425 }
426+
412427 TreeRef checkParam (TreeRef param) {
413428 auto p = Param (param);
414429 TreeRef type_ = checkTensorType (p.type ());
415430 insert (env, p.ident (), type_, true );
416431 live_input_names.insert (p.ident ().name ());
417432 return param;
418433 }
434+
419435 TreeRef checkReturn (TreeRef ret) {
420436 auto r = Param (ret);
421437 TreeRef real_type = lookup (env, r.ident (), true );
422438 return ret;
423439 }
440+
424441 TreeRef checkList (TreeRef list, std::function<TreeRef(TreeRef)> fn) {
425442 TC_ASSERT (list, list->kind () == TK_LIST);
426443 TreeList r;
@@ -429,6 +446,7 @@ struct Sema {
429446 }
430447 return List::create (list->range (), std::move (r));
431448 }
449+
432450 TreeRef checkRangeConstraint (RangeConstraint rc) {
433451 // RCs are checked _before_ the rhs of the TC, so
434452 // it is possible the index is not in the environment yet
@@ -441,11 +459,13 @@ struct Sema {
441459 auto e = expectIntegral (checkExp (rc.end (), false ));
442460 return RangeConstraint::create (rc.range (), rc.ident (), s, e);
443461 }
462+
444463 TreeRef checkLet (Let l) {
445464 auto rhs = checkExp (l.rhs (), true );
446465 insert (let_env, l.name (), typeOfExpr (rhs), true );
447466 return Let::create (l.range (), l.name (), rhs);
448467 }
468+
449469 TreeRef checkWhereClause (TreeRef ref) {
450470 if (ref->kind () == TK_LET) {
451471 return checkLet (Let (ref));
@@ -456,6 +476,7 @@ struct Sema {
456476 return checkRangeConstraint (RangeConstraint (ref));
457477 }
458478 }
479+
459480 // Semantic checking for the statements/comprehensions in a TC Def.
460481 TreeRef checkStmt (TreeRef stmt_) {
461482 auto stmt = Comprehension (stmt_);
@@ -467,11 +488,13 @@ struct Sema {
467488 insert (index_env, index, typ, true );
468489 }
469490
470- // make dimension variables for each dimension of the output tensor
491+ // check that the input is not used for output - inputs are immutable
471492 std::string name = stmt.ident ().name ();
472493 if (inputParameters.count (name) > 0 ) {
473494 throw ErrorReport (stmt_) << " TC inputs are immutable" ;
474495 }
496+
497+ // make dimension variables for each dimension of the output tensor
475498 TreeList output_indices;
476499 int n = stmt.indices ().size ();
477500 for (int i = 0 ; i < n; ++i) {
@@ -578,6 +601,7 @@ struct Sema {
578601
579602 return result;
580603 }
604+
581605 static bool isUninitializedReductionOperation (TreeRef assignment) {
582606 switch (assignment->kind ()) {
583607 case TK_PLUS_EQ:
@@ -589,6 +613,7 @@ struct Sema {
589613 return false ;
590614 }
591615 }
616+
592617 bool isNotInplace (TreeRef assignment) {
593618 switch (assignment->kind ()) {
594619 case TK_PLUS_EQ_B:
@@ -600,6 +625,7 @@ struct Sema {
600625 return false ;
601626 }
602627 }
628+
603629 std::string dumpEnv () {
604630 std::stringstream ss;
605631 std::vector<std::pair<std::string, TreeRef>> elems (env.begin (), env.end ());
@@ -618,6 +644,7 @@ struct Sema {
618644
619645 private:
620646 using Env = std::unordered_map<std::string, TreeRef>;
647+
621648 void
622649 insert (Env& the_env, Ident ident, TreeRef value, bool must_be_undefined) {
623650 std::string name = ident.name ();
@@ -630,6 +657,7 @@ struct Sema {
630657 throw ErrorReport (ident) << name << " already defined" ;
631658 }
632659 }
660+
633661 TreeRef lookup (Ident ident, bool required) {
634662 TreeRef v = lookup (index_env, ident, false );
635663 if (!v)
@@ -638,6 +666,7 @@ struct Sema {
638666 v = lookup (env, ident, required);
639667 return v;
640668 }
669+
641670 TreeRef lookup (Env& the_env, Ident ident, bool required) {
642671 std::string name = ident.name ();
643672 auto it = the_env.find (name);
@@ -647,10 +676,12 @@ struct Sema {
647676 }
648677 return it == the_env.end () ? nullptr : it->second ;
649678 }
650- TreeRef c (int kind, const SourceRange& range, TreeList&& trees) {
679+
680+ TreeRef createCompound (int kind, const SourceRange& range, TreeList&& trees) {
651681 return Compound::create (kind, range, std::move (trees));
652682 }
653- TreeRef s (const std::string& s) {
683+
684+ TreeRef createString (const std::string& s) {
654685 return String::create (s);
655686 }
656687
0 commit comments