|
1 | 1 | import numpy as np |
| 2 | + |
2 | 3 | from .graphical_simulator import GraphicalSimulator |
3 | 4 |
|
4 | 5 |
|
@@ -60,21 +61,21 @@ def meta_fn(): |
60 | 61 | simulator = GraphicalSimulator(meta_fn=meta_fn) |
61 | 62 | simulator.add_node( |
62 | 63 | "schools", |
63 | | - sampling_fn=sample_school, |
| 64 | + sample_fn=sample_school, |
64 | 65 | ) |
65 | 66 | simulator.add_node( |
66 | 67 | "exams", |
67 | | - sampling_fn=sample_exam, |
| 68 | + sample_fn=sample_exam, |
68 | 69 | reps="num_exams", |
69 | 70 | ) |
70 | 71 | simulator.add_node( |
71 | 72 | "questions", |
72 | | - sampling_fn=sample_question, |
| 73 | + sample_fn=sample_question, |
73 | 74 | reps="num_questions", |
74 | 75 | ) |
75 | 76 | simulator.add_node( |
76 | 77 | "students", |
77 | | - sampling_fn=sample_student, |
| 78 | + sample_fn=sample_student, |
78 | 79 | reps="num_students", |
79 | 80 | ) |
80 | 81 |
|
@@ -109,8 +110,8 @@ def meta(): |
109 | 110 |
|
110 | 111 | simulator = GraphicalSimulator(meta_fn=meta) |
111 | 112 |
|
112 | | - simulator.add_node("prior", sampling_fn=prior) |
113 | | - simulator.add_node("likelihood", sampling_fn=likelihood) |
| 113 | + simulator.add_node("prior", sample_fn=prior) |
| 114 | + simulator.add_node("likelihood", sample_fn=likelihood) |
114 | 115 |
|
115 | 116 | simulator.add_edge("prior", "likelihood") |
116 | 117 |
|
@@ -140,15 +141,15 @@ def sample_y(local_mean, shared_std): |
140 | 141 | return {"y": float(y)} |
141 | 142 |
|
142 | 143 | simulator = GraphicalSimulator() |
143 | | - simulator.add_node("hypers", sampling_fn=sample_hypers, reps=1) |
| 144 | + simulator.add_node("hypers", sample_fn=sample_hypers, reps=5) |
144 | 145 |
|
145 | 146 | simulator.add_node( |
146 | 147 | "locals", |
147 | 148 | sampling_fn=sample_locals, |
148 | 149 | reps=6, |
149 | 150 | ) |
150 | | - simulator.add_node("shared", sampling_fn=sample_shared, reps=1) |
151 | | - simulator.add_node("y", sampling_fn=sample_y, reps=10) |
| 151 | + simulator.add_node("shared", sample_fn=sample_shared, reps=1) |
| 152 | + simulator.add_node("y", sample_fn=sample_y, reps=10) |
152 | 153 |
|
153 | 154 | simulator.add_edge("hypers", "locals") |
154 | 155 | simulator.add_edge("locals", "y") |
@@ -184,19 +185,19 @@ def sample_y(level_3_mean, shared_std): |
184 | 185 | return {"y": y} |
185 | 186 |
|
186 | 187 | simulator = GraphicalSimulator() |
187 | | - simulator.add_node("level1", sampling_fn=sample_level_1) |
| 188 | + simulator.add_node("level1", sample_fn=sample_level_1) |
188 | 189 | simulator.add_node( |
189 | 190 | "level2", |
190 | | - sampling_fn=sample_level_2, |
| 191 | + sample_fn=sample_level_2, |
191 | 192 | reps=10, |
192 | 193 | ) |
193 | 194 | simulator.add_node( |
194 | 195 | "level3", |
195 | | - sampling_fn=sample_level_3, |
| 196 | + sample_fn=sample_level_3, |
196 | 197 | reps=20, |
197 | 198 | ) |
198 | | - simulator.add_node("shared", sampling_fn=sample_shared) |
199 | | - simulator.add_node("y", sampling_fn=sample_y, reps=10) |
| 199 | + simulator.add_node("shared", sample_fn=sample_shared) |
| 200 | + simulator.add_node("y", sample_fn=sample_y, reps=10) |
200 | 201 |
|
201 | 202 | simulator.add_edge("level1", "level2") |
202 | 203 | simulator.add_edge("level2", "level3") |
|
0 commit comments