Skip to content

Commit 457bf7f

Browse files
committed
Fix ResultSet __eq__ to handle nans (again)
1 parent 85a8973 commit 457bf7f

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

axelrod/result_set.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import csv
33
import itertools
44
from multiprocessing import cpu_count
5+
from typing import List
56
import warnings
67

78
import numpy as np
@@ -611,6 +612,18 @@ def __eq__(self, other):
611612
other : axelrod.ResultSet
612613
Another results set against which to check equality
613614
"""
615+
616+
def list_equal_with_nans(v1: List[float], v2: List[float]) -> bool:
617+
"""Matches lists, accounting for NaNs."""
618+
if len(v1) != len(v2):
619+
return False
620+
for i1, i2 in zip(v1, v2):
621+
if np.isnan(i1) and np.isnan(i2):
622+
continue
623+
if i1 != i2:
624+
return False
625+
return True
626+
614627
return all(
615628
[
616629
self.wins == other.wins,

axelrod/tests/unit/test_eigen.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ def test_identity_matrices(self):
1515
self.assertAlmostEqual(evalue, 1)
1616
assert_array_almost_equal(evector, _normalise(numpy.ones(size)))
1717

18+
def test_zero_matrix(self):
19+
mat = numpy.array([[0, 0], [0, 0]])
20+
evector, evalue = principal_eigenvector(mat)
21+
self.assertTrue(numpy.isnan(evalue))
22+
self.assertTrue(numpy.isnan(evector[0]))
23+
self.assertTrue(numpy.isnan(evector[1]))
24+
1825
def test_2x2_matrix(self):
1926
mat = numpy.array([[2, 1], [1, 2]])
2027
evector, evalue = principal_eigenvector(mat)

axelrod/tests/unit/test_resultset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pandas as pd
88
from axelrod.result_set import create_counter_dict
99
from axelrod.tests.property import prob_end_tournaments, tournaments
10-
from numpy import mean, nanmedian, std
10+
from numpy import mean, nan, nanmedian, std
1111

1212
from dask.dataframe.core import DataFrame
1313
from hypothesis import given, settings

0 commit comments

Comments
 (0)