77import json
88import os
99from dataclasses import asdict , dataclass , field
10- from typing import Any , Dict , List , Optional , Union
10+ from typing import Any , Dict , List , Optional , Union , Set
1111
1212import pandas as pd
1313
@@ -25,6 +25,9 @@ class MetricReturn:
2525 meta : Dict [str , Any ] = field (default_factory = dict )
2626 """Any useful metadata in a JSON serializable dict."""
2727
28+ added_cols : Set [str ] = field (default_factory = set )
29+ """Columns added to the dataset."""
30+
2831
2932@dataclass
3033class Dataset :
@@ -42,6 +45,12 @@ class Dataset:
4245 output_path : str
4346 """The path to the dataset outputs."""
4447
48+ data_format : str
49+ """The format of the written dataset. E.g. 'csv' or 'json'."""
50+
51+ added_cols : Set [str ] = field (default_factory = set )
52+ """Columns added to the dataset."""
53+
4554
4655class MetricRunner :
4756 """A class to run a list of metrics."""
@@ -68,6 +77,9 @@ def run_metrics(self, metrics: List[BaseMetric]) -> None:
6877
6978 self ._compute_metrics (metrics )
7079
80+ # Write the updated datasets to the output location
81+ self ._write_updated_datasets_to_output ()
82+
7183 def _parse_args (self ) -> None :
7284 parser = argparse .ArgumentParser (description = "Compute custom metrics." )
7385 parser .add_argument (
@@ -124,13 +136,21 @@ def _load_datasets(self) -> None:
124136 # Load the dataset into a pandas DataFrame
125137 if os .path .exists (os .path .join (dataset_path , "dataset.csv" )):
126138 dataset_df = pd .read_csv (os .path .join (dataset_path , "dataset.csv" ))
139+ data_format = "csv"
127140 elif os .path .exists (os .path .join (dataset_path , "dataset.json" )):
128141 dataset_df = pd .read_json (os .path .join (dataset_path , "dataset.json" ), orient = "records" )
142+ data_format = "json"
129143 else :
130144 raise ValueError (f"No dataset found in { dataset_folder } ." )
131145
132146 datasets .append (
133- Dataset (name = dataset_folder , config = dataset_config , df = dataset_df , output_path = dataset_path )
147+ Dataset (
148+ name = dataset_folder ,
149+ config = dataset_config ,
150+ df = dataset_df ,
151+ output_path = dataset_path ,
152+ data_format = data_format ,
153+ )
134154 )
135155 else :
136156 raise ValueError ("No model found in the openlayer.json file. Cannot compute metric." )
@@ -148,6 +168,31 @@ def _compute_metrics(self, metrics: List[BaseMetric]) -> None:
148168 continue
149169 metric .compute (self .datasets )
150170
171+ def _write_updated_datasets_to_output (self ) -> None :
172+ """Write the updated datasets to the output location."""
173+ for dataset in self .datasets :
174+ if dataset .added_cols :
175+ self ._write_updated_dataset_to_output (dataset )
176+
177+ def _write_updated_dataset_to_output (self , dataset : Dataset ) -> None :
178+ """Write the updated dataset to the output location."""
179+
180+ # Determine the filename based on the dataset name and format
181+ filename = f"dataset.{ dataset .data_format } "
182+ data_path = os .path .join (dataset .output_path , filename )
183+
184+ # TODO: Read the dataset again and only include the added columns
185+
186+ # Write the DataFrame to the file based on the specified format
187+ if dataset .data_format == "csv" :
188+ dataset .df .to_csv (data_path , index = False )
189+ elif dataset .data_format == "json" :
190+ dataset .df .to_json (data_path , orient = "records" , indent = 4 , index = False )
191+ else :
192+ raise ValueError ("Unsupported format. Please choose 'csv' or 'json'." )
193+
194+ print (f"Updated dataset { dataset .name } written to { data_path } " )
195+
151196
152197class BaseMetric (abc .ABC ):
153198 """Interface for the Base metric.
@@ -163,7 +208,7 @@ def key(self) -> str:
163208 def compute (self , datasets : List [Dataset ]) -> None :
164209 """Compute the metric on the model outputs."""
165210 for dataset in datasets :
166- metric_return = self .compute_on_dataset (dataset . config , dataset . df )
211+ metric_return = self .compute_on_dataset (dataset )
167212 metric_value = metric_return .value
168213 if metric_return .unit :
169214 metric_value = f"{ metric_value } { metric_return .unit } "
@@ -172,8 +217,12 @@ def compute(self, datasets: List[Dataset]) -> None:
172217 output_dir = os .path .join (dataset .output_path , "metrics" )
173218 self ._write_metric_return_to_file (metric_return , output_dir )
174219
220+ # Add the added columns to the dataset
221+ if metric_return .added_cols :
222+ dataset .added_cols .update (metric_return .added_cols )
223+
175224 @abc .abstractmethod
176- def compute_on_dataset (self , config : dict , df : pd . DataFrame ) -> MetricReturn :
225+ def compute_on_dataset (self , dataset : Dataset ) -> MetricReturn :
177226 """Compute the metric on a specific dataset."""
178227 pass
179228
@@ -183,6 +232,9 @@ def _write_metric_return_to_file(self, metric_return: MetricReturn, output_dir:
183232 # Create the directory if it doesn't exist
184233 os .makedirs (output_dir , exist_ok = True )
185234
235+ # Turn the metric return to a dict
236+ metric_return_dict = asdict (metric_return )
237+
186238 with open (os .path .join (output_dir , f"{ self .key } .json" ), "w" , encoding = "utf-8" ) as f :
187- json .dump (asdict ( metric_return ) , f , indent = 4 )
239+ json .dump (metric_return_dict , f , indent = 4 )
188240 print (f"Metric ({ self .key } ) value written to { output_dir } /{ self .key } .json" )
0 commit comments