From ba408a92509f1c44b385e2ec7851c807d5a0b2e0 Mon Sep 17 00:00:00 2001 From: Shipra <138140065+Shi-pra-19@users.noreply.github.com> Date: Tue, 28 Oct 2025 21:40:01 +0530 Subject: [PATCH 1/2] migrate metric learning example to keras 3 Refactor loss function and tensor handling in metric learning for compatibility with Keras 3. --- examples/vision/metric_learning_tf_similarity.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/vision/metric_learning_tf_similarity.py b/examples/vision/metric_learning_tf_similarity.py index ffcfc1c4dc..4d9a2dc4c0 100644 --- a/examples/vision/metric_learning_tf_similarity.py +++ b/examples/vision/metric_learning_tf_similarity.py @@ -54,7 +54,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras +import keras import tensorflow_similarity as tfsim @@ -216,13 +216,14 @@ val_steps = 50 # init similarity loss -loss = tfsim.losses.MultiSimilarityLoss() +loss = tfsim.losses.MultiSimilarityLoss(reduction='sum_over_batch_size') # compiling and training model.compile( optimizer=keras.optimizers.Adam(learning_rate), loss=loss, steps_per_execution=10, + run_eagerly=True, ) history = model.fit( train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps @@ -321,7 +322,7 @@ for idx in np.argsort(y_display): tfsim.visualization.viz_neigbors_imgs( x_display[idx], - y_display[idx], + y_display[idx].numpy(), nns[idx], class_mapping=class_mapping, fig_size=(16, 2), @@ -394,7 +395,7 @@ """ idx_no_match = np.where(np.array(matches) == 10) -no_match_queries = x_confusion[idx_no_match] +no_match_queries = no_match_queries = keras.ops.take(x_confusion, keras.ops.cast(idx_no_match[0], dtype="int32"), axis=0) if len(no_match_queries): plt.imshow(no_match_queries[0]) else: From 87c5a8f0823a364267fc473d89683fd3bac3b39d Mon Sep 17 00:00:00 2001 From: Shipra <138140065+Shi-pra-19@users.noreply.github.com> Date: Tue, 28 Oct 2025 22:25:02 +0530 Subject: [PATCH 2/2] Fix duplicate assignment in no_match_queries --- examples/vision/metric_learning_tf_similarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/vision/metric_learning_tf_similarity.py b/examples/vision/metric_learning_tf_similarity.py index 4d9a2dc4c0..7436fd10da 100644 --- a/examples/vision/metric_learning_tf_similarity.py +++ b/examples/vision/metric_learning_tf_similarity.py @@ -395,7 +395,7 @@ """ idx_no_match = np.where(np.array(matches) == 10) -no_match_queries = no_match_queries = keras.ops.take(x_confusion, keras.ops.cast(idx_no_match[0], dtype="int32"), axis=0) +no_match_queries = keras.ops.take(x_confusion, keras.ops.cast(idx_no_match[0], dtype="int32"), axis=0) if len(no_match_queries): plt.imshow(no_match_queries[0]) else: