Skip to content

Commit c3b1333

Browse files
committed
add withTransaction script
1 parent d267eb4 commit c3b1333

File tree

1 file changed

+111
-0
lines changed

1 file changed

+111
-0
lines changed

withTransaction.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import time
5+
from concurrent.futures import ThreadPoolExecutor
6+
7+
from pymongo import MongoClient
8+
9+
10+
class RunOrderTransaction:
11+
def __init__(self, client):
12+
super(RunOrderTransaction, self).__init__() # noqa:UP008
13+
self.retry_attempts = -1
14+
self.time = 0
15+
self.client = client
16+
17+
def run(self):
18+
start = time.time()
19+
with self.client.start_session() as session:
20+
try:
21+
session.with_transaction(self.callback)
22+
finally:
23+
self.time = time.time() - start
24+
return self
25+
26+
def callback(self, session):
27+
self.retry_attempts += 1
28+
return callback(session, self.client)
29+
30+
31+
def callback(session, client):
32+
order_id = client.test.orders1.insert_one({"sku": "foo", "qty": 1}, session=session).inserted_id
33+
res = client.test.inventory1.update_one(
34+
{"sku": "foo", "qty": {"$gte": 1}}, {"$inc": {"qty": -1}}, session=session
35+
)
36+
if not res.modified_count:
37+
raise TypeError("Insufficient inventory count")
38+
39+
return order_id
40+
41+
42+
def run(num_threads: int, local: bool):
43+
if local:
44+
client = MongoClient()
45+
else:
46+
client = MongoClient(os.getenv("ATLAS_URI"))
47+
try:
48+
client.drop_database("test")
49+
except Exception: # noqa: S110
50+
# fails on atlas?
51+
pass
52+
db = client.test
53+
db.drop_collection("orders1")
54+
db.create_collection("orders1")
55+
db.drop_collection("inventory1")
56+
inventory = db.create_collection("inventory1")
57+
inventory.insert_one({"sku": "foo", "qty": 1000000})
58+
59+
f.write("Testing %s threads\n" % num_threads)
60+
start = time.time()
61+
N_TXNS = 512
62+
results = []
63+
ops = [RunOrderTransaction(client) for _ in range(N_TXNS)]
64+
with ThreadPoolExecutor(max_workers=num_threads) as exc:
65+
futures = [exc.submit(op.run) for op in ops]
66+
for future in futures:
67+
result = future.result()
68+
results.append(result)
69+
70+
end = time.time()
71+
total_time = end - start
72+
total_attempts = sum(r.retry_attempts for r in results)
73+
74+
f.write("All threads completed after %s seconds\n" % (end - start))
75+
f.write(f"Total number of retry attempts: {total_attempts}\n")
76+
client.close()
77+
78+
latencies = sorted(r.time for r in results)
79+
avg_latency = sum(latencies) / N_TXNS
80+
p50 = latencies[int(N_TXNS * 0.5)]
81+
p90 = latencies[int(N_TXNS * 0.9)]
82+
p99 = latencies[int(N_TXNS * 0.99)]
83+
p100 = latencies[int(N_TXNS * 1.0) - 1]
84+
# print(f'avg latency: {avg_latency:.2f}s p50: {p50:.2f}s p90: {p90:.2f}s p99: {p99:.2f}s p100: {p100:.2f}s')
85+
return total_time, total_attempts, avg_latency, p50, p90, p99, p100
86+
87+
88+
def main(f, local=True):
89+
NUM_THREADS = [1, 2, 4, 8, 16, 32, 64, 128, 256]
90+
data = {}
91+
for num in NUM_THREADS:
92+
times, attempts, avg_latency, p50, p90, p99, p100 = run(num, local)
93+
data[num] = {
94+
"avg": avg_latency,
95+
"p50": p50,
96+
"p90": p90,
97+
"p99": p99,
98+
"p100": p100,
99+
}
100+
f.write("\n")
101+
time.sleep(10)
102+
f.write("\nthreads | avg | p50 | p90 | p99 | p100\n")
103+
for num in NUM_THREADS:
104+
f.write(
105+
f"{num:7} | {data[num]['avg']:5.2f} | {data[num]['p50']:5.2f} | {data[num]['p90']:5.2f} | {data[num]['p90']:5.2f} | {data[num]['p100']:5.2f}\n"
106+
)
107+
108+
109+
if __name__ == "__main__":
110+
with open("/Users/iris.ho/Github/backpressure/final/local_original_1.5.txt", "w") as f:
111+
main(f, local=True)

0 commit comments

Comments
 (0)