Skip to content

Commit 466e3b4

Browse files
committed
Ran black, all tests pass, final lookover
1 parent ab30261 commit 466e3b4

File tree

3 files changed

+14
-7
lines changed

3 files changed

+14
-7
lines changed

pyomo/contrib/parmest/graphics.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def pairwise_plot(
273273
Add a legend to the plot
274274
filename: string, optional
275275
Filename used to save the figure
276+
seed: int, optional
277+
Random seed used to generate theta values if theta_values is a tuple.
278+
If None, the seed is not set.
276279
"""
277280
assert isinstance(theta_values, (pd.DataFrame, tuple))
278281
assert isinstance(theta_star, (type(None), dict, pd.Series, pd.DataFrame))

pyomo/contrib/parmest/parmest.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def _Q_at_theta(self, thetavals, initialize_parmest_model=False):
817817

818818
return retval, thetavals, WorstStatus
819819

820-
def _get_sample_list(self, samplesize, num_samples, replacement=True, seed=None):
820+
def _get_sample_list(self, samplesize, num_samples, replacement=True):
821821
samplelist = list()
822822

823823
scenario_numbers = list(range(len(self.exp_list)))
@@ -834,8 +834,6 @@ def _get_sample_list(self, samplesize, num_samples, replacement=True, seed=None)
834834
while (unique_samples <= len(self._return_theta_names())) and (
835835
not duplicate
836836
):
837-
# if seed is not None:
838-
# np.random.seed(seed) # set seed for reproducibility
839837
sample = np.random.choice(
840838
scenario_numbers, samplesize, replace=replacement
841839
)
@@ -1037,9 +1035,7 @@ def theta_est_leaveNout(
10371035
if seed is not None:
10381036
np.random.seed(seed)
10391037

1040-
global_list = self._get_sample_list(
1041-
samplesize, lNo_samples, replacement=False, seed=seed
1042-
)
1038+
global_list = self._get_sample_list(samplesize, lNo_samples, replacement=False)
10431039

10441040
task_mgr = utils.ParallelTaskManager(len(global_list))
10451041
local_list = task_mgr.global_to_local_data(global_list)

pyomo/contrib/parmest/tests/test_graphics.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
import pyomo.contrib.parmest.parmest as parmest
3232
import pyomo.contrib.parmest.graphics as graphics
3333

34+
from pyomo.contrib.parmest.tests.test_parmest import _RANDOM_SEED_FOR_TESTING
35+
3436
testdir = os.path.dirname(os.path.abspath(__file__))
3537

3638

@@ -47,6 +49,7 @@
4749
)
4850
class TestGraphics(unittest.TestCase):
4951
def setUp(self):
52+
np.random.seed(_RANDOM_SEED_FOR_TESTING)
5053
self.A = pd.DataFrame(
5154
np.random.randint(0, 100, size=(100, 4)), columns=list('ABCD')
5255
)
@@ -55,7 +58,12 @@ def setUp(self):
5558
)
5659

5760
def test_pairwise_plot(self):
58-
graphics.pairwise_plot(self.A, alpha=0.8, distributions=['Rect', 'MVN', 'KDE'])
61+
graphics.pairwise_plot(
62+
self.A,
63+
alpha=0.8,
64+
distributions=['Rect', 'MVN', 'KDE'],
65+
seed=_RANDOM_SEED_FOR_TESTING,
66+
)
5967

6068
def test_grouped_boxplot(self):
6169
graphics.grouped_boxplot(self.A, self.B, normalize=True, group_names=['A', 'B'])

0 commit comments

Comments
 (0)