diff --git a/examples/vision/metric_learning_tf_similarity.py b/examples/vision/metric_learning_tf_similarity.py index ffcfc1c4dc..7436fd10da 100644 --- a/examples/vision/metric_learning_tf_similarity.py +++ b/examples/vision/metric_learning_tf_similarity.py @@ -54,7 +54,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras +import keras import tensorflow_similarity as tfsim @@ -216,13 +216,14 @@ val_steps = 50 # init similarity loss -loss = tfsim.losses.MultiSimilarityLoss() +loss = tfsim.losses.MultiSimilarityLoss(reduction='sum_over_batch_size') # compiling and training model.compile( optimizer=keras.optimizers.Adam(learning_rate), loss=loss, steps_per_execution=10, + run_eagerly=True, ) history = model.fit( train_ds, epochs=epochs, validation_data=val_ds, validation_steps=val_steps @@ -321,7 +322,7 @@ for idx in np.argsort(y_display): tfsim.visualization.viz_neigbors_imgs( x_display[idx], - y_display[idx], + y_display[idx].numpy(), nns[idx], class_mapping=class_mapping, fig_size=(16, 2), @@ -394,7 +395,7 @@ """ idx_no_match = np.where(np.array(matches) == 10) -no_match_queries = x_confusion[idx_no_match] +no_match_queries = keras.ops.take(x_confusion, keras.ops.cast(idx_no_match[0], dtype="int32"), axis=0) if len(no_match_queries): plt.imshow(no_match_queries[0]) else: