Skip to content

Commit c6b3bc4

Browse files
author
Release Manager
committed
sagemathgh-41119: slightly restructure the logic of first_terms, to avoid copying when calling a statistic Currently, calling a statistic copies the dict of known terms. To avoid this, we create a new lazy attribute, for internal use. URL: sagemath#41119 Reported by: Martin Rubey Reviewer(s):
2 parents 2579aaf + b412a48 commit c6b3bc4

File tree

1 file changed

+63
-22
lines changed

1 file changed

+63
-22
lines changed

src/sage/databases/findstat.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def mapping(sigma):
198198
# https://www.gnu.org/licenses/
199199
# ****************************************************************************
200200
from sage.misc.lazy_list import lazy_list
201+
from sage.misc.lazy_attribute import lazy_attribute
201202
from sage.misc.inherit_comparison import InheritComparisonClasscallMetaclass
202203
from sage.structure.element import Element
203204
from sage.structure.parent import Parent
@@ -631,7 +632,7 @@ def _data_from_iterable(iterable, mapping=False, domain=None,
631632
pre_data = [(elts, vals)]
632633

633634
# pre_data is a list of all elements of the iterator accessed so
634-
# far, for each of its elements and also the remainder ot the
635+
# far, for each of its elements and also the remainder of the
635636
# iterator, each element is either a pair ``(object, value)`` or
636637
# a pair ``(objects, values)``
637638
elts, vals = pre_data[0]
@@ -730,16 +731,15 @@ def _data_from_data(data, max_values):
730731
[0, 0, 1, 1, 2, 2, 1, 0, 0, 0, 1, 1, 1, 2, 3])]
731732
"""
732733
query = []
733-
total = min(max_values, FINDSTAT_MAX_VALUES)
734734
iterator = iter(data)
735-
while total > 0:
735+
while max_values > 0:
736736
try:
737737
elts, vals = next(iterator)
738738
except StopIteration:
739739
break
740-
if total >= len(elts):
740+
if max_values >= len(elts):
741741
query.append((elts, vals))
742-
total -= len(elts)
742+
max_values -= len(elts)
743743
else:
744744
break # assuming that the next pair is even larger
745745

@@ -1017,12 +1017,12 @@ def findstat(query=None, values=None, distribution=None, domain=None,
10171017
sage: findstat("Permutations", lambda x: 1, depth='x') # optional -- internet
10181018
Traceback (most recent call last):
10191019
...
1020-
ValueError: E021: Depth should be a nonnegative integer at most 9, but is x.
1020+
ValueError: E021: Depth should be a non-negative integer at most 9, but is x.
10211021
10221022
sage: findstat("Permutations", lambda x: 1, depth=100) # optional -- internet
10231023
Traceback (most recent call last):
10241024
...
1025-
ValueError: E021: Depth should be a nonnegative integer at most 9, but is 100.
1025+
ValueError: E021: Depth should be a non-negative integer at most 9, but is 100.
10261026
10271027
sage: S = Permutation
10281028
sage: findstat([(S([1,2]), 1), ([S([1,3,2]), S([1,2])], [2,3])]) # optional -- internet
@@ -1786,10 +1786,10 @@ def set_sage_code(self, value):
17861786
EXAMPLES::
17871787
17881788
sage: q = findstat([(d, randint(1,1000)) for d in DyckWords(4)]) # optional -- internet
1789-
sage: q.set_sage_code("def statistic(x):\n return randint(1, 1000)") # optional -- internet
1789+
sage: q.set_sage_code("def statistic(x):\n return randint(1, 1000)") # optional -- internet
17901790
sage: print(q.sage_code()) # optional -- internet
17911791
def statistic(x):
1792-
return randint(1,1000)
1792+
return randint(1, 1000)
17931793
"""
17941794
if value != self.sage_code():
17951795
self._modified = True
@@ -1818,9 +1818,22 @@ def __init__(self):
18181818
sage: FindStatCombinatorialStatistic()
18191819
<sage.databases.findstat.FindStatCombinatorialStatistic object at 0x...>
18201820
"""
1821-
self._first_terms_cache = None
18221821
self._first_terms_raw_cache = None
18231822

1823+
@lazy_attribute
1824+
def _first_terms_cache(self):
1825+
"""
1826+
Return the first terms of the (compound) statistic as a
1827+
dictionary.
1828+
1829+
EXAMPLES::
1830+
1831+
sage: findstat(41)._first_terms_cache[PerfectMatching([(1,6),(2,5),(3,4)])] # optional -- internet
1832+
3
1833+
"""
1834+
# this indirectly initializes self._first_terms_raw_cache
1835+
return dict(self._fetch_first_terms())
1836+
18241837
def first_terms(self):
18251838
r"""
18261839
Return the first terms of the (compound) statistic as a
@@ -1838,10 +1851,6 @@ def first_terms(self):
18381851
sage: findstat(41).first_terms()[PerfectMatching([(1,6),(2,5),(3,4)])] # optional -- internet
18391852
3
18401853
"""
1841-
# initialize self._first_terms_cache and
1842-
# self._first_terms_raw_cache on first call
1843-
if self._first_terms_cache is None:
1844-
self._first_terms_cache = self._fetch_first_terms()
18451854
# a shallow copy suffices - tuples are immutable
18461855
return dict(self._first_terms_cache)
18471856

@@ -1944,7 +1953,7 @@ def _generating_functions_dict(self,
19441953
domain = self.domain()
19451954
levels_with_sizes = domain.levels_with_sizes()
19461955
total = 0
1947-
for elt, val in self.first_terms().items():
1956+
for elt, val in self._first_terms_cache.items():
19481957
if total == max_values:
19491958
break
19501959
lvl = domain.element_level(elt)
@@ -2153,7 +2162,7 @@ def __call__(self, elt):
21532162
sage: q(graphs.PetersenGraph().copy(immutable=True)) # optional -- internet
21542163
2
21552164
"""
2156-
val = self.first_terms().get(elt, None)
2165+
val = self._first_terms_cache.get(elt, None)
21572166
if val is None:
21582167
return FindStatFunction.__call__(self, elt)
21592168
return val
@@ -2267,12 +2276,22 @@ def set_first_terms(self, values):
22672276
[(1, 4), (2, 3)] => 3
22682277
sage: s.reset() # optional -- internet
22692278
"""
2270-
to_str = self.domain().to_string()
2279+
domain = self.domain()
2280+
from_str = domain.from_string()
2281+
to_str = domain.to_string()
2282+
2283+
def to_domain(elt):
2284+
if domain.is_element(elt):
2285+
return elt
2286+
if not isinstance(elt, str):
2287+
elt = str(elt)
2288+
return from_str(elt)
2289+
22712290
new = [(to_str(obj), value) for obj, value in values]
22722291
if sorted(new) != sorted(self.first_terms_str()):
22732292
self._modified = True
22742293
self._first_terms_raw_cache = new
2275-
self._first_terms_cache = values
2294+
self._first_terms_cache = {to_domain(elt): v for elt, v in values}
22762295

22772296
def code(self):
22782297
r"""
@@ -2584,6 +2603,7 @@ def __init__(self, data=None, values_of=None, distribution_of=None,
25842603
self._known_terms = data
25852604
else:
25862605
self._known_terms = known_terms
2606+
self._known_terms_number = 0
25872607
self._values_of = None
25882608
self._distribution_of = None
25892609
self._depth = depth
@@ -2647,9 +2667,26 @@ def __init__(self, data=None, values_of=None, distribution_of=None,
26472667
function=function)
26482668
Element.__init__(self, FindStatStatistics()) # this is not completely correct, but it works
26492669

2670+
@lazy_attribute
2671+
def _first_terms_cache(self):
2672+
"""
2673+
Return the pairs of the known terms which contain
2674+
singletons, as a dictionary.
2675+
2676+
EXAMPLES::
2677+
2678+
sage: PM = PerfectMatchings
2679+
sage: l = [(PM(2*n), [m.number_of_nestings() for m in PM(2*n)]) for n in range(5)]
2680+
sage: r = findstat(l, depth=0) # optional -- internet
2681+
sage: r._first_terms_cache # optional -- internet
2682+
{}
2683+
"""
2684+
return dict()
2685+
26502686
def first_terms(self, max_values=FINDSTAT_MAX_SUBMISSION_VALUES):
26512687
"""
2652-
Return the pairs of the known terms which contain singletons as a dictionary.
2688+
Return the pairs of the known terms which contain
2689+
singletons, as a dictionary.
26532690
26542691
EXAMPLES::
26552692
@@ -2660,10 +2697,14 @@ def first_terms(self, max_values=FINDSTAT_MAX_SUBMISSION_VALUES):
26602697
1: St000042 (quality [99, 100])
26612698
sage: r.first_terms() # optional -- internet
26622699
{[]: 0, [(1, 2)]: 0}
2700+
26632701
"""
2664-
return dict(itertools.islice(((objs[0], vals[0])
2665-
for objs, vals in self._known_terms
2666-
if len(vals) == 1), max_values))
2702+
new_terms = self._known_terms[self._known_terms_number:max_values]
2703+
self._first_terms_cache.update((objs[0], vals[0])
2704+
for objs, vals in new_terms
2705+
if len(vals) == 1)
2706+
self._known_terms_number = max(max_values, self._known_terms_number)
2707+
return dict(self._first_terms_cache)
26672708

26682709
def _first_terms_raw(self, max_values):
26692710
"""

0 commit comments

Comments
 (0)