Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ set(PT_LIBS
if (${TRITON_PYTORCH_NVSHMEM})
set(PT_LIBS
${PT_LIBS}
"libtorch_nvshmem.so"
Copy link
Author

@chdhr-harshal chdhr-harshal Nov 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this temporarily because I don't have GPU on my personal machine.

)
endif() # TRITON_PYTORCH_NVSHMEM

Expand Down
Binary file added model_repository/dict_model/1/model.pt
Binary file not shown.
24 changes: 24 additions & 0 deletions model_repository/dict_model/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: "dict_model"
platform: "pytorch_libtorch"
max_batch_size: 8

input [
{
name: "INPUT__0"
data_type: TYPE_FP32
dims: [ 10 ]
}
]

output [
{
name: "logits"
data_type: TYPE_FP32
dims: [ 20 ]
},
{
name: "embeddings"
data_type: TYPE_FP32
dims: [ 5 ]
}
]
26 changes: 24 additions & 2 deletions src/model_instance_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace triton::backend::pytorch {
ModelInstanceState::ModelInstanceState(
ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance)
: BackendModelInstance(model_state, triton_model_instance),
model_state_(model_state), device_(torch::kCPU), is_dict_input_(false),
model_state_(model_state), device_(torch::kCPU), is_dict_input_(false), is_dict_output_(false),
device_cnt_(0)
{
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
Expand Down Expand Up @@ -345,6 +345,18 @@ ModelInstanceState::Execute(
list_output.elementType()->str() + "]");
}
output_tensors->push_back(model_outputs_);
} else if (model_outputs_.isGenericDict()) {
is_dict_output_ = true;
auto dict_output = model_outputs_.toGenericDict();
output_dict_key_to_index_.clear();

int index = 0;
for (auto it = dict_output.begin(); it != dict_output.end(); ++it) {
std::string key = it->key().toStringRef();
output_tensors->push_back(it->value());
output_dict_key_to_index_[key] = index;
index++;
}
} else {
throw std::invalid_argument(
"output must be of type Tensor, List[str] or Tuple containing one of "
Expand Down Expand Up @@ -872,7 +884,17 @@ ModelInstanceState::ReadOutputTensors(
// The serialized string buffer must be valid until output copies are done
std::vector<std::unique_ptr<std::string>> string_buffer;
for (auto& output : model_state_->ModelOutputs()) {
int op_index = output_index_map_[output.first];
// Use dict key mapping if available
int op_index;
if (is_dict_output_) {
auto it = output_dict_key_to_index_.find(output.first);
if (it == output_dict_key_to_index_.end()) {
continue; // Skip outputs not in dict
}
op_index = it->second;
} else {
op_index = output_index_map_[output.first];
}
auto name = output.first;
auto output_tensor_pair = output.second;

Expand Down
4 changes: 4 additions & 0 deletions src/model_instance_state.hh
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ class ModelInstanceState : public BackendModelInstance {
// Map from configuration name for an output to the index of
// that output in the model.
std::unordered_map<std::string, int> output_index_map_;

// If the output is a dictionary of tensors.
std::unordered_map<std::string, int> output_dict_key_to_index_;
bool is_dict_output_;
std::unordered_map<std::string, TRITONSERVER_DataType> output_dtype_map_;

// If the input to the tensor is a dictionary of tensors.
Expand Down
25 changes: 25 additions & 0 deletions test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# test_client.py
import tritonclient.http as httpclient
import numpy as np

# Create client
client = httpclient.InferenceServerClient(url="localhost:8000")

# Prepare input
input_data = np.random.randn(5, 10).astype(np.float32)
inputs = [httpclient.InferInput("INPUT__0", input_data.shape, "FP32")]
inputs[0].set_data_from_numpy(input_data)

# Request outputs by dict key names
outputs = [
httpclient.InferRequestedOutput("logits"),
httpclient.InferRequestedOutput("embeddings")
]

# Infer
results = client.infer("dict_model", inputs, outputs=outputs)

# Check output names
print("Output names:", results.get_response())
print("Logits shape:", results.as_numpy("logits").shape)
print("Embeddings shape:", results.as_numpy("embeddings").shape)
33 changes: 33 additions & 0 deletions test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# test_model.py
import torch
import torch.nn as nn

class DictOutputModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 20)
self.fc3 = nn.Linear(50, 5)

def forward(self, x):
features = self.fc1(x)
logits = self.fc2(features)
embeddings = self.fc3(features)

# Return dictionary
return {
"logits": logits,
"embeddings": embeddings
}

# Create and save model
model = DictOutputModel()
model.eval()

# Trace with example input
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input, strict=False)

# Save
torch.jit.save(traced_model, "model.pt")
print("Model saved!")