@@ -183,15 +183,23 @@ void emitTensorView(
183183 stringstream& ss,
184184 Halide::OutputImageParam p,
185185 const map<string, Halide::Expr>& paramValues,
186- bool constInput = false ) {
186+ bool constInput = false ,
187+ const TensorInfo* tinfo = NULL ) {
187188 WS ws;
188189 stringstream ssViewType;
189190 for (int i = 1 ; i < p.dimensions (); ++i) { // Skip the outermost dimension
190191 Halide::Expr extent = p.parameter ().extent_constraint (i);
191192 extent = Halide::Internal::substitute (paramValues, extent);
192193 CHECK (extent.defined ())
193194 << " Undefined extent on input/output tensor. Forward bounds inference should have set these\n " ;
194- ssViewType << " [" << extent << " ]" ;
195+ // TODO: Handle non-unit stride in the innermost dimension
196+ if (tinfo && tinfo->strides .size () == p.dimensions () &&
197+ tinfo->strides [p.dimensions () - 1 ] == 1 &&
198+ tinfo->strides [i - 1 ] != (tinfo->shape [i] * tinfo->strides [i])) {
199+ ssViewType << " [" << tinfo->strides [i - 1 ] << " ]" ;
200+ } else {
201+ ssViewType << " [" << extent << " ]" ;
202+ }
195203 }
196204 ss << ws.tab ();
197205 ss << (constInput ? " const " : " " ) << p.type () << " (*" << p.name () << " )"
@@ -216,9 +224,12 @@ void emitTensorViews(
216224void emitTensorViews (
217225 stringstream& ss,
218226 const vector<Halide::ImageParam>& params,
219- const map<string, Halide::Expr>& paramValues) {
220- for (auto p : params) {
221- emitTensorView (ss, p, paramValues, true );
227+ const map<string, Halide::Expr>& paramValues,
228+ const std::vector<TensorInfo>& inputsInfo) {
229+ for (size_t i = 0 ; i < params.size (); ++i) {
230+ inputsInfo.size ()
231+ ? emitTensorView (ss, params[i], paramValues, true , &inputsInfo[i])
232+ : emitTensorView (ss, params[i], paramValues, true );
222233 }
223234}
224235
@@ -738,7 +749,8 @@ std::unordered_set<isl::id, isl::IslIdIslHash> gatherReadOnlySet(
738749
739750string emitCudaKernel (
740751 const std::string& specializedName,
741- const MappedScop& mscop) {
752+ const MappedScop& mscop,
753+ const std::vector<TensorInfo>& inputsInfo) {
742754 // Expecting a schedule with domain root and context first child.
743755 CHECK (mscop.schedule ()->elemAs <detail::ScheduleTreeElemDomain>());
744756 CHECK (
@@ -755,7 +767,7 @@ string emitCudaKernel(
755767 emitKernelSignature (ss, specializedName, scop);
756768 emitThreadIdInit (ss, mscop);
757769 emitTensorViews (ss, scop.halide .outputs , paramValues);
758- emitTensorViews (ss, scop.halide .inputs , paramValues);
770+ emitTensorViews (ss, scop.halide .inputs , paramValues, inputsInfo );
759771 emitTmpDecl (ss, scop);
760772 emitPromotedArrayViewsHalide (ss, scop);
761773 NodeInfoMapType nodeInfoMap;
0 commit comments