@@ -393,6 +393,89 @@ auto expand_registrations TORCHTRT_UNUSED =
393393 auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], collapse->getOutput (0 ));
394394 LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
395395
396+ return true ;
397+ }})
398+ .pattern(
399+ {" aten::meshgrid(Tensor[] tensors) -> (Tensor[])" ,
400+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
401+ // torch.meshgrid only supports 1D or 0D input tensors
402+ auto arg_tensors = args[0 ].IValue ()->toListRef ();
403+ std::vector<nvinfer1::ITensor*> tensors;
404+ for (auto t : arg_tensors) {
405+ if (t.isTensor ()) {
406+ auto torch_tensor = t.toTensor ();
407+ tensors.push_back (tensor_to_const (ctx, torch_tensor));
408+ } else {
409+ auto cont = t.toCustomClass <TensorContainer>();
410+ tensors.push_back (cont->tensor ());
411+ }
412+ }
413+
414+ // build the output shape for all tensors in the output list
415+ nvinfer1::Dims output_dims;
416+ output_dims.nbDims = tensors.size ();
417+ for (size_t idx = 0UL ; idx < tensors.size (); ++idx) {
418+ auto dims = tensors[idx]->getDimensions ();
419+ output_dims.d [idx] = dims.nbDims == 0 ? 1 : dims.d [0 ];
420+ }
421+ std::vector<nvinfer1::ITensor*> out_tensors;
422+ // Reshape tensors into output shape (reshape, expand)
423+ for (size_t idx = 0UL ; idx < tensors.size (); ++idx) {
424+ auto t = tensors[idx];
425+ auto dims = t->getDimensions ();
426+ nvinfer1::Dims reshape_dims;
427+ reshape_dims.nbDims = tensors.size ();
428+ for (size_t reshape_idx = 0UL ; reshape_idx < tensors.size (); ++reshape_idx) {
429+ if (reshape_idx == idx) {
430+ reshape_dims.d [reshape_idx] = dims.nbDims == 0 ? 1 : dims.d [0 ];
431+ } else {
432+ reshape_dims.d [reshape_idx] = 1 ;
433+ }
434+ }
435+ // Add a reshape layer before expanding dims
436+ auto reshape_layer = ctx->net ->addShuffle (*t);
437+ reshape_layer->setReshapeDimensions (reshape_dims);
438+ std::stringstream reshape_layer_name;
439+ reshape_layer_name << util::node_info (n) << " _meshgrid_reshape_" << std::to_string (idx);
440+ reshape_layer->setName (reshape_layer_name.str ().c_str ());
441+ auto reshaped = reshape_layer->getOutput (0 );
442+ LOG_DEBUG (" Tensor " << idx << " reshaped to : " << reshaped->getDimensions () << " from " << dims);
443+
444+ // Add slice layer for expansion
445+ std::vector<int64_t > start_vec (output_dims.nbDims , 0 );
446+ auto start_offset = util::toDims (c10::IntArrayRef (start_vec));
447+
448+ std::vector<int64_t > strides_vec (output_dims.nbDims , 0 );
449+ for (int64_t i = 0 ; i < output_dims.nbDims ; i++) {
450+ strides_vec[i] = (reshaped->getDimensions ().d [i] != 1 );
451+ }
452+
453+ auto strides = util::toDims (c10::IntArrayRef (strides_vec));
454+
455+ auto slice_layer = ctx->net ->addSlice (*reshaped, start_offset, output_dims, strides);
456+ std::stringstream slice_layer_name;
457+ slice_layer_name << util::node_info (n) << " _meshgrid_slice_" << std::to_string (idx);
458+ slice_layer->setName (slice_layer_name.str ().c_str ());
459+ auto slice_output = slice_layer->getOutput (0 );
460+ LOG_DEBUG (" Tensor " << idx << " expanded to : " << slice_output->getDimensions ());
461+ out_tensors.push_back (slice_output);
462+ }
463+
464+ // Pack output tensors into list
465+ c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
466+ c10::TypePtr elementType = lt->getElementType ();
467+ auto list = c10::impl::GenericList (elementType);
468+ list.reserve (out_tensors.size ());
469+
470+ for (auto t : out_tensors) {
471+ auto tensor_holder = TensorContainer ();
472+ tensor_holder.hold_tensor (t);
473+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
474+ list.emplace_back (ival);
475+ }
476+
477+ auto output_list = std::move (torch::jit::IValue (list));
478+ ctx->AssociateValueAndIValue (n->outputs ()[0 ], output_list);
396479 return true ;
397480 }});
398481
0 commit comments