Skip to content

Commit deab302

Browse files
committed
Create and use axl_filename
1 parent 6568fc4 commit deab302

File tree

8 files changed

+41
-19
lines changed

8 files changed

+41
-19
lines changed

axelrod/load_data_.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
1-
from typing import Dict, List, Tuple
1+
from typing import Dict, List, Text, Tuple
22

3+
import os
34
import pkg_resources
45

56

7+
def axl_filename(axl_path: Text) -> Text:
8+
"""Get the path to Axelrod/<axl_path> from the working directory."""
9+
# Working directory
10+
dirname = os.path.dirname(__file__)
11+
12+
# We go up a dir because this code is located in Axelrod/axelrod and
13+
# axl_path is from the top-level Axelrod dir.
14+
return os.path.join(dirname, "..", axl_path)
15+
16+
617
def load_file(filename: str, directory: str) -> List[List[str]]:
718
"""Loads a data file stored in the Axelrod library's data subdirectory,
819
likely for parameters for a strategy."""

axelrod/plot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numpy import arange, median, nan_to_num
99

1010
from .result_set import ResultSet
11+
from .load_data_ import axl_filename
1112

1213
titleType = List[str]
1314
namesType = List[str]
@@ -323,7 +324,7 @@ def save_all_plots(
323324

324325
for method, name in plots:
325326
f = getattr(self, method)(title="{} - {}".format(title_prefix, name))
326-
f.savefig("{}_{}.{}".format(prefix, method, filetype))
327+
f.savefig(axl_filename("{}_{}.{}".format(prefix, method, filetype)))
327328
plt.close(f)
328329

329330
if progress_bar:

axelrod/tests/integration/test_tournament.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from hypothesis import given, settings
55

66
import axelrod
7+
from axelrod.load_data_ import axl_filename
78
from axelrod.strategy_transformers import FinalTransformer
89
from axelrod.tests.property import tournaments
910

@@ -45,7 +46,7 @@ def setUpClass(cls):
4546
def test_big_tournaments(self, tournament):
4647
"""A test to check that tournament runs with a sample of non-cheating
4748
strategies."""
48-
filename = "test_outputs/test_tournament.csv"
49+
filename = axl_filename("test_outputs/test_tournament.csv")
4950
self.assertIsNone(
5051
tournament.play(progress_bar=False, filename=filename, build_results=False)
5152
)
@@ -90,7 +91,7 @@ def test_repeat_tournament_deterministic(self):
9091
turns=2,
9192
repetitions=2,
9293
)
93-
files.append("test_outputs/stochastic_tournament_{}.csv".format(_))
94+
files.append(axl_filename("test_outputs/stochastic_tournament_{}.csv".format(_)))
9495
tournament.play(progress_bar=False, filename=files[-1], build_results=False)
9596
self.assertTrue(filecmp.cmp(files[0], files[1]))
9697

@@ -113,7 +114,7 @@ def test_repeat_tournament_stochastic(self):
113114
turns=2,
114115
repetitions=2,
115116
)
116-
files.append("test_outputs/stochastic_tournament_{}.csv".format(_))
117+
files.append(axl_filename("test_outputs/stochastic_tournament_{}.csv".format(_)))
117118
tournament.play(progress_bar=False, filename=files[-1], build_results=False)
118119
self.assertTrue(filecmp.cmp(files[0], files[1]))
119120

axelrod/tests/unit/test_deterministic_cache.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest
44

55
from axelrod import Action, Defector, DeterministicCache, Random, TitForTat
6+
from axelrod.load_data_ import axl_filename
67

78
C, D = Action.C, Action.D
89

@@ -12,8 +13,8 @@ class TestDeterministicCache(unittest.TestCase):
1213
def setUpClass(cls):
1314
cls.test_key = (TitForTat(), Defector())
1415
cls.test_value = [(C, D), (D, D), (D, D)]
15-
cls.test_save_file = "test_cache_save.txt"
16-
cls.test_load_file = "test_cache_load.txt"
16+
cls.test_save_file = axl_filename("test_outputs/test_cache_save.txt")
17+
cls.test_load_file = axl_filename("test_outputs/test_cache_load.txt")
1718
test_data_to_pickle = {("Tit For Tat", "Defector"): [(C, D), (D, D), (D, D)]}
1819
cls.test_pickle = pickle.dumps(test_data_to_pickle)
1920

@@ -92,7 +93,7 @@ def test_load(self):
9293
self.assertEqual(self.cache[self.test_key], self.test_value)
9394

9495
def test_load_error_for_inccorect_format(self):
95-
filename = "test_outputs/test.cache"
96+
filename = axl_filename("test_outputs/test.cache")
9697
with open(filename, "wb") as io:
9798
pickle.dump(range(5), io)
9899

axelrod/tests/unit/test_fingerprint.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import axelrod as axl
1111
from axelrod.fingerprint import AshlockFingerprint, Point, TransitiveFingerprint
12+
from axelrod.load_data_ import axl_filename
1213
from axelrod.strategy_transformers import DualTransformer, JossAnnTransformer
1314
from axelrod.tests.property import strategy_lists
1415

@@ -195,7 +196,7 @@ def test_temp_file_creation(self):
195196

196197
RecordedMksTemp.reset_record()
197198
af = AshlockFingerprint(axl.TitForTat)
198-
filename = "test_outputs/test_fingerprint.csv"
199+
filename = axl_filename("test_outputs/test_fingerprint.csv")
199200

200201
self.assertEqual(RecordedMksTemp.record, [])
201202

@@ -211,7 +212,7 @@ def test_temp_file_creation(self):
211212
self.assertFalse(os.path.isfile(filename))
212213

213214
def test_fingerprint_with_filename(self):
214-
filename = "test_outputs/test_fingerprint.csv"
215+
filename = axl_filename("test_outputs/test_fingerprint.csv")
215216
af = AshlockFingerprint(axl.TitForTat)
216217
af.fingerprint(
217218
turns=1, repetitions=1, step=0.5, progress_bar=False, filename=filename
@@ -425,7 +426,7 @@ def test_init_with_not_default_number(self):
425426
)
426427

427428
def test_fingerprint_with_filename(self):
428-
filename = "test_outputs/test_fingerprint.csv"
429+
filename = axl_filename("test_outputs/test_fingerprint.csv")
429430
strategy = axl.TitForTat()
430431
tf = TransitiveFingerprint(strategy)
431432
tf.fingerprint(turns=1, repetitions=1, progress_bar=False, filename=filename)
@@ -437,7 +438,9 @@ def test_serial_fingerprint(self):
437438
strategy = axl.TitForTat()
438439
tf = TransitiveFingerprint(strategy)
439440
tf.fingerprint(
440-
repetitions=1, progress_bar=False, filename="test_outputs/tran_fin.csv"
441+
repetitions=1,
442+
progress_bar=False,
443+
filename=axl_filename("test_outputs/tran_fin.csv"),
441444
)
442445
self.assertEqual(tf.data.shape, (50, 50))
443446

@@ -450,7 +453,7 @@ def test_parallel_fingerprint(self):
450453

451454
def test_analyse_cooperation_ratio(self):
452455
tf = TransitiveFingerprint(axl.TitForTat)
453-
filename = "test_outputs/test_fingerprint.csv"
456+
filename = axl_filename("test_outputs/test_fingerprint.csv")
454457
with open(filename, "w") as f:
455458
f.write(
456459
"""Interaction index,Player index,Opponent index,Repetition,Player name,Opponent name,Actions

axelrod/tests/unit/test_plot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import unittest
33

44
import axelrod
5+
from axelrod.load_data_ import axl_filename
56
import matplotlib
67
import matplotlib.pyplot as plt
78
from numpy import mean
@@ -10,7 +11,7 @@
1011
class TestPlot(unittest.TestCase):
1112
@classmethod
1213
def setUpClass(cls):
13-
cls.filename = "test_outputs/test_results.csv"
14+
cls.filename = axl_filename("test_outputs/test_results.csv")
1415

1516
cls.players = [axelrod.Alternator(), axelrod.TitForTat(), axelrod.Defector()]
1617
cls.repetitions = 3

axelrod/tests/unit/test_resultset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import axelrod
66
import axelrod.interaction_utils as iu
77
import pandas as pd
8+
from axelrod.load_data_ import axl_filename
89
from axelrod.result_set import create_counter_dict
910
from axelrod.tests.property import prob_end_tournaments, tournaments
1011
from numpy import mean, nanmedian, std
@@ -19,7 +20,7 @@ class TestResultSet(unittest.TestCase):
1920
@classmethod
2021
def setUpClass(cls):
2122

22-
cls.filename = "test_outputs/test_results.csv"
23+
cls.filename = axl_filename("test_outputs/test_results.csv")
2324

2425
cls.players = [axelrod.Alternator(), axelrod.TitForTat(), axelrod.Defector()]
2526
cls.repetitions = 3
@@ -647,7 +648,7 @@ class TestResultSetSpatialStructure(TestResultSet):
647648
@classmethod
648649
def setUpClass(cls):
649650

650-
cls.filename = "test_outputs/test_results_spatial.csv"
651+
cls.filename = axl_filename("test_outputs/test_results_spatial.csv")
651652
cls.players = [axelrod.Alternator(), axelrod.TitForTat(), axelrod.Defector()]
652653
cls.turns = 5
653654
cls.edges = [(0, 1), (0, 2)]

axelrod/tests/unit/test_tournament.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from unittest.mock import MagicMock, patch
1313

1414
import axelrod
15+
from axelrod.load_data_ import axl_filename
1516
import numpy as np
1617
import pandas as pd
1718
from axelrod.tests.property import (
@@ -88,7 +89,7 @@ def setUpClass(cls):
8889
[200, 200, 1, 200, 200],
8990
]
9091

91-
cls.filename = "test_outputs/test_tournament.csv"
92+
cls.filename = axl_filename("test_outputs/test_tournament.csv")
9293

9394
def setUp(self):
9495
self.test_tournament = axelrod.Tournament(
@@ -733,7 +734,9 @@ def test_write_to_csv_with_results(self):
733734
)
734735
tournament.play(filename=self.filename, progress_bar=False)
735736
df = pd.read_csv(self.filename)
736-
expected_df = pd.read_csv("test_outputs/expected_test_tournament.csv")
737+
expected_df = pd.read_csv(
738+
axl_filename("test_outputs/expected_test_tournament.csv")
739+
)
737740
self.assertTrue(df.equals(expected_df))
738741

739742
def test_write_to_csv_without_results(self):
@@ -747,7 +750,7 @@ def test_write_to_csv_without_results(self):
747750
tournament.play(filename=self.filename, progress_bar=False, build_results=False)
748751
df = pd.read_csv(self.filename)
749752
expected_df = pd.read_csv(
750-
"test_outputs/expected_test_tournament_no_results.csv"
753+
axl_filename("test_outputs/expected_test_tournament_no_results.csv")
751754
)
752755
self.assertTrue(df.equals(expected_df))
753756

0 commit comments

Comments
 (0)