@@ -218,7 +218,7 @@ def propagate_data_to_dependencies(
218218 "OneOfDict" ,
219219 "LoadImage" ,
220220 "SampleToMasks" , # TODO ***MG***
221- "AsType" , # TODO ***MG***
221+ "AsType" ,
222222 "ChannelFirst2d" ,
223223 "Upscale" , # TODO ***AL***
224224 "NonOverlapping" , # TODO ***AL***
@@ -7785,9 +7785,9 @@ def _process_and_get(
77857785class AsType (Feature ):
77867786 """Convert the data type of images.
77877787
7788- This feature changes the data type (`dtype`) of input images to a specified
7789- type. The accepted types are the same as those used by NumPy arrays, such
7790- as ` float64`, `int32`, `uint16 `, `int16 `, `uint8`, and `int8` .
7788+ `Astype` changes the data type (`dtype`) of input images to a specified
7789+ type. The accepted types are standard NumPy or PyTorch data types (e.g.,
7790+ `" float64" `, `" int32" `, `"uint8" `, `"int8" `, and `"torch.float32"`) .
77917791
77927792 Parameters
77937793 ----------
@@ -7810,7 +7810,7 @@ class AsType(Feature):
78107810 >>>
78117811 >>> input_image = np.array([1.5, 2.5, 3.5])
78127812
7813- Apply an AsType feature to convert to `int32`:
7813+ Apply an AsType feature to convert to " `int32" `:
78147814 >>> astype_feature = dt.AsType(dtype="int32")
78157815 >>> output_image = astype_feature.get(input_image, dtype="int32")
78167816 >>> output_image
@@ -7827,8 +7827,7 @@ def __init__(
78277827 dtype : PropertyLike [str ] = "float64" ,
78287828 ** kwargs : Any ,
78297829 ):
7830- """
7831- Initialize the AsType feature.
7830+ """Initialize the AsType feature.
78327831
78337832 Parameters
78347833 ----------
@@ -7867,7 +7866,39 @@ def get(
78677866
78687867 """
78697868
7870- return image .astype (dtype )
7869+ if apc .is_torch_array (image ):
7870+ # Mapping from string to torch dtype
7871+ torch_dtypes = {
7872+ "float64" : torch .float64 ,
7873+ "double" : torch .float64 ,
7874+ "float32" : torch .float32 ,
7875+ "float" : torch .float32 ,
7876+ "float16" : torch .float16 ,
7877+ "half" : torch .float16 ,
7878+ "int64" : torch .int64 ,
7879+ "int32" : torch .int32 ,
7880+ "int16" : torch .int16 ,
7881+ "int8" : torch .int8 ,
7882+ "uint8" : torch .uint8 ,
7883+ "bool" : torch .bool ,
7884+ "complex64" : torch .complex64 ,
7885+ "complex128" : torch .complex128 ,
7886+ }
7887+
7888+ # Ensure `"torch.float32"` and `"float32"` are treated the same by
7889+ # removing the `torch.` prefix if present
7890+ dtype_str = str (dtype ).replace ("torch." , "" )
7891+ torch_dtype = torch_dtypes .get (dtype_str )
7892+
7893+ if torch_dtype is None :
7894+ raise ValueError (
7895+ f"Unsupported dtype for torch.Tensor: { dtype } "
7896+ )
7897+
7898+ return image .to (dtype = torch_dtype )
7899+
7900+ else :
7901+ return image .astype (dtype )
78717902
78727903
78737904class ChannelFirst2d (Feature ): # DEPRECATED
0 commit comments