Skip to content

Commit 12c091c

Browse files
Mg/features/astype (#402)
* adding torch compatibility * added unittesting for torch * minor change * update astype * update astype
1 parent 2a89062 commit 12c091c

File tree

2 files changed

+80
-12
lines changed

2 files changed

+80
-12
lines changed

deeptrack/features.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def propagate_data_to_dependencies(
218218
"OneOfDict",
219219
"LoadImage",
220220
"SampleToMasks", # TODO ***MG***
221-
"AsType", # TODO ***MG***
221+
"AsType",
222222
"ChannelFirst2d",
223223
"Upscale", # TODO ***AL***
224224
"NonOverlapping", # TODO ***AL***
@@ -7785,9 +7785,9 @@ def _process_and_get(
77857785
class AsType(Feature):
77867786
"""Convert the data type of images.
77877787
7788-
This feature changes the data type (`dtype`) of input images to a specified
7789-
type. The accepted types are the same as those used by NumPy arrays, such
7790-
as `float64`, `int32`, `uint16`, `int16`, `uint8`, and `int8`.
7788+
`Astype` changes the data type (`dtype`) of input images to a specified
7789+
type. The accepted types are standard NumPy or PyTorch data types (e.g.,
7790+
`"float64"`, `"int32"`, `"uint8"`, `"int8"`, and `"torch.float32"`).
77917791
77927792
Parameters
77937793
----------
@@ -7810,7 +7810,7 @@ class AsType(Feature):
78107810
>>>
78117811
>>> input_image = np.array([1.5, 2.5, 3.5])
78127812
7813-
Apply an AsType feature to convert to `int32`:
7813+
Apply an AsType feature to convert to "`int32"`:
78147814
>>> astype_feature = dt.AsType(dtype="int32")
78157815
>>> output_image = astype_feature.get(input_image, dtype="int32")
78167816
>>> output_image
@@ -7827,8 +7827,7 @@ def __init__(
78277827
dtype: PropertyLike[str] = "float64",
78287828
**kwargs: Any,
78297829
):
7830-
"""
7831-
Initialize the AsType feature.
7830+
"""Initialize the AsType feature.
78327831
78337832
Parameters
78347833
----------
@@ -7867,7 +7866,39 @@ def get(
78677866
78687867
"""
78697868

7870-
return image.astype(dtype)
7869+
if apc.is_torch_array(image):
7870+
# Mapping from string to torch dtype
7871+
torch_dtypes = {
7872+
"float64": torch.float64,
7873+
"double": torch.float64,
7874+
"float32": torch.float32,
7875+
"float": torch.float32,
7876+
"float16": torch.float16,
7877+
"half": torch.float16,
7878+
"int64": torch.int64,
7879+
"int32": torch.int32,
7880+
"int16": torch.int16,
7881+
"int8": torch.int8,
7882+
"uint8": torch.uint8,
7883+
"bool": torch.bool,
7884+
"complex64": torch.complex64,
7885+
"complex128": torch.complex128,
7886+
}
7887+
7888+
# Ensure `"torch.float32"` and `"float32"` are treated the same by
7889+
# removing the `torch.` prefix if present
7890+
dtype_str = str(dtype).replace("torch.", "")
7891+
torch_dtype = torch_dtypes.get(dtype_str)
7892+
7893+
if torch_dtype is None:
7894+
raise ValueError(
7895+
f"Unsupported dtype for torch.Tensor: {dtype}"
7896+
)
7897+
7898+
return image.to(dtype=torch_dtype)
7899+
7900+
else:
7901+
return image.astype(dtype)
78717902

78727903

78737904
class ChannelFirst2d(Feature): # DEPRECATED

deeptrack/tests/test_features.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,11 +2003,48 @@ def test_AsType(self):
20032003
np.all(output_image == np.array([1, 2, 3], dtype=dtype))
20042004
)
20052005

2006-
# Test for Image.
2007-
#TODO
2006+
### Test with PyTorch tensor (if available)
2007+
if TORCH_AVAILABLE:
2008+
input_image_torch = torch.tensor([1.5, 2.5, 3.5])
2009+
2010+
data_types_torch = [
2011+
"float64",
2012+
"int32",
2013+
"int16",
2014+
"uint8",
2015+
"int8",
2016+
"torch.float64",
2017+
"torch.int32",
2018+
]
20082019

2009-
# Test for PyTorch tensors.
2010-
#TODO
2020+
torch_dtypes_map = {
2021+
"float64": torch.float64,
2022+
"int32": torch.int32,
2023+
"int16": torch.int16,
2024+
"uint8": torch.uint8,
2025+
"int8": torch.int8,
2026+
"torch.float64": torch.float64,
2027+
"torch.int32": torch.int32,
2028+
}
2029+
2030+
for dtype in data_types_torch:
2031+
astype_feature = features.AsType(dtype=dtype)
2032+
output_image = astype_feature.get(
2033+
input_image_torch, dtype=dtype
2034+
)
2035+
expected_dtype = torch_dtypes_map[dtype]
2036+
self.assertEqual(output_image.dtype, expected_dtype)
2037+
2038+
# Additional check for specific behavior of integers.
2039+
if expected_dtype in [
2040+
torch.int8,
2041+
torch.int16,
2042+
torch.int32,
2043+
torch.uint8,
2044+
]:
2045+
# Verify that fractional parts are truncated
2046+
expected = torch.tensor([1, 2, 3], dtype=expected_dtype)
2047+
self.assertTrue(torch.equal(output_image, expected))
20112048

20122049

20132050
def test_ChannelFirst2d(self):

0 commit comments

Comments
 (0)