diff --git a/src/lighteval/tasks/lighteval_task.py b/src/lighteval/tasks/lighteval_task.py index b84d421a6..bc1771ade 100644 --- a/src/lighteval/tasks/lighteval_task.py +++ b/src/lighteval/tasks/lighteval_task.py @@ -32,7 +32,7 @@ from lighteval.metrics.metrics import Metrics from lighteval.metrics.metrics_sample import SamplingMetric -from lighteval.metrics.utils.metric_utils import Metric +from lighteval.metrics.utils.metric_utils import Metric, MetricGrouping from lighteval.tasks.prompt_manager import FewShotSampler from lighteval.tasks.requests import ( Doc, @@ -167,15 +167,8 @@ def __str__(self, lite: bool = False): continue if k == "metrics": for ix, metrics in enumerate(v): - for metric_k, metric_v in metrics.items(): - if isinstance(metric_v, Callable): - repr_v = metric_v.__name__ - elif isinstance(metric_v, Metric.get_allowed_types_for_metrics()): - repr_v = str(metric_v) - else: - repr_v = repr(metric_v) - values.append([f"{k} {ix}: {metric_k}", repr_v]) - + item_str_list = self._build_metrics_item_str_list(ix, metrics) + values.extend(item_str_list) else: if isinstance(v, Callable): values.append([k, v.__name__]) @@ -186,6 +179,35 @@ def __str__(self, lite: bool = False): return md_writer.dumps() + def _build_metrics_item_str_list(self, ix: int, metrics: dict) -> list: + values = [] + + is_metric_grouping = False + + if isinstance(getattr(self, "metrics")[ix], MetricGrouping): + is_metric_grouping = True + + for metric_k, metric_v in metrics.items(): + if is_metric_grouping and isinstance(metric_v, dict): + for metric_sub_k, metric_sub_v in metric_v.items(): + if isinstance(metric_sub_v, Callable): + repr_v = metric_sub_v.__name__ + elif isinstance(metric_sub_v, Metric.get_allowed_types_for_metrics()): + repr_v = str(metric_sub_v) + else: + repr_v = repr(metric_sub_v) + values.append([f"metrics {ix}: {metric_k}: {metric_sub_k}", repr_v]) + else: + if isinstance(metric_v, Callable): + repr_v = metric_v.__name__ + elif isinstance(metric_v, Metric.get_allowed_types_for_metrics()): + repr_v = str(metric_v) + else: + repr_v = repr(metric_v) + values.append([f"metrics {ix}: {metric_k}", repr_v]) + + return values + def print(self, lite: bool = False): print(str(self, lite))