From addf6fe4a91c54bfe871cca34274bdd3ac17d5f5 Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Sat, 8 Jun 2024 15:03:27 +0800 Subject: [PATCH 1/2] Support int inputs for text models --- torch2trt/flattener.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch2trt/flattener.py b/torch2trt/flattener.py index 7fbf4070..d263aecd 100644 --- a/torch2trt/flattener.py +++ b/torch2trt/flattener.py @@ -3,7 +3,7 @@ def _default_condition(x): - return isinstance(x, torch.Tensor) and (x.dtype is torch.half or x.dtype is torch.float or x.dtype == torch.bool) + return isinstance(x, torch.Tensor) and (x.dtype is torch.half or x.dtype is torch.float or x.dtype == torch.bool or x.dtype == torch.int32 or x.dtype == torch.int64 or x.dtype == torch.long) def _make_schema_from_value(value, condition=_default_condition, size=0): @@ -90,4 +90,4 @@ def unflatten(self, flattened): result[child_key] = Flattener(child_schema, self.size).unflatten(flattened) return result else: - return None \ No newline at end of file + return None From d11692c4fbdffc419564eefb28f5bee8387b74cf Mon Sep 17 00:00:00 2001 From: Billy Cao Date: Sat, 8 Jun 2024 18:58:22 +0800 Subject: [PATCH 2/2] Fix utf char --- torch2trt/flattener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch2trt/flattener.py b/torch2trt/flattener.py index d263aecd..fca20595 100644 --- a/torch2trt/flattener.py +++ b/torch2trt/flattener.py @@ -3,7 +3,7 @@ def _default_condition(x): - return isinstance(x, torch.Tensor) and (x.dtype is torch.half or x.dtype is torch.float or x.dtype == torch.bool or x.dtype == torch.int32 or x.dtype == torch.int64 or x.dtype == torch.long) + return isinstance(x, torch.Tensor) and (x.dtype is torch.half or x.dtype is torch.float or x.dtype == torch.bool or x.dtype == torch.int32 or x.dtype == torch.int64 or x.dtype == torch.long) def _make_schema_from_value(value, condition=_default_condition, size=0):