@@ -194,6 +194,27 @@ struct Sema {
194194 }
195195 return e;
196196 }
197+ void expectBool (TreeRef anchor, int token) {
198+ if (token != TK_BOOL) {
199+ throw ErrorReport (anchor)
200+ << " expected boolean but found " << kindToString (token);
201+ }
202+ }
203+ TreeRef expectBool (TreeRef exp) {
204+ expectBool (exp, typeOfExpr (exp)->kind ());
205+ return exp;
206+ }
207+ TreeRef lookupVarOrCreateIndex (Ident ident) {
208+ TreeRef type = lookup (ident, false );
209+ if (!type) {
210+ // variable exp is not defined, so a reduction variable is created
211+ // a reduction variable index i
212+ type = indexType (ident);
213+ insert (index_env, ident, type, true );
214+ reduction_variables.push_back (ident);
215+ }
216+ return type;
217+ }
197218 TreeRef checkExp (TreeRef exp, bool allow_access) {
198219 switch (exp->kind ()) {
199220 case TK_APPLY: {
@@ -205,6 +226,7 @@ struct Sema {
205226 throw ErrorReport (exp)
206227 << " tensor accesses cannot be used in this context" ;
207228 }
229+
208230 // also handle built-in functions log, exp, etc.
209231 auto ident = a.name ();
210232 if (builtin_functions.count (ident.name ()) > 0 ) {
@@ -239,14 +261,7 @@ struct Sema {
239261 } break ;
240262 case TK_IDENT: {
241263 auto ident = Ident (exp);
242- TreeRef type = lookup (ident, false );
243- if (!type) {
244- // variable exp is not defined, so a reduction variable is created
245- // a reduction variable index i
246- type = indexType (exp);
247- insert (index_env, ident, type, true );
248- reduction_variables.push_back (exp);
249- }
264+ auto type = lookupVarOrCreateIndex (ident);
250265 if (type->kind () == TK_TENSOR_TYPE) {
251266 auto tt = TensorType (type);
252267 if (tt.dims ().size () != 0 ) {
@@ -276,6 +291,35 @@ struct Sema {
276291 exp->map ([&](TreeRef c) { return checkExp (c, allow_access); });
277292 return withType (nexp, matchAllTypes (nexp));
278293 } break ;
294+ case TK_EQ:
295+ case TK_NE:
296+ case TK_GE:
297+ case TK_LE:
298+ case ' <' :
299+ case ' >' : {
300+ auto nexp =
301+ exp->map ([&](TreeRef c) { return checkExp (c, allow_access); });
302+ // make sure the types match but the return type
303+ // is always bool
304+ matchAllTypes (nexp);
305+ return withType (nexp, boolType (exp));
306+ } break ;
307+ case TK_AND:
308+ case TK_OR:
309+ case ' !' : {
310+ auto nexp =
311+ exp->map ([&](TreeRef c) { return checkExp (c, allow_access); });
312+ expectBool (exp, matchAllTypes (nexp)->kind ());
313+ return withType (nexp, boolType (exp));
314+ } break ;
315+ case ' ?' : {
316+ auto nexp =
317+ exp->map ([&](TreeRef c) { return checkExp (c, allow_access); });
318+ expectBool (nexp->tree (0 ));
319+ auto rtype =
320+ match_types (typeOfExpr (nexp->tree (1 )), typeOfExpr (nexp->tree (2 )));
321+ return withType (nexp, rtype);
322+ }
279323 case TK_CONST: {
280324 auto c = Const (exp);
281325 return withType (exp, c.type ());
@@ -322,7 +366,10 @@ struct Sema {
322366 TreeRef floatType (TreeRef anchor) {
323367 return c (TK_FLOAT, anchor->range (), {});
324368 }
325- void checkDim (const Ident& dim) {
369+ TreeRef boolType (TreeRef anchor) {
370+ return c (TK_BOOL, anchor->range (), {});
371+ }
372+ void checkDim (Ident dim) {
326373 insert (env, dim, dimType (dim), false );
327374 }
328375 TreeRef checkTensorType (TreeRef type) {
@@ -354,6 +401,33 @@ struct Sema {
354401 }
355402 return List::create (list->range (), std::move (r));
356403 }
404+ TreeRef checkRangeConstraint (RangeConstraint rc) {
405+ // RCs are checked _before_ the rhs of the TC, so
406+ // it is possible the index is not in the environment yet
407+ // calling lookupOrCreate ensures it exists
408+ lookupVarOrCreateIndex (rc.ident ());
409+ // calling looking directly in the index_env ensures that
410+ // we are actually constraining an index and not some other variable
411+ lookup (index_env, rc.ident (), true );
412+ auto s = expectIntegral (checkExp (rc.start (), false ));
413+ auto e = expectIntegral (checkExp (rc.end (), false ));
414+ return RangeConstraint::create (rc.range (), rc.ident (), s, e);
415+ }
416+ TreeRef checkLet (Let l) {
417+ auto rhs = checkExp (l.rhs (), true );
418+ insert (let_env, l.name (), typeOfExpr (rhs), true );
419+ return Let::create (l.range (), l.name (), rhs);
420+ }
421+ TreeRef checkWhereClause (TreeRef ref) {
422+ if (ref->kind () == TK_LET) {
423+ return checkLet (Let (ref));
424+ } else if (ref->kind () == TK_EXISTS) {
425+ auto exp = checkExp (Exists (ref).exp (), true );
426+ return Exists::create (ref->range (), exp);
427+ } else {
428+ return checkRangeConstraint (RangeConstraint (ref));
429+ }
430+ }
357431 TreeRef checkStmt (TreeRef stmt_) {
358432 auto stmt = Comprehension (stmt_);
359433
@@ -374,6 +448,11 @@ struct Sema {
374448 output_indices.push_back (new_var);
375449 }
376450
451+ // where clauses are checked _before_ the rhs because they
452+ // introduce let bindings that are in scope for the rhs
453+ auto where_clauses_ = stmt.whereClauses ().map (
454+ [&](TreeRef rc) { return checkWhereClause (rc); });
455+
377456 TreeRef rhs_ = checkExp (stmt.rhs (), true );
378457 TreeRef scalar_type = typeOfExpr (rhs_);
379458
@@ -408,20 +487,11 @@ struct Sema {
408487 // if we redefined an input, it is no longer valid for range expressions
409488 live_input_names.erase (stmt.ident ().name ());
410489
411- auto range_constraints =
412- stmt.rangeConstraints ().map ([&](const RangeConstraint& rc) {
413- lookup (index_env, rc.ident (), true );
414- auto s = expectIntegral (checkExp (rc.start (), false ));
415- auto e = expectIntegral (checkExp (rc.end (), false ));
416- return RangeConstraint::create (rc.range (), rc.ident (), s, e);
417- });
418-
419- auto equivalent_statement_ =
420- stmt.equivalent ().map ([&](const Equivalent& eq) {
421- auto indices_ = eq.accesses ().map (
422- [&](TreeRef index) { return checkExp (index, true ); });
423- return Equivalent::create (eq.range (), eq.name (), indices_);
424- });
490+ auto equivalent_statement_ = stmt.equivalent ().map ([&](Equivalent eq) {
491+ auto indices_ = eq.accesses ().map (
492+ [&](TreeRef index) { return checkExp (index, true ); });
493+ return Equivalent::create (eq.range (), eq.name (), indices_);
494+ });
425495
426496 TreeRef assignment = stmt.assignment ();
427497 // For semantic consistency we allow overwriting reductions like +=!
@@ -446,13 +516,16 @@ struct Sema {
446516 stmt.indices (),
447517 stmt.assignment (),
448518 rhs_,
449- range_constraints ,
519+ where_clauses_ ,
450520 equivalent_statement_,
451521 reduction_variable_list);
522+ // clear the per-statement environments to get ready for the next statement
452523 index_env.clear ();
524+ let_env.clear ();
525+
453526 return result;
454527 }
455- bool isNotInplace (const TreeRef& assignment) {
528+ bool isNotInplace (TreeRef assignment) {
456529 switch (assignment->kind ()) {
457530 case TK_PLUS_EQ_B:
458531 case TK_TIMES_EQ_B:
@@ -493,13 +566,15 @@ struct Sema {
493566 throw ErrorReport (ident) << name << " already defined" ;
494567 }
495568 }
496- TreeRef lookup (const Ident& ident, bool required) {
569+ TreeRef lookup (Ident ident, bool required) {
497570 TreeRef v = lookup (index_env, ident, false );
571+ if (!v)
572+ v = lookup (let_env, ident, false );
498573 if (!v)
499574 v = lookup (env, ident, required);
500575 return v;
501576 }
502- TreeRef lookup (Env& the_env, const Ident& ident, bool required) {
577+ TreeRef lookup (Env& the_env, Ident ident, bool required) {
503578 std::string name = ident.name ();
504579 auto it = the_env.find (name);
505580 if (required && it == the_env.end ()) {
@@ -517,6 +592,7 @@ struct Sema {
517592
518593 std::vector<TreeRef> reduction_variables; // per-statement
519594 Env index_env; // per-statement
595+ Env let_env; // per-statement, used for where i = <exp>
520596
521597 Env env; // name -> type
522598 Env annotated_output_types; // name -> type, for all annotated returns types
0 commit comments