@@ -21,6 +21,28 @@ Var::Var(nvinfer1::ITensor* p) : type_(Type::kITensor) {
2121 ptr_.tensor = p;
2222}
2323
24+ Var::IValueType Var::determineIValueType (torch::jit::IValue* p) {
25+ if (p->isInt ()) {
26+ return IValueType::kInt ;
27+ } else if (p->isDouble ()) {
28+ return IValueType::kDouble ;
29+ } else if (p->isBool ()) {
30+ return IValueType::kBool ;
31+ } else if (p->isTensor ()) {
32+ return IValueType::kTensor ;
33+ } else if (p->isIntList ()) {
34+ return IValueType::kIntList ;
35+ } else if (p->isDoubleList ()) {
36+ return IValueType::kDoubleList ;
37+ } else if (p->isBoolList ()) {
38+ return IValueType::kBoolList ;
39+ } else if (p->isTensorList ()) {
40+ return IValueType::kTensorList ;
41+ } else if (p->isList ()) {
42+ return IValueType::kITensorList ;
43+ }
44+ }
45+
2446Var::Var (const Var& a) {
2547 switch (a.type_ ) {
2648 case Type::kITensor :
@@ -30,6 +52,7 @@ Var::Var(const Var& a) {
3052 case Type::kIValue :
3153 ptr_.ivalue = a.ptr_ .ivalue ;
3254 type_ = Type::kIValue ;
55+ ivalue_type_ = determineIValueType (ptr_.ivalue );
3356 break ;
3457 case Type::kNone :
3558 default :
@@ -47,6 +70,7 @@ Var& Var::operator=(const Var& a) {
4770 case Type::kIValue :
4871 ptr_.ivalue = a.ptr_ .ivalue ;
4972 type_ = Type::kIValue ;
73+ ivalue_type_ = determineIValueType (ptr_.ivalue );
5074 break ;
5175 case Type::kNone :
5276 default :
@@ -59,6 +83,7 @@ Var& Var::operator=(const Var& a) {
5983Var& Var::operator =(torch::jit::IValue* in) {
6084 ptr_.ivalue = in;
6185 type_ = Type::kIValue ;
86+ ivalue_type_ = determineIValueType (ptr_.ivalue );
6287 return (*this );
6388}
6489
@@ -72,6 +97,10 @@ Var::Type Var::type() const {
7297 return type_;
7398}
7499
100+ Var::IValueType Var::ivalue_type () const {
101+ return ivalue_type_;
102+ }
103+
75104std::string Var::type_name () const {
76105 switch (type_) {
77106 case Type::kITensor :
@@ -147,7 +176,39 @@ bool Var::isITensor() const {
147176}
148177
149178bool Var::isITensorList () const {
150- if (type_ == Type::kITensor ) {
179+ if (ivalue_type_ == IValueType::kITensorList ) {
180+ return true ;
181+ } else {
182+ return false ;
183+ }
184+ }
185+
186+ bool Var::isIntList () const {
187+ if (ivalue_type_ == IValueType::kIntList ) {
188+ return true ;
189+ } else {
190+ return false ;
191+ }
192+ }
193+
194+ bool Var::isDoubleList () const {
195+ if (ivalue_type_ == IValueType::kDoubleList ) {
196+ return true ;
197+ } else {
198+ return false ;
199+ }
200+ }
201+
202+ bool Var::isTensorList () const {
203+ if (ivalue_type_ == IValueType::kTensorList ) {
204+ return true ;
205+ } else {
206+ return false ;
207+ }
208+ }
209+
210+ bool Var::isBoolList () const {
211+ if (ivalue_type_ == IValueType::kBoolList ) {
151212 return true ;
152213 } else {
153214 return false ;
@@ -157,6 +218,8 @@ bool Var::isITensorList() const {
157218std::vector<nvinfer1::ITensor*> Var::unwrapToITensorList () {
158219 TORCHTRT_CHECK (
159220 isIValue (), " Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name ());
221+ LOG_DEBUG (" === Is INT list: " << ptr_.ivalue ->isIntList ());
222+ LOG_DEBUG (" === Is List: " << ptr_.ivalue ->isList ());
160223 auto ivalue_list = ptr_.ivalue ->toList ();
161224 std::vector<nvinfer1::ITensor*> outputs;
162225 for (int i = 0 ; i < ivalue_list.size (); i++) {
0 commit comments