@@ -69,11 +69,16 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {
6969 }
7070}
7171
72+ bool valid_input_domain (std::vector<double > domain) {
73+ return (domain.size () == 2 ) && (domain[0 ] < domain[1 ]);
74+ }
75+
7276Input::Input (
7377 std::vector<int64_t > shape,
7478 at::ScalarType dtype,
7579 nvinfer1::TensorFormat format,
76- bool dtype_is_user_defined) {
80+ bool dtype_is_user_defined,
81+ std::vector<double > tensor_domain) {
7782 if (shape.size () > 5 ) {
7883 LOG_WARNING (" Verify that this dim size is accepted" );
7984 }
@@ -93,6 +98,11 @@ Input::Input(
9398 << " ), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported" );
9499 this ->format = format;
95100 this ->dtype_is_user_defined = dtype_is_user_defined;
101+
102+ TORCHTRT_CHECK (
103+ valid_input_domain (tensor_domain),
104+ " Unsupported tensor domain: [" << tensor_domain[0 ] << " , " << tensor_domain[1 ] << " )" );
105+ this ->tensor_domain = tensor_domain;
96106}
97107
98108Input::Input (
@@ -101,7 +111,8 @@ Input::Input(
101111 std::vector<int64_t > max_shape,
102112 at::ScalarType dtype,
103113 nvinfer1::TensorFormat format,
104- bool dtype_is_user_defined) {
114+ bool dtype_is_user_defined,
115+ std::vector<double > tensor_domain) {
105116 if (min_shape.size () > 5 || opt_shape.size () > 5 || max_shape.size () > 5 ) {
106117 LOG_WARNING (" Verify that this dim size is accepted" );
107118 }
@@ -146,6 +157,10 @@ Input::Input(
146157 << " ), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported" );
147158 this ->format = format;
148159 this ->dtype_is_user_defined = dtype_is_user_defined;
160+ TORCHTRT_CHECK (
161+ valid_input_domain (tensor_domain),
162+ " Unsupported tensor domain: [" << tensor_domain[0 ] << " , " << tensor_domain[1 ] << " )" );
163+ this ->tensor_domain = tensor_domain;
149164}
150165
151166std::ostream& operator <<(std::ostream& os, const Input& input) {
0 commit comments