|
14 | 14 |
|
15 | 15 | import orjson |
16 | 16 | from channels import DEFAULT_CHANNEL_LAYER |
17 | | -from channels.db import database_sync_to_async |
18 | 17 | from channels.layers import InMemoryChannelLayer, get_channel_layer |
19 | 18 | from reactpy import use_callback, use_effect, use_memo, use_ref, use_state |
20 | 19 | from reactpy import use_connection as _use_connection |
|
34 | 33 | SyncPostprocessor, |
35 | 34 | UserData, |
36 | 35 | ) |
37 | | -from reactpy_django.utils import django_query_postprocessor, generate_obj_name, get_pk |
| 36 | +from reactpy_django.utils import django_query_postprocessor, ensure_async, generate_obj_name, get_pk |
38 | 37 |
|
39 | 38 | if TYPE_CHECKING: |
40 | 39 | from collections.abc import Awaitable, Sequence |
@@ -138,19 +137,13 @@ async def execute_query() -> None: |
138 | 137 | """The main running function for `use_query`""" |
139 | 138 | try: |
140 | 139 | # Run the query |
141 | | - if asyncio.iscoroutinefunction(query): |
142 | | - new_data = await query(**kwargs) |
143 | | - else: |
144 | | - new_data = await database_sync_to_async(query, thread_sensitive=thread_sensitive)(**kwargs) |
| 140 | + new_data = await ensure_async(query, thread_sensitive=thread_sensitive)(**kwargs) |
145 | 141 |
|
146 | 142 | # Run the postprocessor |
147 | 143 | if postprocessor: |
148 | | - if asyncio.iscoroutinefunction(postprocessor): |
149 | | - new_data = await postprocessor(new_data, **postprocessor_kwargs) |
150 | | - else: |
151 | | - new_data = await database_sync_to_async(postprocessor, thread_sensitive=thread_sensitive)( |
152 | | - new_data, **postprocessor_kwargs |
153 | | - ) |
| 144 | + new_data = await ensure_async(postprocessor, thread_sensitive=thread_sensitive)( |
| 145 | + new_data, **postprocessor_kwargs |
| 146 | + ) |
154 | 147 |
|
155 | 148 | # Log any errors and set the error state |
156 | 149 | except Exception as e: |
@@ -240,12 +233,7 @@ def use_mutation( |
240 | 233 | async def execute_mutation(exec_args, exec_kwargs) -> None: |
241 | 234 | # Run the mutation |
242 | 235 | try: |
243 | | - if asyncio.iscoroutinefunction(mutation): |
244 | | - should_refetch = await mutation(*exec_args, **exec_kwargs) |
245 | | - else: |
246 | | - should_refetch = await database_sync_to_async(mutation, thread_sensitive=thread_sensitive)( |
247 | | - *exec_args, **exec_kwargs |
248 | | - ) |
| 236 | + should_refetch = await ensure_async(mutation, thread_sensitive=thread_sensitive)(*exec_args, **exec_kwargs) |
249 | 237 |
|
250 | 238 | # Log any errors and set the error state |
251 | 239 | except Exception as e: |
@@ -444,10 +432,8 @@ async def _get_user_data(user: AbstractUser, default_data: None | dict, save_def |
444 | 432 | for key, value in default_data.items(): |
445 | 433 | if key not in data: |
446 | 434 | new_value: Any = value |
447 | | - if asyncio.iscoroutinefunction(value): |
448 | | - new_value = await value() |
449 | | - elif callable(value): |
450 | | - new_value = value() |
| 435 | + if callable(value): |
| 436 | + new_value = await ensure_async(value)() |
451 | 437 | data[key] = new_value |
452 | 438 | changed = True |
453 | 439 | if changed: |
|
0 commit comments