@@ -28,13 +28,15 @@ __pd = None
2828__pyspark = None
2929__tf = None
3030__K = None
31+ __torch = None
3132__ipywidgets = None
3233
3334
3435def _check_imported():
35- global __np, __pd, __pyspark, __tf, __K, __ipywidgets
36+ global __np, __pd, __pyspark, __tf, __K, __torch, __ipywidgets
3637
37- if 'numpy' in sys.modules:
38+ if '
39+ ' in sys.modules:
3840 # don't really need the try
3941 import numpy as __np
4042
@@ -55,6 +57,9 @@ def _check_imported():
5557 except ImportError:
5658 __K = None
5759
60+ if 'torch' in sys.modules:
61+ import torch as __torch
62+
5863 if 'ipywidgets' in sys.modules:
5964 import ipywidgets as __ipywidgets
6065
@@ -66,6 +71,8 @@ def _jupyterlab_variableinspector_getsizeof(x):
6671 return "?"
6772 elif __tf and isinstance(x, __tf.Variable):
6873 return "?"
74+ elif __torch and isinstance(x, __torch.Tensor):
75+ return x.element_size() * x.nelement()
6976 elif __pd and type(x).__name__ == 'DataFrame':
7077 return x.memory_usage().sum()
7178 else:
@@ -88,6 +95,9 @@ def _jupyterlab_variableinspector_getshapeof(x):
8895 if __tf and isinstance(x, __tf.Tensor):
8996 shape = " x ".join([str(int(i)) for i in x.shape])
9097 return "%s" % shape
98+ if __torch and isinstance(x, __torch.Tensor):
99+ shape = " x ".join([str(int(i)) for i in x.shape])
100+ return "%s" % shape
91101 if isinstance(x, list):
92102 return "%s" % len(x)
93103 if isinstance(x, dict):
@@ -129,6 +139,8 @@ def _jupyterlab_variableinspector_is_matrix(x):
129139 return True
130140 if __tf and isinstance(x, __tf.Tensor) and len(x.shape) <= 2:
131141 return True
142+ if __torch and isinstance(x, __torch.Tensor) and len(x.shape) <= 2:
143+ return True
132144 if isinstance(x, list):
133145 return True
134146 return False
@@ -153,7 +165,7 @@ def _jupyterlab_variableinspector_dict_list():
153165 return True
154166 if str(obj)[0] == "<":
155167 return False
156- if v in ['__np', '__pd', '__pyspark', '__tf', '__K', '__ipywidgets']:
168+ if v in ['__np', '__pd', '__pyspark', '__tf', '__K', '__torch', ' __ipywidgets']:
157169 return obj is not None
158170 if str(obj).startswith("_Feature"):
159171 # removes tf/keras objects
@@ -199,6 +211,9 @@ def _jupyterlab_variableinspector_getmatrixcontent(x, max_rows=10000):
199211 elif __tf and (isinstance(x, __tf.Variable) or isinstance(x, __tf.Tensor)):
200212 df = __K.get_value(x)
201213 return _jupyterlab_variableinspector_getmatrixcontent(df)
214+ elif __torch and __pd and isinstance(x, torch.Tensor):
215+ df = x.cpu().numpy()
216+ return _jupyterlab_variableinspector_getmatrixcontent(df)
202217 elif isinstance(x, list):
203218 s = __pd.Series(x)
204219 return _jupyterlab_variableinspector_getmatrixcontent(s)
0 commit comments