@@ -21,28 +21,6 @@ Var::Var(nvinfer1::ITensor* p) : type_(Type::kITensor) {
2121 ptr_.tensor = p;
2222}
2323
24- Var::IValueType Var::determineIValueType (torch::jit::IValue* p) {
25- if (p->isInt ()) {
26- return IValueType::kInt ;
27- } else if (p->isDouble ()) {
28- return IValueType::kDouble ;
29- } else if (p->isBool ()) {
30- return IValueType::kBool ;
31- } else if (p->isTensor ()) {
32- return IValueType::kTensor ;
33- } else if (p->isIntList ()) {
34- return IValueType::kIntList ;
35- } else if (p->isDoubleList ()) {
36- return IValueType::kDoubleList ;
37- } else if (p->isBoolList ()) {
38- return IValueType::kBoolList ;
39- } else if (p->isTensorList ()) {
40- return IValueType::kTensorList ;
41- } else if (p->isList ()) {
42- return IValueType::kITensorList ;
43- }
44- }
45-
4624Var::Var (const Var& a) {
4725 switch (a.type_ ) {
4826 case Type::kITensor :
@@ -52,7 +30,6 @@ Var::Var(const Var& a) {
5230 case Type::kIValue :
5331 ptr_.ivalue = a.ptr_ .ivalue ;
5432 type_ = Type::kIValue ;
55- ivalue_type_ = determineIValueType (ptr_.ivalue );
5633 break ;
5734 case Type::kNone :
5835 default :
@@ -70,7 +47,6 @@ Var& Var::operator=(const Var& a) {
7047 case Type::kIValue :
7148 ptr_.ivalue = a.ptr_ .ivalue ;
7249 type_ = Type::kIValue ;
73- ivalue_type_ = determineIValueType (ptr_.ivalue );
7450 break ;
7551 case Type::kNone :
7652 default :
@@ -83,7 +59,6 @@ Var& Var::operator=(const Var& a) {
8359Var& Var::operator =(torch::jit::IValue* in) {
8460 ptr_.ivalue = in;
8561 type_ = Type::kIValue ;
86- ivalue_type_ = determineIValueType (ptr_.ivalue );
8762 return (*this );
8863}
8964
@@ -97,10 +72,6 @@ Var::Type Var::type() const {
9772 return type_;
9873}
9974
100- Var::IValueType Var::ivalue_type () const {
101- return ivalue_type_;
102- }
103-
10475std::string Var::type_name () const {
10576 switch (type_) {
10677 case Type::kITensor :
@@ -175,40 +146,8 @@ bool Var::isITensor() const {
175146 }
176147}
177148
178- bool Var::isITensorList () const {
179- if (ivalue_type_ == IValueType::kITensorList ) {
180- return true ;
181- } else {
182- return false ;
183- }
184- }
185-
186- bool Var::isIntList () const {
187- if (ivalue_type_ == IValueType::kIntList ) {
188- return true ;
189- } else {
190- return false ;
191- }
192- }
193-
194- bool Var::isDoubleList () const {
195- if (ivalue_type_ == IValueType::kDoubleList ) {
196- return true ;
197- } else {
198- return false ;
199- }
200- }
201-
202- bool Var::isTensorList () const {
203- if (ivalue_type_ == IValueType::kTensorList ) {
204- return true ;
205- } else {
206- return false ;
207- }
208- }
209-
210- bool Var::isBoolList () const {
211- if (ivalue_type_ == IValueType::kBoolList ) {
149+ bool Var::isITensorList () {
150+ if (isList () && ptr_.ivalue ->isCustomClass ()) {
212151 return true ;
213152 } else {
214153 return false ;
@@ -218,10 +157,7 @@ bool Var::isBoolList() const {
218157std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList () {
219158 TORCHTRT_CHECK (
220159 isIValue (), " Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
221- TORCHTRT_CHECK (
222- isITensorList (),
223- " Expected IValue to be an ITensorList, however the type is "
224- << static_cast <std::underlying_type<IValueType>::type>(ivalue_type_));
160+ TORCHTRT_CHECK (isITensorList (), " Expected IValue to be an ITensorList" );
225161 auto ivalue_list = ptr_.ivalue ->toList ();
226162 std::vector<nvinfer1::ITensor*> outputs;
227163 for (int i = 0 ; i < ivalue_list.size (); i++) {
0 commit comments