@@ -43,7 +43,8 @@ class Metrics(object):
4343 ROUGE_2_F = "rouge_2_fscore"
4444 ROUGE_L_F = "rouge_L_fscore"
4545 EDIT_DISTANCE = "edit_distance"
46-
46+ SET_PRECISION = 'set_precision'
47+ SET_RECALL = 'set_recall'
4748
4849def padded_rmse (predictions , labels , weights_fn = common_layers .weights_all ):
4950 predictions , labels = common_layers .pad_with_zeros (predictions , labels )
@@ -188,6 +189,48 @@ def padded_accuracy(predictions,
188189 padded_labels = tf .to_int32 (padded_labels )
189190 return tf .to_float (tf .equal (outputs , padded_labels )), weights
190191
192+ def set_precision (predictions ,
193+ labels ,
194+ weights_fn = common_layers .weights_nonzero ):
195+ """Precision of set predictions.
196+
197+ Args:
198+ predictions : A Tensor of scores of shape (batch, nlabels)
199+ labels: A Tensor of int32s giving true set elements of shape (batch, seq_length)
200+
201+ Returns:
202+ hits: A Tensor of shape (batch, nlabels)
203+ weights: A Tensor of shape (batch, nlabels)
204+ """
205+ with tf .variable_scope ("set_precision" , values = [predictions , labels ]):
206+ labels = tf .squeeze (labels , [2 , 3 ])
207+ labels = tf .one_hot (labels , predictions .shape [- 1 ])
208+ labels = tf .reduce_max (labels , axis = 1 )
209+ labels = tf .cast (labels , tf .bool )
210+ predictions = predictions > 0
211+ return tf .to_float (tf .equal (labels , predictions )), tf .to_float (predictions )
212+
213+ def set_recall (predictions ,
214+ labels ,
215+ weights_fn = common_layers .weights_nonzero ):
216+ """Recall of set predictions.
217+
218+ Args:
219+ predictions : A Tensor of scores of shape (batch, nlabels)
220+ labels: A Tensor of int32s giving true set elements of shape (batch, seq_length)
221+
222+ Returns:
223+ hits: A Tensor of shape (batch, nlabels)
224+ weights: A Tensor of shape (batch, nlabels)
225+ """
226+ with tf .variable_scope ("set_recall" , values = [predictions , labels ]):
227+ labels = tf .squeeze (labels , [2 , 3 ])
228+ labels = tf .one_hot (labels , predictions .shape [- 1 ])
229+ labels = tf .reduce_max (labels , axis = 1 )
230+ labels = tf .cast (labels , tf .bool )
231+ predictions = predictions > 0
232+ return tf .to_float (tf .equal (labels , predictions )), tf .to_float (labels )
233+
191234
192235def create_evaluation_metrics (problems , model_hparams ):
193236 """Creates the evaluation metrics for the model.
@@ -278,4 +321,6 @@ def wrapped_metric_fn():
278321 Metrics .ROUGE_2_F : rouge .rouge_2_fscore ,
279322 Metrics .ROUGE_L_F : rouge .rouge_l_fscore ,
280323 Metrics .EDIT_DISTANCE : sequence_edit_distance ,
324+ Metrics .SET_PRECISION : set_precision ,
325+ Metrics .SET_RECALL : set_recall ,
281326}
0 commit comments