77import json
88import os
99from dataclasses import asdict , dataclass , field
10+ import traceback
1011from typing import Any , Dict , List , Optional , Set , Union
1112
1213import pandas as pd
1617class MetricReturn :
1718 """The return type of the `run` method in the BaseMetric."""
1819
19- value : Union [float , int , bool ]
20+ value : Optional [ Union [float , int , bool ] ]
2021 """The value of the metric."""
2122
2223 unit : Optional [str ] = None
@@ -25,6 +26,9 @@ class MetricReturn:
2526 meta : Dict [str , Any ] = field (default_factory = dict )
2627 """Any useful metadata in a JSON serializable dict."""
2728
29+ error : Optional [str ] = None
30+ """An error message if the metric computation failed."""
31+
2832 added_cols : Set [str ] = field (default_factory = set )
2933 """Columns added to the dataset."""
3034
@@ -73,8 +77,7 @@ def run_metrics(self, metrics: List[BaseMetric]) -> None:
7377 # Load the datasets from the openlayer.json file
7478 self ._load_datasets ()
7579
76- # TODO: Auto-load all the metrics in the current directory
77-
80+ # Compute the metric values
7881 self ._compute_metrics (metrics )
7982
8083 # Write the updated datasets to the output location
@@ -213,10 +216,9 @@ class BaseMetric(abc.ABC):
213216 Your metric's class should inherit from this class and implement the compute method.
214217 """
215218
216- @abc .abstractmethod
217219 def get_key (self ) -> str :
218220 """Return the key of the metric. This should correspond to the folder name."""
219- pass
221+ return os . path . basename ( os . getcwd ())
220222
221223 @property
222224 def key (self ) -> str :
@@ -225,11 +227,27 @@ def key(self) -> str:
225227 def compute (self , datasets : List [Dataset ]) -> None :
226228 """Compute the metric on the model outputs."""
227229 for dataset in datasets :
228- metric_return = self .compute_on_dataset (dataset )
230+ # Check if the metric has already been computed
231+ if os .path .exists (
232+ os .path .join (dataset .output_path , "metrics" , f"{ self .key } .json" )
233+ ):
234+ print (
235+ f"Metric ({ self .key } ) already computed on { dataset .name } . "
236+ "Skipping."
237+ )
238+ continue
239+
240+ try :
241+ metric_return = self .compute_on_dataset (dataset )
242+ except Exception as e : # pylint: disable=broad-except
243+ print (f"Error computing metric ({ self .key } ) on { dataset .name } :" )
244+ print (traceback .format_exc ())
245+ metric_return = MetricReturn (error = str (e ), value = None )
246+
229247 metric_value = metric_return .value
230248 if metric_return .unit :
231249 metric_value = f"{ metric_value } { metric_return .unit } "
232- print (f"Metric ({ self .key } ) value for { dataset .name } : { metric_value } " )
250+ print (f"Metric ({ self .key } ) value on { dataset .name } : { metric_value } " )
233251
234252 output_dir = os .path .join (dataset .output_path , "metrics" )
235253 self ._write_metric_return_to_file (metric_return , output_dir )
0 commit comments