|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import os |
| 6 | +import threading |
6 | 7 | from concurrent import futures |
7 | 8 | from http.server import ThreadingHTTPServer |
8 | | -from typing import Any, Callable, Coroutine, Optional, TypeVar, overload |
| 9 | +from typing import Any, Callable, Coroutine, List, Optional, TypeVar, overload |
9 | 10 | from urllib.parse import urlsplit |
10 | 11 |
|
11 | 12 | from typing_extensions import ParamSpec, TypeAlias |
|
31 | 32 | "Status", |
32 | 33 | "all", |
33 | 34 | "any", |
| 35 | + "batch", |
34 | 36 | "call", |
35 | 37 | "function", |
36 | 38 | "gather", |
|
44 | 46 | T = TypeVar("T") |
45 | 47 |
|
46 | 48 | _registry: Optional[Registry] = None |
47 | | - |
| 49 | +_workers: List[Callable[None, None]] = [] |
| 50 | +_threads: List[threading.Thread] = [] |
48 | 51 |
|
49 | 52 | def default_registry(): |
50 | 53 | global _registry |
@@ -89,10 +92,35 @@ def run(init: Optional[Callable[P, None]] = None, *args: P.args, **kwargs: P.kwa |
89 | 92 | parsed_url = urlsplit("//" + address) |
90 | 93 | server_address = (parsed_url.hostname or "", parsed_url.port or 0) |
91 | 94 | server = ThreadingHTTPServer(server_address, Dispatch(default_registry())) |
| 95 | + |
| 96 | + for worker in _workers: |
| 97 | + def entrypoint(): |
| 98 | + try: |
| 99 | + worker() |
| 100 | + finally: |
| 101 | + server.shutdown() |
| 102 | + _threads.append(threading.Thread(target=entrypoint)) |
| 103 | + |
| 104 | + for thread in _threads: |
| 105 | + thread.start() |
| 106 | + |
92 | 107 | try: |
93 | 108 | if init is not None: |
94 | 109 | init(*args, **kwargs) |
95 | 110 | server.serve_forever() |
96 | 111 | finally: |
97 | 112 | server.shutdown() |
98 | 113 | server.server_close() |
| 114 | + |
| 115 | + for thread in _threads: |
| 116 | + thread.join() |
| 117 | + |
| 118 | +def batch() -> Batch: |
| 119 | + """Create a new batch object.""" |
| 120 | + return default_registry().batch() |
| 121 | + |
| 122 | + |
| 123 | +def worker(fn: Callable[None, None]) -> Callable[None, None]: |
| 124 | + """Decorator declaring workers that will be started when dipatch.run is called.""" |
| 125 | + _workers.append(fn) |
| 126 | + return fn |
0 commit comments