Skip to content

Commit 4bedc3f

Browse files
committed
fix tf metric use
fix log flake it Signed-off-by: arthurPignet <arthur.pignet@mines-paristech.fr>
1 parent 28db75b commit 4bedc3f

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

mplc/contributivity.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,16 @@ def __str__(self):
8585
+ str(self.first_charac_fct_calls_count)
8686
+ "\n"
8787
)
88-
output += f"Contributivity scores: {np.round(self.contributivity_scores, 3)}\n"
89-
output += f"Std of the contributivity scores: {np.round(self.scores_std, 3)}\n"
90-
output += f"Normalized contributivity scores: {np.round(self.normalized_scores, 3)}\n"
88+
if isinstance(self.contributivity_scores, dict):
89+
for key, value in self.contributivity_scores.items():
90+
output += f'Metric: {key}\n'
91+
output += f"Contributivity scores : {np.round(value, 3)}\n"
92+
output += f"Std of the contributivity scores: {np.round(self.scores_std[key], 3)}\n"
93+
output += f"Normalized contributivity scores: {np.round(self.normalized_scores[key], 3)}\n"
94+
else:
95+
output += f"Contributivity scores : {np.round(self.contributivity_scores, 3)}\n"
96+
output += f"Std of the contributivity scores: {np.round(self.scores_std, 3)}\n"
97+
output += f"Normalized contributivity scores: {np.round(self.normalized_scores, 3)}\n"
9198

9299
return output
93100

@@ -1119,24 +1126,29 @@ def statistcal_distances_via_smodel(self):
11191126
start = timer()
11201127
try:
11211128
mpl_pretrain = self.scenario.mpl.pretrain_epochs
1122-
except AttributeError as e:
1129+
except AttributeError:
11231130
mpl_pretrain = 2
11241131

11251132
mpl = fast_mpl.FastFedAvgSmodel(self.scenario, mpl_pretrain)
11261133
mpl.fit()
11271134
cross_entropy = tf.keras.metrics.CategoricalCrossentropy()
1128-
self.contributivity_scores = {'Kullbakc divergence': [0 for _ in mpl.partners_list],
1129-
'ma': [0 for _ in mpl.partners_list], 'Hennigen': [0 for _ in mpl.partners_list]}
1135+
self.contributivity_scores = {'Kullback Leiber divergence': [0 for _ in mpl.partners_list],
1136+
'Bhattacharyya distance': [0 for _ in mpl.partners_list],
1137+
'Hellinger metric': [0 for _ in mpl.partners_list]}
1138+
self.scores_std = {'Kullback Leiber divergence': [0 for _ in mpl.partners_list],
1139+
'Bhattacharyya distance': [0 for _ in mpl.partners_list],
1140+
'Hellinger metric': [0 for _ in mpl.partners_list]}
1141+
# TODO; The variance of our estimation is likely to be estimated.
1142+
11301143
for i, partnerMpl in enumerate(mpl.partners_list):
11311144
y_global = mpl.model.predict(partnerMpl.x_train)
11321145
y_local = mpl.smodel_list[i].predict(y_global)
11331146
cross_entropy.update_state(y_global, y_local)
11341147
cs = cross_entropy.result().numpy()
1135-
cross_entropy.reset_state()
1148+
cross_entropy.reset_states()
11361149
cross_entropy.update_state(y_global, y_global)
11371150
e = cross_entropy.result().numpy()
1138-
cross_entropy.reset_state()
1139-
self.contributivity_scores['Kullbakc divergence'][i] = cs - e
1151+
cross_entropy.reset_states()
11401152
BC = 0
11411153
for y_g, y_l in zip(y_global, y_local):
11421154
BC += np.sum(np.sqrt(y_g * y_l))
@@ -1146,7 +1158,6 @@ def statistcal_distances_via_smodel(self):
11461158
self.contributivity_scores['Hellinger metric'][i] = np.sqrt(1 - BC)
11471159

11481160
self.name = "Statistic metric via S-model"
1149-
self.scores_std = np.zeros(mpl.partners_count)
11501161
self.normalized_scores = {}
11511162
for key, value in self.contributivity_scores.items():
11521163
self.normalized_scores[key] = value / np.sum(value)

0 commit comments

Comments
 (0)