11import base64
22from pickle import dumps , loads
3- from random import randrange
43from typing import Dict , List
54
65from .player import Player
@@ -22,15 +21,26 @@ class EvolvablePlayer(Player):
2221 parent_class = Player
2322 parent_kwargs = [] # type: List[str]
2423
24+ def __init__ (self , seed = None ):
25+ # Parameter seed is required for reproducibility. Player will throw
26+ # a warning to the user otherwise.
27+ super ().__init__ ()
28+ self .set_seed (seed = seed )
29+
2530 def overwrite_init_kwargs (self , ** kwargs ):
2631 """Use to overwrite parameters for proper cloning and testing."""
2732 for k , v in kwargs .items ():
2833 self .init_kwargs [k ] = v
2934
3035 def create_new (self , ** kwargs ):
31- """Creates a new variant with parameters overwritten by kwargs."""
36+ """Creates a new variant with parameters overwritten by kwargs. This differs from
37+ cloning the Player because it propagates a seed forward, and is intended to be
38+ used by the mutation and crossover methods."""
3239 init_kwargs = self .init_kwargs .copy ()
3340 init_kwargs .update (kwargs )
41+ # Propagate seed forward for reproducibility.
42+ if "seed" not in kwargs :
43+ init_kwargs ["seed" ] = self ._random .random_seed_int ()
3444 return self .__class__ (** init_kwargs )
3545
3646 # Serialization and deserialization. You may overwrite to obtain more human readable serializations
@@ -74,15 +84,15 @@ def copy_lists(lists: List[List]) -> List[List]:
7484 return list (map (list , lists ))
7585
7686
77- def crossover_lists (list1 : List , list2 : List ) -> List :
78- cross_point = randrange ( len (list1 ))
87+ def crossover_lists (list1 : List , list2 : List , rng ) -> List :
88+ cross_point = rng . randint ( 0 , len (list1 ))
7989 new_list = list (list1 [:cross_point ]) + list (list2 [cross_point :])
8090 return new_list
8191
8292
83- def crossover_dictionaries (table1 : Dict , table2 : Dict ) -> Dict :
93+ def crossover_dictionaries (table1 : Dict , table2 : Dict , rng ) -> Dict :
8494 keys = list (table1 .keys ())
85- cross_point = randrange ( len (keys ))
95+ cross_point = rng . randint ( 0 , len (keys ))
8696 new_items = [(k , table1 [k ]) for k in keys [:cross_point ]]
8797 new_items += [(k , table2 [k ]) for k in keys [cross_point :]]
8898 new_table = dict (new_items )
0 commit comments