Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit 8d4410c

Browse files
author
Sven Verdoolaege
committed
Scop::fixParameters: store parameter values in the specialized scop
This will allow the values to be reused in emitCudaKernel. The values are stored as int because that is the type used in emitCudaKernel. Any larger type would only get truncated there.
1 parent 1dd7487 commit 8d4410c

File tree

3 files changed

+17
-11
lines changed

3 files changed

+17
-11
lines changed

tc/core/cuda/cuda_tc_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ CudaCompilationResult CudaBackend::compileWithTcMapper(
9090
LOG_IF(INFO, FLAGS_debug_tc_mapper) << "Mapped schedule:" << std::endl
9191
<< *(mappedScop->schedule());
9292

93-
auto parameters = mappedScop->scop().getParameterValues(pvm);
93+
auto parameters = mappedScop->scop().getParameterValues();
9494
auto specializedName = specializeKernelName(tcName, parameters);
9595

9696
// This updates the launch bounds with the actual result from compilation

tc/core/polyhedral/scop.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -262,17 +262,17 @@ void Scop::promoteEverythingAt(std::vector<size_t> pos) {
262262
insertSyncsAroundCopies(tree);
263263
}
264264

265-
std::vector<long> Scop::getParameterValues(
266-
const std::unordered_map<std::string, int>& sizes) const {
265+
std::vector<long> Scop::getParameterValues() const {
267266
// Scop holds a vector of Variables.
268267
// Iterate over parameters in order, checking if the
269-
// context contains a parameter whose name corresponds to that
268+
// map of known parameter values contains a parameter
269+
// whose name corresponds to that
270270
// Variable and push respective parameter values.
271271
std::vector<long> paramValues;
272272
for (auto const& param : halide.params) {
273273
auto name = param.name();
274-
CHECK(sizes.count(name) == 1);
275-
paramValues.push_back(sizes.at(name));
274+
CHECK(parameterValues.count(name) == 1);
275+
paramValues.push_back(parameterValues.at(name));
276276
}
277277
return paramValues;
278278
}

tc/core/polyhedral/scop.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ struct Scop {
6363
// Clone a Scop
6464
static std::unique_ptr<Scop> makeScop(const Scop& scop) {
6565
auto res = std::unique_ptr<Scop>(new Scop());
66+
res->parameterValues = scop.parameterValues;
6667
res->globalParameterContext = scop.globalParameterContext;
6768
res->halide = scop.halide;
6869
res->reads = scop.reads;
@@ -137,18 +138,21 @@ struct Scop {
137138
}
138139

139140
// Fix the values of the specified parameters in the context
140-
// to the corresponding specified values.
141+
// to the corresponding specified values and keep track of them
142+
// in parameterValues.
141143
template <typename T>
142144
void fixParameters(const std::unordered_map<std::string, T>& sizes) {
145+
CHECK(parameterValues.size() == 0);
146+
for (const auto& kvp : sizes) {
147+
parameterValues.emplace(kvp.first, kvp.second);
148+
}
143149
intersectContext(makeContext(sizes));
144150
}
145151

146-
// Given a map between TC parametric tensor sizes, represented as strings,
147-
// and their numerical values, return the list of parameter values in the same
152+
// Return the list of parameter values in the same
148153
// order as codegen places them in the function signature, i.e. following the
149154
// order of scop.params.
150-
std::vector<long> getParameterValues(
151-
const std::unordered_map<std::string, int>& sizes) const;
155+
std::vector<long> getParameterValues() const;
152156

153157
isl::id nextGroupIdForTensor(isl::id tensorId) {
154158
auto ctx = domain().get_ctx();
@@ -490,6 +494,8 @@ struct Scop {
490494
// of the ScheduleTree "function".
491495
isl::union_set& domain();
492496
const isl::union_set domain() const;
497+
// The parameter values of a specialized Scop.
498+
std::unordered_map<std::string, int> parameterValues;
493499
// A globalParameterContext is kept. This represents (partial)
494500
// parameter specialization coming from the outside.
495501
// This may be further specialized before codegen.

0 commit comments

Comments
 (0)