Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 207ea4b

Browse files
committed
tc2halide: put throwWarnings in CompilerOptions and use the latter
Most functions in the tc2halide namespace take a boolean argument that controls whether the conversion warnings should be treated as exceptions. Move this flag to CompilerOptions and pass these options instead. Propagate the options to Scop::makeScop calls used only in tests.
1 parent b13de92 commit 207ea4b

File tree

13 files changed

+71
-37
lines changed

13 files changed

+71
-37
lines changed

tc/aten/aten_compiler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tc/aten/aten.h"
2323
#include "tc/core/tensor.h"
2424
#include "tc/core/utils/time.h"
25+
#include "tc/utils/compiler_options.h"
2526

2627
namespace tc {
2728
namespace aten {

tc/core/compiler-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
5555

5656
auto inputsInfo = makeTensorInfoVector(inputs);
5757
auto outputsInfo = detail::inferOutputTensorInfo(tcDefinition, inputs);
58-
auto halideComponents =
59-
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tcDefinition);
58+
auto halideComponents = tc2halide::translate(
59+
isl::with_exceptions::globalIslCtx(), tcDefinition, CompilerOptions());
6060
detail::checkInputsCompliant(halideComponents, inputs);
6161

6262
auto tcName = lang::Def(tcDefinition).name().name();

tc/core/compiler.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ std::vector<TensorInfo> inferOutputTensorInfo(
103103
lang::TreeRef tcDefinition,
104104
const std::vector<const DLConstTensor*> inputs) {
105105
return tc::inferOutputTensorInfo(
106-
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tcDefinition),
106+
tc2halide::translate(
107+
isl::with_exceptions::globalIslCtx(),
108+
tcDefinition,
109+
CompilerOptions()),
107110
inputs);
108111
}
109112

tc/core/polyhedral/scop.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "tc/core/polyhedral/schedule_utils.h"
3535
#include "tc/core/scope_guard.h"
3636
#include "tc/core/tc2halide.h"
37+
#include "tc/utils/compiler_options.h"
3738

3839
using namespace std;
3940

@@ -69,12 +70,18 @@ ScopUPtr Scop::makeScop(
6970
return scop;
7071
}
7172

72-
ScopUPtr Scop::makeScop(isl::ctx ctx, const string& tc) {
73-
return makeScop(ctx, tc2halide::translate(ctx, tc));
73+
ScopUPtr Scop::makeScop(
74+
isl::ctx ctx,
75+
const string& tc,
76+
const CompilerOptions& compilerOptions) {
77+
return makeScop(ctx, tc2halide::translate(ctx, tc, compilerOptions));
7478
}
7579

76-
ScopUPtr Scop::makeScop(isl::ctx ctx, const lang::TreeRef& treeRef) {
77-
return makeScop(ctx, tc2halide::translate(ctx, treeRef));
80+
ScopUPtr Scop::makeScop(
81+
isl::ctx ctx,
82+
const lang::TreeRef& treeRef,
83+
const CompilerOptions& compilerOptions) {
84+
return makeScop(ctx, tc2halide::translate(ctx, treeRef, compilerOptions));
7885
}
7986

8087
isl::union_set& Scop::domainRef() {

tc/core/polyhedral/scop.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "tc/core/tc2halide.h"
3333
#include "tc/core/tensor.h"
3434
#include "tc/external/isl.h"
35+
#include "tc/utils/compiler_options.h"
3536

3637
namespace tc {
3738
namespace polyhedral {
@@ -56,11 +57,15 @@ struct Scop {
5657
// Halide IR is constructed and made a member by setting halideComponents.
5758
// These operations are grouped and scheduled in a halide::Stmt which becomes
5859
// the unit from which the scop is constructed.
59-
static std::unique_ptr<Scop> makeScop(isl::ctx ctx, const std::string& tc);
60+
static std::unique_ptr<Scop> makeScop(
61+
isl::ctx ctx,
62+
const std::string& tc,
63+
const CompilerOptions& compilerOptions = CompilerOptions());
6064

6165
static std::unique_ptr<Scop> makeScop(
6266
isl::ctx ctx,
63-
const lang::TreeRef& treeRef);
67+
const lang::TreeRef& treeRef,
68+
const CompilerOptions& compilerOptions = CompilerOptions());
6469

6570
// Clone a Scop
6671
static std::unique_ptr<Scop> makeScop(const Scop& scop) {

tc/core/tc2halide.cc

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "tc/core/tc2halide.h"
2121
#include "tc/lang/parser.h"
2222
#include "tc/lang/sema.h"
23+
#include "tc/utils/compiler_options.h"
2324

2425
namespace tc2halide {
2526

@@ -270,7 +271,7 @@ void forwardBoundsInference(
270271
const std::vector<Expr>& exprs,
271272
const FunctionBounds& bounds,
272273
const lang::TreeRef& comprehension,
273-
bool throwWarnings,
274+
const tc::CompilerOptions& compilerOptions,
274275
Scope<Interval>* solution) {
275276
class CreateConstraints : public IRVisitor {
276277
using IRVisitor::visit;
@@ -488,10 +489,10 @@ void forwardBoundsInference(
488489
lang::ErrorReport err(comprehension);
489490
err << "Required precondition will not be checked at runtime: "
490491
<< remaining;
491-
if (throwWarnings) {
492+
if (compilerOptions.throwWarnings) {
492493
throw err;
493494
} else {
494-
warn(err);
495+
warn(err, compilerOptions);
495496
}
496497
}
497498
}
@@ -509,7 +510,7 @@ Expr reductionUpdate(Expr e) {
509510
void translateComprehension(
510511
const lang::Comprehension& comprehension,
511512
const map<string, Parameter>& params,
512-
bool throwWarnings,
513+
const tc::CompilerOptions& compilerOptions,
513514
map<string, Function>* funcs,
514515
FunctionBounds* bounds) {
515516
Function f;
@@ -670,7 +671,7 @@ void translateComprehension(
670671
// Infer the rest
671672
all_exprs.push_back(rhs);
672673
forwardBoundsInference(
673-
all_exprs, *bounds, comprehension, throwWarnings, &solution);
674+
all_exprs, *bounds, comprehension, compilerOptions, &solution);
674675

675676
// TODO: What if subsequent updates have incompatible bounds
676677
// (e.g. an in-place stencil)?. The .bound directive will use the
@@ -754,7 +755,9 @@ void translateComprehension(
754755
}
755756

756757
// Translate a semantically checked TC def to HalideComponents struct.
757-
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
758+
HalideComponents translateDef(
759+
const lang::Def& def,
760+
const tc::CompilerOptions& compilerOptions) {
758761
map<string, Function> funcs;
759762
HalideComponents components;
760763
components.def = def;
@@ -765,7 +768,7 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
765768
}
766769
for (auto c : def.statements()) {
767770
translateComprehension(
768-
c, components.params, throwWarnings, &funcs, &bounds);
771+
c, components.params, compilerOptions, &funcs, &bounds);
769772
}
770773
vector<Function> outputs;
771774
for (auto p : def.returns()) {
@@ -906,19 +909,23 @@ HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
906909
}
907910
} // namespace
908911

909-
HalideComponents
910-
translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
912+
HalideComponents translate(
913+
isl::ctx ctx,
914+
const lang::TreeRef& treeRef,
915+
const tc::CompilerOptions& compilerOptions = tc::CompilerOptions()) {
911916
LOG_IF(INFO, tc::FLAGS_debug_halide) << treeRef;
912917
return translateDef(
913-
lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings);
918+
lang::Def(lang::Sema().checkFunction(treeRef)), compilerOptions);
914919
}
915920

916921
// NOTE: there is no guarantee here that the tc string has only one def. It
917922
// could have many defs. Only first def will be converted in that case.
918-
HalideComponents
919-
translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) {
923+
HalideComponents translate(
924+
isl::ctx ctx,
925+
const std::string& tc,
926+
const tc::CompilerOptions& compilerOptions = tc::CompilerOptions()) {
920927
LOG_IF(INFO, tc::FLAGS_debug_halide) << tc;
921-
return translate(ctx, lang::Parser(tc).parseFunction(), throwWarnings);
928+
return translate(ctx, lang::Parser(tc).parseFunction(), compilerOptions);
922929
}
923930

924931
} // namespace tc2halide

tc/core/tc2halide.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "tc/external/isl.h"
2121
#include "tc/lang/tree.h"
2222
#include "tc/lang/tree_views.h"
23+
#include "tc/utils/compiler_options.h"
2324

2425
namespace tc2halide {
2526

@@ -44,15 +45,19 @@ struct HalideComponents {
4445
Halide::Internal::Call::ConstString kReductionUpdate = "ReductionUpdate";
4546

4647
// Translate a TC parse tree into equivalent Halide imperative IR with
47-
// a naive schedule.
48+
// a naive schedule. Additional options, such as how to treat warnings, are
49+
// passed in as "compilerOptions".
4850
HalideComponents translate(
4951
isl::ctx ctx,
5052
const lang::TreeRef& treeRef,
51-
bool throwWarnings = false);
53+
const tc::CompilerOptions& compilerOptions);
5254

5355
// Translate TC source into equivalent Halide imperative IR with a
54-
// naive schedule.
55-
HalideComponents
56-
translate(isl::ctx ctx, const std::string& tc, bool throwWarnings = false);
56+
// naive schedule. Additional options, such as how to treat warnings, are
57+
// passed in as "compilerOptions".
58+
HalideComponents translate(
59+
isl::ctx ctx,
60+
const std::string& tc,
61+
const tc::CompilerOptions& compilerOptions);
5762

5863
} // namespace tc2halide

tc/utils/compiler_options.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class CompilerOptions {
3232

3333
/// Print syntactic warnings.
3434
bool emitWarnings = true;
35+
/// Treat warnings in TC to Halide conversion as exceptions.
36+
bool throwWarnings = false;
3537
};
3638

3739
} // namespace tc

test/test_core.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "tc/lang/error_report.h"
3333
#include "tc/library/copy.h"
3434
#include "tc/library/matmul.h"
35+
#include "tc/utils/compiler_options.h"
3536

3637
using namespace std;
3738

@@ -60,8 +61,8 @@ dtype {
6061
struct GenericHalideCoreTest : public ::testing::Test {
6162
void CheckC(const std::string& tc, const std::vector<std::string>& expected) {
6263
auto curPos = std::string::npos;
63-
auto halide =
64-
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tc);
64+
auto halide = tc2halide::translate(
65+
isl::with_exceptions::globalIslCtx(), tc, CompilerOptions());
6566
auto res = tc::halideCodegenC(halide.stmt);
6667
for (const auto& e : expected) {
6768
auto newPos = res.find(e);
@@ -243,8 +244,8 @@ struct TC2Isl : public ::testing::Test {
243244
DLConstTensorUPtr in = makeDLConstTensor(ti);
244245

245246
// Must reuse the same ctx or memleaks ensue!
246-
tc2halide::HalideComponents comps =
247-
tc2halide::translate(isl::with_exceptions::globalIslCtx(), tc);
247+
tc2halide::HalideComponents comps = tc2halide::translate(
248+
isl::with_exceptions::globalIslCtx(), tc, CompilerOptions());
248249
auto scop =
249250
polyhedral::Scop::makeScop(isl::with_exceptions::globalIslCtx(), comps);
250251
polyhedral::detail::validateSchedule(scop->scheduleRoot());

test/test_cuda_mapper.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct PolyhedralMapperTest : public ::testing::Test {
5050
std::unique_ptr<Scop> Prepare(std::string tc) {
5151
auto ctx = isl::with_exceptions::globalIslCtx();
5252
// Build the SCoP corresponding to the Tc
53-
return Scop::makeScop(ctx, tc);
53+
return Scop::makeScop(ctx, tc, CompilerOptions());
5454
}
5555

5656
std::unique_ptr<Scop> PrepareAndJoinBands(std::string tc) {

0 commit comments

Comments
 (0)