Skip to content

Commit 57ab706

Browse files
committed
feature: Added time summarizer
1 parent 278cdac commit 57ab706

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
""" A module that provides a class that save data about how the time of estimation is distributed """
2+
3+
from pathlib import Path
4+
5+
import numpy as np
6+
import yaml
7+
8+
from experimental_env.analysis.analyze_summarizers.analysis_summarizer import (
9+
AnalysisSummarizer,
10+
)
11+
from experimental_env.experiment.experiment_description import ExperimentDescription
12+
from experimental_env.utils import round_sig
13+
14+
15+
class TimeSummarizer(AnalysisSummarizer):
16+
"""
17+
A class that calculates the average error for a dataset using the selected metric.
18+
"""
19+
20+
def calculate(self, results: list[ExperimentDescription]) -> tuple:
21+
"""
22+
Helper function for calculating mean and standard deviation of time
23+
"""
24+
times = []
25+
for result in results:
26+
time = np.sum(step.time for step in result.steps)
27+
times.append(time)
28+
29+
mean = np.sum(times) / len(times)
30+
deviation = np.sqrt(np.sum((x - mean) ** 2 for x in times) / len(times))
31+
return float(mean), float(deviation)
32+
33+
def analyze_method(self, results: list[ExperimentDescription], method: str):
34+
mean, deviation = self.calculate(results)
35+
36+
info_dict = {
37+
"mean": round_sig(mean, 3),
38+
"standart_deviation": round_sig(deviation, 3),
39+
}
40+
yaml_path: Path = self._out_dir.joinpath("time_info.yaml")
41+
42+
with open(yaml_path, "w", encoding="utf-8") as file:
43+
yaml.dump(info_dict, file)
44+
45+
def compare_methods(
46+
self,
47+
results_1: list[ExperimentDescription],
48+
results_2: list[ExperimentDescription],
49+
method_1: str,
50+
method_2: str,
51+
):
52+
mean_1, deviation_1 = self.calculate(results_1)
53+
mean_2, deviation_2 = self.calculate(results_2)
54+
55+
info_dict = {
56+
f"{method_1}_mean": round_sig(mean_1, 3),
57+
f"{method_1}_standart_deviation": round_sig(deviation_1, 3),
58+
f"{method_2}_mean": round_sig(mean_2, 3),
59+
f"{method_2}_standart_deviation": round_sig(deviation_2, 3),
60+
}
61+
yaml_path: Path = self._out_dir.joinpath("time_info.yaml")
62+
63+
with open(yaml_path, "w", encoding="utf-8") as file:
64+
yaml.dump(info_dict, file)

0 commit comments

Comments
 (0)