77from keras .src import backend
88from keras .src .backend .mlx .core import convert_to_tensor
99from keras .src .backend .mlx .core import to_mlx_dtype
10+ from keras .src .backend .mlx .random import mlx_draw_seed
1011
1112
1213def rgb_to_grayscale (images , data_format = None ):
@@ -657,17 +658,55 @@ def _compute_weight_mat(
657658 )
658659
659660
660- def elastic_transform (
661- images ,
662- alpha = 20.0 ,
663- sigma = 5.0 ,
664- interpolation = "bilinear" ,
665- fill_mode = "reflect" ,
666- fill_value = 0.0 ,
667- seed = None ,
668- data_format = None ,
669- ):
670- raise NotImplementedError ("elastic_transform not yet implemented in mlx." )
661+ def compute_homography_matrix (start_points , end_points ):
662+ # as implemented for the jax backend
663+ start_points = convert_to_tensor (start_points , dtype = mx .float32 )
664+ end_points = convert_to_tensor (end_points , dtype = mx .float32 )
665+
666+ start_x , start_y = start_points [..., 0 ], start_points [..., 1 ]
667+ end_x , end_y = end_points [..., 0 ], end_points [..., 1 ]
668+
669+ zeros = mx .zeros_like (end_x )
670+ ones = mx .ones_like (end_x )
671+
672+ x_rows = mx .stack (
673+ [
674+ end_x ,
675+ end_y ,
676+ ones ,
677+ zeros ,
678+ zeros ,
679+ zeros ,
680+ - start_x * end_x ,
681+ - start_x * end_y ,
682+ ],
683+ axis = - 1 ,
684+ )
685+ y_rows = mx .stack (
686+ [
687+ zeros ,
688+ zeros ,
689+ zeros ,
690+ end_x ,
691+ end_y ,
692+ ones ,
693+ - start_y * end_x ,
694+ - start_y * end_y ,
695+ ],
696+ axis = - 1 ,
697+ )
698+
699+ coefficient_matrix = mx .concatenate ([x_rows , y_rows ], axis = 1 )
700+
701+ target_vector = mx .expand_dims (
702+ mx .concatenate ([start_x , start_y ], axis = - 1 ), axis = - 1
703+ )
704+
705+ # solve the linear system: coefficient_matrix * homography = target_vector
706+ with mx .stream (mx .cpu ):
707+ homography_matrix = mx .linalg .solve (coefficient_matrix , target_vector )
708+
709+ return homography_matrix .squeeze (- 1 )
671710
672711
673712def perspective_transform (
@@ -678,12 +717,314 @@ def perspective_transform(
678717 fill_value = 0 ,
679718 data_format = None ,
680719):
681- raise NotImplementedError (
682- "perspective_transform not yet implemented in mlx."
720+ # perspective_transform based on implementation in jax backend
721+ data_format = backend .standardize_data_format (data_format )
722+ if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS .keys ():
723+ raise ValueError (
724+ "Invalid value for argument `interpolation`. Expected one of "
725+ f"{ set (AFFINE_TRANSFORM_INTERPOLATIONS .keys ())} . Received: "
726+ f"interpolation={ interpolation } "
727+ )
728+
729+ if len (images .shape ) not in (3 , 4 ):
730+ raise ValueError (
731+ "Invalid images rank: expected rank 3 (single image) "
732+ "or rank 4 (batch of images). Received input with shape: "
733+ f"images.shape={ images .shape } "
734+ )
735+
736+ if start_points .shape [- 2 :] != (4 , 2 ) or start_points .ndim not in (2 , 3 ):
737+ raise ValueError (
738+ "Invalid start_points shape: expected (4,2) for a single image"
739+ f" or (N,4,2) for a batch. Received shape: { start_points .shape } "
740+ )
741+ if end_points .shape [- 2 :] != (4 , 2 ) or end_points .ndim not in (2 , 3 ):
742+ raise ValueError (
743+ "Invalid end_points shape: expected (4,2) for a single image"
744+ f" or (N,4,2) for a batch. Received shape: { end_points .shape } "
745+ )
746+ if start_points .shape != end_points .shape :
747+ raise ValueError (
748+ "start_points and end_points must have the same shape."
749+ f" Received start_points.shape={ start_points .shape } , "
750+ f"end_points.shape={ end_points .shape } "
751+ )
752+
753+ images = convert_to_tensor (images )
754+ start_points = convert_to_tensor (start_points )
755+ end_points = convert_to_tensor (end_points )
756+
757+ need_squeeze = False
758+ if len (images .shape ) == 3 :
759+ images = mx .expand_dims (images , axis = 0 )
760+ need_squeeze = True
761+
762+ if len (start_points .shape ) == 2 :
763+ start_points = mx .expand_dims (start_points , axis = 0 )
764+ if len (end_points .shape ) == 2 :
765+ end_points = mx .expand_dims (end_points , axis = 0 )
766+
767+ if data_format == "channels_first" :
768+ images = mx .transpose (images , (0 , 2 , 3 , 1 ))
769+
770+ batch_size , height , width , channels = images .shape
771+
772+ transforms = compute_homography_matrix (
773+ mx .array (start_points , dtype = mx .float32 ),
774+ mx .array (end_points , dtype = mx .float32 ),
775+ )
776+
777+ x , y = mx .meshgrid (mx .arange (width ), mx .arange (height ), indexing = "xy" )
778+ grid = mx .stack (
779+ [x .flatten (), y .flatten (), mx .ones_like (x ).flatten ()], axis = 0
683780 )
684781
782+ outputs = []
783+ for b in range (batch_size ):
784+ transform = transforms [b ]
785+
786+ # apply homography to grid coordinates
787+ denom = transform [6 ] * grid [0 ] + transform [7 ] * grid [1 ] + 1.0
788+ x_in = (
789+ transform [0 ] * grid [0 ] + transform [1 ] * grid [1 ] + transform [2 ]
790+ ) / denom
791+ y_in = (
792+ transform [3 ] * grid [0 ] + transform [4 ] * grid [1 ] + transform [5 ]
793+ ) / denom
794+
795+ coords = mx .stack ([y_in , x_in ], axis = 0 )
796+
797+ transformed = mx .zeros ((height , width , channels ), dtype = images .dtype )
798+ for c in range (channels ):
799+ transformed_channel = map_coordinates (
800+ images [b , :, :, c ],
801+ coords ,
802+ order = AFFINE_TRANSFORM_INTERPOLATIONS [interpolation ],
803+ fill_mode = "constant" ,
804+ fill_value = fill_value ,
805+ ).reshape (height , width )
806+
807+ transformed = transformed .at [:, :, c ].add (transformed_channel )
808+
809+ outputs .append (transformed )
810+
811+ output = mx .stack (outputs , axis = 0 )
812+
813+ if data_format == "channels_first" :
814+ output = mx .transpose (output , (0 , 3 , 1 , 2 ))
815+ if need_squeeze :
816+ output = mx .squeeze (output , axis = 0 )
817+
818+ return output
819+
685820
686821def gaussian_blur (
687822 images , kernel_size = (3 , 3 ), sigma = (1.0 , 1.0 ), data_format = None
688823):
689- raise NotImplementedError ("gaussian_blur not yet implemented in mlx." )
824+ # gaussian_blur similar to jax backend
825+ def _create_gaussian_kernel (kernel_size , sigma , dtype , num_channels ):
826+ def _get_gaussian_kernel1d (size , sigma ):
827+ x = mx .arange (size , dtype = dtype ) - (size - 1 ) / 2
828+ kernel1d = mx .exp (- 0.5 * (x / sigma ) ** 2 )
829+ return kernel1d / mx .sum (kernel1d )
830+
831+ def _get_gaussian_kernel2d (size , sigma ):
832+ kernel1d_x = _get_gaussian_kernel1d (size [0 ], sigma [0 ])
833+ kernel1d_y = _get_gaussian_kernel1d (size [1 ], sigma [1 ])
834+ return mx .outer (kernel1d_y , kernel1d_x )
835+
836+ kernel2d = _get_gaussian_kernel2d (kernel_size , sigma )
837+
838+ # mlx expects kernel with shape (C_out, spatial..., C_in)
839+ # for depthwise convolution with groups=C, we need (C, H, W, 1)
840+ kernel = kernel2d .reshape (1 , kernel_size [0 ], kernel_size [1 ], 1 )
841+ kernel = mx .tile (kernel , (num_channels , 1 , 1 , 1 ))
842+
843+ return kernel
844+
845+ if len (images .shape ) not in (3 , 4 ):
846+ raise ValueError (
847+ "Invalid images rank: expected rank 3 (single image) "
848+ "or rank 4 (batch of images). Received input with shape: "
849+ f"images.shape={ images .shape } "
850+ )
851+
852+ data_format = backend .standardize_data_format (data_format )
853+ images = convert_to_tensor (images )
854+ sigma = convert_to_tensor (sigma )
855+ dtype = images .dtype
856+
857+ need_squeeze = False
858+ if images .ndim == 3 :
859+ images = images [mx .newaxis , ...]
860+ need_squeeze = True
861+
862+ if data_format == "channels_first" :
863+ images = mx .transpose (images , (0 , 2 , 3 , 1 ))
864+
865+ num_channels = images .shape [- 1 ]
866+
867+ # mx.arange can only take integer input values
868+ kernel_size = tuple (int (k ) for k in kernel_size )
869+ kernel = _create_gaussian_kernel (kernel_size , sigma , dtype , num_channels )
870+
871+ # get padding for 'same' behavior
872+ pad_h = max (0 , (kernel_size [0 ] - 1 ) // 2 )
873+ pad_w = max (0 , (kernel_size [1 ] - 1 ) // 2 )
874+ padding = ((pad_h , pad_h ), (pad_w , pad_w ))
875+
876+ blurred_images = mx .conv_general (
877+ images ,
878+ kernel ,
879+ stride = 1 ,
880+ padding = padding ,
881+ kernel_dilation = 1 ,
882+ input_dilation = 1 ,
883+ groups = num_channels ,
884+ flip = False ,
885+ )
886+
887+ if data_format == "channels_first" :
888+ blurred_images = mx .transpose (blurred_images , (0 , 3 , 1 , 2 ))
889+
890+ if need_squeeze :
891+ blurred_images = mx .squeeze (blurred_images , axis = 0 )
892+
893+ return blurred_images
894+
895+
896+ def elastic_transform (
897+ images ,
898+ alpha = 20.0 ,
899+ sigma = 5.0 ,
900+ interpolation = "bilinear" ,
901+ fill_mode = "reflect" ,
902+ fill_value = 0.0 ,
903+ seed = None ,
904+ data_format = None ,
905+ ):
906+ # elastic_transform based on implementation in jax backend
907+ data_format = backend .standardize_data_format (data_format )
908+ if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS :
909+ raise ValueError (
910+ "Invalid value for argument `interpolation`. Expected one of "
911+ f"{ set (AFFINE_TRANSFORM_INTERPOLATIONS .keys ())} . Received: "
912+ f"interpolation={ interpolation } "
913+ )
914+ if fill_mode not in AFFINE_TRANSFORM_FILL_MODES :
915+ raise ValueError (
916+ "Invalid value for argument `fill_mode`. Expected one of "
917+ f"{ AFFINE_TRANSFORM_FILL_MODES } . Received: fill_mode={ fill_mode } "
918+ )
919+ if len (images .shape ) not in (3 , 4 ):
920+ raise ValueError (
921+ "Invalid images rank: expected rank 3 (single image) "
922+ "or rank 4 (batch of images). Received input with shape: "
923+ f"images.shape={ images .shape } "
924+ )
925+
926+ images = convert_to_tensor (images )
927+ alpha = convert_to_tensor (alpha )
928+ sigma = convert_to_tensor (sigma )
929+ input_dtype = images .dtype
930+ kernel_size = (int (6 * sigma ) | 1 , int (6 * sigma ) | 1 )
931+
932+ need_squeeze = False
933+ if len (images .shape ) == 3 :
934+ images = mx .expand_dims (images , axis = 0 )
935+ need_squeeze = True
936+
937+ if data_format == "channels_last" :
938+ batch_size , height , width , channels = images .shape
939+ channel_axis = - 1
940+ else :
941+ batch_size , channels , height , width = images .shape
942+ channel_axis = 1
943+
944+ mlx_seed = mlx_draw_seed (seed )
945+ if mlx_seed is not None :
946+ seed_dx , seed_dy = mx .random .split (mlx_seed )
947+ else :
948+ seed_dx , seed_dy = mlx_draw_seed (None ), mlx_draw_seed (None )
949+
950+ dx = mx .random .normal (
951+ shape = (batch_size , height , width ),
952+ loc = 0.0 ,
953+ scale = sigma ,
954+ dtype = input_dtype ,
955+ key = seed_dx ,
956+ )
957+
958+ dy = mx .random .normal (
959+ shape = (batch_size , height , width ),
960+ loc = 0.0 ,
961+ scale = sigma ,
962+ dtype = input_dtype ,
963+ key = seed_dy ,
964+ )
965+
966+ dx = gaussian_blur (
967+ mx .expand_dims (dx , axis = channel_axis ),
968+ kernel_size = kernel_size ,
969+ sigma = (sigma , sigma ),
970+ data_format = data_format ,
971+ )
972+ dy = gaussian_blur (
973+ mx .expand_dims (dy , axis = channel_axis ),
974+ kernel_size = kernel_size ,
975+ sigma = (sigma , sigma ),
976+ data_format = data_format ,
977+ )
978+
979+ dx = mx .squeeze (dx , axis = channel_axis )
980+ dy = mx .squeeze (dy , axis = channel_axis )
981+
982+ x_vals = mx .arange (width )
983+ y_vals = mx .arange (height )
984+ x , y = mx .meshgrid (x_vals , y_vals , indexing = "xy" )
985+ x = mx .expand_dims (x , axis = 0 )
986+ y = mx .expand_dims (y , axis = 0 )
987+
988+ distorted_x = x + alpha * dx
989+ distorted_y = y + alpha * dy
990+
991+ transformed_images = mx .zeros_like (images )
992+ if data_format == "channels_last" :
993+ for i in range (channels ):
994+ transformed_channel = []
995+ for b in range (batch_size ):
996+ transformed_channel .append (
997+ map_coordinates (
998+ images [b , :, :, i ],
999+ [distorted_y [b ], distorted_x [b ]],
1000+ order = AFFINE_TRANSFORM_INTERPOLATIONS [interpolation ],
1001+ fill_mode = fill_mode ,
1002+ fill_value = fill_value ,
1003+ )
1004+ )
1005+ transformed_images = transformed_images .at [:, :, :, i ].add (
1006+ mx .stack (transformed_channel )
1007+ )
1008+ else : # channels_first
1009+ for i in range (channels ):
1010+ transformed_channel = []
1011+ for b in range (batch_size ):
1012+ transformed_channel .append (
1013+ map_coordinates (
1014+ images [b , i , :, :],
1015+ [distorted_y [b ], distorted_x [b ]],
1016+ order = AFFINE_TRANSFORM_INTERPOLATIONS [interpolation ],
1017+ fill_mode = fill_mode ,
1018+ fill_value = fill_value ,
1019+ )
1020+ )
1021+ transformed_images = transformed_images .at [:, i , :, :].add (
1022+ mx .stack (transformed_channel )
1023+ )
1024+
1025+ if need_squeeze :
1026+ transformed_images = mx .squeeze (transformed_images , axis = 0 )
1027+
1028+ transformed_images = transformed_images .astype (input_dtype )
1029+
1030+ return transformed_images
0 commit comments