Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit fb45cf6

Browse files
committed
compiler API: take CompilerOptions as an optional argument
Make all tc::compile and related functions take an instance of CompilerOptions. The options are defaulted to a default-constructed instance of CompilerOptions that preserves the original behavior of the compilaton flow. Making the argument optional preserves the old way of calling tc::compile since it is a user-facing function.
1 parent 207ea4b commit fb45cf6

File tree

5 files changed

+34
-20
lines changed

5 files changed

+34
-20
lines changed

tc/aten/aten_compiler-inl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
3232
const std::string& tc,
3333
const std::string& entryPoint,
3434
const std::vector<at::Tensor>& inputs,
35-
const typename Backend::MappingOptionsType& options) {
35+
const typename Backend::MappingOptionsType& options,
36+
const CompilerOptions& compilerOptions) {
3637
auto inputDLTensors = makeDLConstTensors(inputs);
3738
return tc::compile<Backend>(
38-
tc, entryPoint, extractRawPtrs(inputDLTensors), options);
39+
tc, entryPoint, extractRawPtrs(inputDLTensors), options, compilerOptions);
3940
}
4041

4142
template <typename Executor>

tc/aten/aten_compiler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
5858
const std::string& tc,
5959
const std::string& entryPoint,
6060
const std::vector<at::Tensor>& inputs,
61-
const typename Backend::MappingOptionsType& options);
61+
const typename Backend::MappingOptionsType& options,
62+
const CompilerOptions& compilerOptions = CompilerOptions());
6263

6364
/// Given an executor resulting from compiling a TC, run the TC and fill the
6465
/// outputs vector with the results. The output vector must have as many

tc/core/compiler-inl.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
3737
const std::string& entryPoint,
3838
const std::vector<const DLConstTensor*>& inputs,
3939
/* TODO: in the future also pass outputs for stride and alignment info */
40-
const typename Backend::MappingOptionsType& options) {
40+
const typename Backend::MappingOptionsType& options,
41+
const CompilerOptions& compilerOptions) {
4142
auto parsedTcs = detail::parse(tc);
4243
TC_CHECK_EQ(parsedTcs.count(entryPoint), 1u)
4344
<< "attempting to access undefined function " << entryPoint;
44-
return detail::compile<Backend>(parsedTcs[entryPoint], inputs, options);
45+
return detail::compile<Backend>(
46+
parsedTcs[entryPoint], inputs, options, compilerOptions);
4547
}
4648

4749
namespace detail {
@@ -50,13 +52,15 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
5052
lang::TreeRef tcDefinition,
5153
const std::vector<const DLConstTensor*>& inputs,
5254
/* TODO: in the future also pass outputs for stride and alignment info */
53-
const typename Backend::MappingOptionsType& options) {
55+
const typename Backend::MappingOptionsType& options,
56+
const CompilerOptions& compilerOptions) {
5457
using CompilationResultType = typename Backend::CompilationResultType;
5558

5659
auto inputsInfo = makeTensorInfoVector(inputs);
57-
auto outputsInfo = detail::inferOutputTensorInfo(tcDefinition, inputs);
60+
auto outputsInfo =
61+
detail::inferOutputTensorInfo(tcDefinition, inputs, compilerOptions);
5862
auto halideComponents = tc2halide::translate(
59-
isl::with_exceptions::globalIslCtx(), tcDefinition, CompilerOptions());
63+
isl::with_exceptions::globalIslCtx(), tcDefinition, compilerOptions);
6064
detail::checkInputsCompliant(halideComponents, inputs);
6165

6266
auto tcName = lang::Def(tcDefinition).name().name();

tc/core/compiler.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,19 @@
2424
#include "tc/core/halide_utils.h"
2525
#include "tc/core/tensor.h"
2626
#include "tc/lang/canonicalize.h"
27+
#include "tc/utils/compiler_options.h"
2728

2829
namespace tc {
2930
std::vector<TensorInfo> inferOutputTensorInfo(
3031
const std::string& tc,
3132
const std::string& entryPoint,
32-
const std::vector<const DLConstTensor*> inputs) {
33+
const std::vector<const DLConstTensor*> inputs,
34+
const CompilerOptions& compilerOptions) {
3335
auto parsedTcs = detail::parse(tc);
3436
TC_CHECK_EQ(parsedTcs.count(entryPoint), 1u)
3537
<< "attempting to access undefined function " << entryPoint;
36-
return tc::detail::inferOutputTensorInfo(parsedTcs[entryPoint], inputs);
38+
return tc::detail::inferOutputTensorInfo(
39+
parsedTcs[entryPoint], inputs, compilerOptions);
3740
}
3841

3942
namespace detail {
@@ -101,12 +104,11 @@ void checkInputsCompliant(
101104

102105
std::vector<TensorInfo> inferOutputTensorInfo(
103106
lang::TreeRef tcDefinition,
104-
const std::vector<const DLConstTensor*> inputs) {
107+
const std::vector<const DLConstTensor*> inputs,
108+
const CompilerOptions& compilerOptions) {
105109
return tc::inferOutputTensorInfo(
106110
tc2halide::translate(
107-
isl::with_exceptions::globalIslCtx(),
108-
tcDefinition,
109-
CompilerOptions()),
111+
isl::with_exceptions::globalIslCtx(), tcDefinition, compilerOptions),
110112
inputs);
111113
}
112114

tc/core/compiler.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "tc/core/mapping_options.h"
2323
#include "tc/core/tensor.h"
2424
#include "tc/lang/tree.h"
25+
#include "tc/utils/compiler_options.h"
2526

2627
/**
2728
* This provides a simple functional-style C++ API with multi-backend
@@ -62,8 +63,9 @@ namespace tc {
6263
/// "entryPoint", this function compiles a new TcExecutor for the specified
6364
/// Backend. For now, contiguous output sizes are inferred given input sizes.
6465
/// If you need another kernel for another entryPoint or other inputs or
65-
// other options then just compile another TcExecutor; because atm we fully
66-
/// JIT specialize on all sizes.
66+
/// other options then just compile another TcExecutor; because atm we fully
67+
/// JIT specialize on all sizes. General compilation options (warnings, debug
68+
/// info) are provided in "compilerOptions".
6769
/// \returns a new TcExecutor on which the run method can be called to run
6870
/// entryPoint
6971
template <typename Backend>
@@ -72,7 +74,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
7274
const std::string& entryPoint,
7375
const std::vector<const DLConstTensor*>& inputs,
7476
/* TODO: in the future also pass outputs for stride and alignment info */
75-
const typename Backend::MappingOptionsType& options);
77+
const typename Backend::MappingOptionsType& options,
78+
const CompilerOptions& compilerOptions = CompilerOptions());
7679

7780
/// Given a TC representation as a TC + TC function name entryPoint and a list
7881
/// of input tensors that match the definition in the TC function definition
@@ -85,7 +88,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
8588
std::vector<TensorInfo> inferOutputTensorInfo(
8689
const std::string& tc,
8790
const std::string& entryPoint,
88-
const std::vector<const DLConstTensor*> inputs);
91+
const std::vector<const DLConstTensor*> inputs,
92+
const CompilerOptions& compilerOptions = CompilerOptions());
8993

9094
namespace detail {
9195
/// Given a TC representation, this parses the TC functions into a map of
@@ -105,7 +109,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
105109
lang::TreeRef tcDefinition,
106110
const std::vector<const DLConstTensor*>& inputs,
107111
/* TODO: in the future also pass outputs for stride and alignment info */
108-
const typename Backend::MappingOptionsType& options);
112+
const typename Backend::MappingOptionsType& options,
113+
const CompilerOptions& compilerOptions = CompilerOptions());
109114

110115
/// Given a TC representation as a TreeRef and a list of input tensors that
111116
/// match the definition in the TC function definition (in positional order),
@@ -116,7 +121,8 @@ std::unique_ptr<typename Backend::ExecutorType> compile(
116121
/// performing output shape validation.
117122
std::vector<TensorInfo> inferOutputTensorInfo(
118123
lang::TreeRef tcDefinition,
119-
const std::vector<const DLConstTensor*> inputs);
124+
const std::vector<const DLConstTensor*> inputs,
125+
const CompilerOptions& compilerOptions = CompilerOptions());
120126

121127
} // namespace detail
122128
} // namespace tc

0 commit comments

Comments
 (0)