|
8 | 8 |
|
9 | 9 | import argparse |
10 | 10 |
|
11 | | -parser = argparse.ArgumentParser() |
12 | | -parser.add_argument("participant", choices=["A", "B"]) |
13 | | -parser.add_argument("--config", "-c", default="../precice-config.xml") |
14 | | -parser.add_argument("--no-remesh", dest="remesh", action="store_false") |
15 | | -args = parser.parse_args() |
16 | 11 |
|
17 | | -participant_name = args.participant |
18 | | -remote_name = "A" if participant_name == "B" else "B" |
| 12 | +def main(): |
| 13 | + parser = argparse.ArgumentParser() |
| 14 | + parser.add_argument("participant", choices=["A", "B"]) |
| 15 | + parser.add_argument("--config", "-c", default="../precice-config.xml") |
| 16 | + parser.add_argument("--no-remesh", dest="remesh", action="store_false") |
| 17 | + args = parser.parse_args() |
19 | 18 |
|
20 | | -# x is partitioned per rank and doesn't change |
21 | | -nx = 256 * 3 |
22 | | -x = 0.0, 1.0 |
23 | | -ny = 256 * 3 |
24 | | -y = 0.0, 1.0 |
| 19 | + participant_name = args.participant |
| 20 | + remote_name = "A" if participant_name == "B" else "B" |
25 | 21 |
|
26 | | -# y grows over time |
27 | | -newNodesPerEvent = 2 |
28 | | -eventFrequency = 3 # time windows |
29 | | -dz = 0.1 |
| 22 | + # x is partitioned per rank and doesn't change |
| 23 | + nx = 256 * 3 |
| 24 | + x = 0.0, 1.0 |
| 25 | + ny = 256 * 3 |
| 26 | + y = 0.0, 1.0 |
30 | 27 |
|
| 28 | + # y grows over time |
| 29 | + newNodesPerEvent = 2 |
| 30 | + eventFrequency = 3 # time windows |
| 31 | + dz = 0.1 |
31 | 32 |
|
32 | | -# Handle partitioning |
33 | | -world = MPI.COMM_WORLD |
34 | | -size: int = world.size |
35 | | -rank: int = world.rank |
| 33 | + # Handle partitioning |
| 34 | + world = MPI.COMM_WORLD |
| 35 | + size: int = world.size |
| 36 | + rank: int = world.rank |
36 | 37 |
|
37 | | -parts: int = int(math.sqrt(size)) |
38 | | -assert parts**2 == size, "size must be a square value" |
39 | | -assert math.remainder(nx, parts) == 0, f"{nx=} must be dividable by {parts=}" |
| 38 | + parts: int = int(math.sqrt(size)) |
| 39 | + assert parts**2 == size, "size must be a square value" |
| 40 | + assert math.remainder(nx, parts) == 0, f"{nx=} must be dividable by {parts=}" |
40 | 41 |
|
41 | | -# current parition in x, y |
42 | | -px = rank % parts |
43 | | -py = rank // parts |
| 42 | + # current parition in x, y |
| 43 | + px = rank % parts |
| 44 | + py = rank // parts |
44 | 45 |
|
45 | | -# nodes per parition |
46 | | -nxp = nx // parts |
47 | | -nyp = ny // parts |
| 46 | + # nodes per parition |
| 47 | + nxp = nx // parts |
| 48 | + nyp = ny // parts |
48 | 49 |
|
49 | | -# node slide per partition |
50 | | -nxps = slice(nxp * px, nxp * (px + 1)) |
51 | | -nyps = slice(nyp * py, nyp * (py + 1)) |
| 50 | + # node slide per partition |
| 51 | + nxps = slice(nxp * px, nxp * (px + 1)) |
| 52 | + nyps = slice(nyp * py, nyp * (py + 1)) |
52 | 53 |
|
53 | | -print(f"{rank=} {nxps=} {nyps=}") |
| 54 | + print(f"{rank=} {nxps=} {nyps=}") |
54 | 55 |
|
| 56 | + def requiresEvent(tw): |
| 57 | + return tw % eventFrequency == 0 |
55 | 58 |
|
56 | | -def requiresEvent(tw): |
57 | | - return tw % eventFrequency == 0 |
| 59 | + assert not requiresEvent(eventFrequency - 1) |
| 60 | + assert requiresEvent(eventFrequency) |
| 61 | + assert not requiresEvent(eventFrequency + 1) |
58 | 62 |
|
| 63 | + def eventsAt(tw): |
| 64 | + # First event block at tw=0, second at eventFrequency |
| 65 | + return 1 + math.floor(tw / eventFrequency) |
59 | 66 |
|
60 | | -assert not requiresEvent(eventFrequency - 1) |
61 | | -assert requiresEvent(eventFrequency) |
62 | | -assert not requiresEvent(eventFrequency + 1) |
| 67 | + assert eventsAt(0) == 1 |
| 68 | + assert eventsAt(eventFrequency - 1) == 1 |
| 69 | + assert eventsAt(eventFrequency) == 2 |
| 70 | + assert eventsAt(eventFrequency + 1) == 2 |
63 | 71 |
|
| 72 | + def getMeshAtTimeWindow(tw): |
| 73 | + znodes = eventsAt(tw) * newNodesPerEvent |
64 | 74 |
|
65 | | -def eventsAt(tw): |
66 | | - # First event block at tw=0, second at eventFrequency |
67 | | - return 1 + math.floor(tw / eventFrequency) |
| 75 | + xs = np.linspace(x[0], x[1], nx)[nxps] |
| 76 | + ys = np.linspace(y[0], y[1], ny)[nyps] |
| 77 | + zs = np.array(range(znodes)) * dz |
68 | 78 |
|
| 79 | + return np.reshape([(x, y, z) for z in zs for y in ys for x in xs], (-1, 3)) |
69 | 80 |
|
70 | | -assert eventsAt(0) == 1 |
71 | | -assert eventsAt(eventFrequency - 1) == 1 |
72 | | -assert eventsAt(eventFrequency) == 2 |
73 | | -assert eventsAt(eventFrequency + 1) == 2 |
| 81 | + participant = precice.Participant(participant_name, args.config, rank, size) |
74 | 82 |
|
| 83 | + mesh_name = participant_name + "-Mesh" |
| 84 | + read_data_name = "Data-" + remote_name |
| 85 | + write_data_name = "Data-" + participant_name |
75 | 86 |
|
76 | | -def getMeshAtTimeWindow(tw): |
77 | | - znodes = eventsAt(tw) * newNodesPerEvent |
| 87 | + coords = getMeshAtTimeWindow(0) |
| 88 | + vertex_ids = participant.set_mesh_vertices(mesh_name, coords) |
| 89 | + participant.initialize() |
78 | 90 |
|
79 | | - xs = np.linspace(x[0], x[1], nx)[nxps] |
80 | | - ys = np.linspace(y[0], y[1], ny)[nyps] |
81 | | - zs = np.array(range(znodes)) * dz |
| 91 | + tw = 1 |
| 92 | + while participant.is_coupling_ongoing(): |
| 93 | + dt = participant.get_max_time_step_size() |
82 | 94 |
|
83 | | - return np.reshape([(x, y, z) for z in zs for y in ys for x in xs], (-1, 3)) |
84 | | - |
85 | | - |
86 | | -participant = precice.Participant(participant_name, args.config, rank, size) |
87 | | - |
88 | | -mesh_name = participant_name + "-Mesh" |
89 | | -read_data_name = "Data-" + remote_name |
90 | | -write_data_name = "Data-" + participant_name |
91 | | - |
92 | | -coords = getMeshAtTimeWindow(0) |
93 | | -vertex_ids = participant.set_mesh_vertices(mesh_name, coords) |
94 | | -participant.initialize() |
95 | | - |
96 | | -tw = 1 |
97 | | -while participant.is_coupling_ongoing(): |
98 | | - dt = participant.get_max_time_step_size() |
99 | | - |
100 | | - data = participant.read_data(mesh_name, read_data_name, vertex_ids, dt) |
101 | | - if rank == 0: |
102 | | - print(data) |
103 | | - |
104 | | - if args.remesh and requiresEvent(tw): |
105 | | - oldCount = len(coords) |
106 | | - coords = getMeshAtTimeWindow(tw) |
| 95 | + data = participant.read_data(mesh_name, read_data_name, vertex_ids, dt) |
107 | 96 | if rank == 0: |
108 | | - print( |
109 | | - f"Event grows local mesh from {oldCount} to { |
110 | | - len(coords)} and global mesh from { |
111 | | - oldCount * |
112 | | - size} to { |
113 | | - len(coords) * |
114 | | - size}") |
115 | | - participant.reset_mesh(mesh_name) |
116 | | - vertex_ids = participant.set_mesh_vertices(mesh_name, coords) |
117 | | - |
118 | | - data = np.full(len(coords), tw) |
119 | | - participant.write_data(mesh_name, write_data_name, vertex_ids, data) |
120 | | - |
121 | | - participant.advance(dt) |
122 | | - tw += 1 |
| 97 | + print(data) |
| 98 | + |
| 99 | + if args.remesh and requiresEvent(tw): |
| 100 | + oldCount = len(coords) |
| 101 | + coords = getMeshAtTimeWindow(tw) |
| 102 | + if rank == 0: |
| 103 | + print( |
| 104 | + f"Event grows local mesh from {oldCount} to { |
| 105 | + len(coords)} and global mesh from { |
| 106 | + oldCount * |
| 107 | + size} to { |
| 108 | + len(coords) * |
| 109 | + size}") |
| 110 | + participant.reset_mesh(mesh_name) |
| 111 | + vertex_ids = participant.set_mesh_vertices(mesh_name, coords) |
| 112 | + |
| 113 | + data = np.full(len(coords), tw) |
| 114 | + participant.write_data(mesh_name, write_data_name, vertex_ids, data) |
| 115 | + |
| 116 | + participant.advance(dt) |
| 117 | + tw += 1 |
| 118 | + |
| 119 | + |
| 120 | +if __name__ == "__main__": |
| 121 | + main() |
0 commit comments