Skip to content

Commit 2537feb

Browse files
committed
Update TensorFlow create_raw_prediction_request
1 parent 33c1c03 commit 2537feb

File tree

1 file changed

+6
-12
lines changed
  • pkg/workloads/cortex/tf_api

1 file changed

+6
-12
lines changed

pkg/workloads/cortex/tf_api/api.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,14 @@ def create_raw_prediction_request(sample):
154154
prediction_request.model_spec.signature_name = signature_key
155155

156156
for column_name, value in sample.items():
157-
if util.is_list(value):
158-
shape = [len(value)]
159-
for dim in signature_def[signature_key]["inputs"][column_name]["tensorShape"]["dim"][
160-
1:
161-
]:
162-
shape.append(int(dim["size"]))
163-
else:
164-
shape = [1]
165-
value = [value]
157+
shape = []
158+
for dim in signature_def[signature_key]["inputs"][column_name]["tensorShape"]["dim"]:
159+
shape.append(int(dim["size"]))
160+
166161
sig_type = signature_def[signature_key]["inputs"][column_name]["dtype"]
162+
167163
try:
168-
tensor_proto = tf.make_tensor_proto(
169-
value, dtype=DTYPE_TO_TF_TYPE[sig_type], shape=shape
170-
)
164+
tensor_proto = tf.make_tensor_proto(value, dtype=DTYPE_TO_TF_TYPE[sig_type])
171165
prediction_request.inputs[column_name].CopyFrom(tensor_proto)
172166
except Exception as e:
173167
raise UserException(

0 commit comments

Comments
 (0)