|
1 | 1 | import json |
2 | 2 | import inspect |
3 | 3 | import functools |
4 | | -from typing import Any, Dict, List, Union, TypeVar, Callable, Optional |
| 4 | +from typing import Any, Dict, List, TypeVar, Callable, Optional |
5 | 5 | from dataclasses import dataclass |
6 | 6 |
|
7 | 7 | T = TypeVar("T") |
@@ -44,53 +44,58 @@ def __init__(self, name: str): |
44 | 44 | # Register this app in the global registry |
45 | 45 | _app_registry.register_app(self) |
46 | 46 |
|
47 | | - def action(self, name_or_handler: Optional[Union[str, Callable[..., Any]]] = None) -> Callable[..., Any]: |
48 | | - """Decorator to register an action with the app""" |
49 | | - if name_or_handler is None: |
50 | | - # This is the @app.action() case, which should return the decorator |
51 | | - def decorator(f: Callable[..., Any]) -> Callable[..., Any]: |
52 | | - return self._register_action(f.__name__, f) |
53 | | - return decorator |
54 | | - elif callable(name_or_handler): |
55 | | - # This is the @app.action case (handler passed directly) |
56 | | - return self._register_action(name_or_handler.__name__, name_or_handler) |
57 | | - else: |
58 | | - # This is the @app.action("name") case (name_or_handler is a string) |
59 | | - def decorator(f: Callable[..., Any]) -> Callable[..., Any]: |
60 | | - return self._register_action(name_or_handler, f) # name_or_handler is the name string here |
61 | | - return decorator |
| 47 | + def action(self, name: str) -> Callable[..., Any]: |
| 48 | + """ |
| 49 | + Decorator to register an action with the app |
| 50 | + |
| 51 | + Usage: |
| 52 | + @app.action("action-name") |
| 53 | + def my_handler(ctx: KernelContext): |
| 54 | + # ... |
| 55 | + |
| 56 | + @app.action("action-with-payload") |
| 57 | + def my_handler(ctx: KernelContext, payload: dict): |
| 58 | + # ... |
| 59 | + """ |
| 60 | + def decorator(handler: Callable[..., Any]) -> Callable[..., Any]: |
| 61 | + return self._register_action(name, handler) |
| 62 | + return decorator |
62 | 63 |
|
63 | 64 | def _register_action(self, name: str, handler: Callable[..., Any]) -> Callable[..., Any]: |
64 | 65 | """Internal method to register an action""" |
| 66 | + # Validate handler signature |
| 67 | + sig = inspect.signature(handler) |
| 68 | + param_count = len(sig.parameters) |
| 69 | + |
| 70 | + if param_count == 0: |
| 71 | + raise TypeError("Action handler must accept at least the context parameter") |
| 72 | + elif param_count > 2: |
| 73 | + raise TypeError("Action handler can only accept context and payload parameters") |
| 74 | + |
| 75 | + param_names = list(sig.parameters.keys()) |
65 | 76 |
|
66 | 77 | @functools.wraps(handler) |
67 | 78 | def wrapper(*args: Any, **kwargs: Any) -> Any: |
68 | | - # Determine if the original handler accepts context as first argument |
69 | | - sig = inspect.signature(handler) |
70 | | - param_names = list(sig.parameters.keys()) |
71 | | - param_count = len(param_names) |
72 | | - |
| 79 | + # Ensure the first argument is the context |
| 80 | + if not args or not isinstance(args[0], KernelContext): |
| 81 | + raise TypeError("First argument to action handler must be a KernelContext") |
| 82 | + |
| 83 | + ctx = args[0] |
| 84 | + |
73 | 85 | if param_count == 1: |
74 | | - actual_input = None |
75 | | - # The handler only takes input |
76 | | - if len(args) > 0: # Prioritize args if context was implicitly passed |
77 | | - # If context (args[0]) and input (args[1]) were provided, or just input (args[0]) |
78 | | - actual_input = args[1] if len(args) > 1 else args[0] |
79 | | - elif kwargs: |
80 | | - # Attempt to find the single expected kwarg |
81 | | - if param_names: # Should always be true if param_count == 1 |
82 | | - param_name = param_names[0] |
83 | | - if param_name in kwargs: |
84 | | - actual_input = kwargs[param_name] |
85 | | - elif kwargs: # Fallback if name doesn't match but kwargs exist |
86 | | - actual_input = next(iter(kwargs.values())) |
87 | | - elif kwargs: # param_names is empty but kwargs exist (unlikely for param_count==1) |
88 | | - actual_input = next(iter(kwargs.values())) |
89 | | - # If no args/kwargs, actual_input remains None, handler might raise error or accept None |
90 | | - return handler(actual_input) |
91 | | - else: # param_count == 0 or param_count > 1 |
92 | | - # Handler takes context and input (or more), or no args |
93 | | - return handler(*args, **kwargs) |
| 86 | + # Handler takes only context |
| 87 | + return handler(ctx) |
| 88 | + else: # param_count == 2 |
| 89 | + # Handler takes context and payload |
| 90 | + if len(args) >= 2: |
| 91 | + return handler(ctx, args[1]) |
| 92 | + else: |
| 93 | + # Try to find payload in kwargs |
| 94 | + payload_name = param_names[1] |
| 95 | + if payload_name in kwargs: |
| 96 | + return handler(ctx, kwargs[payload_name]) |
| 97 | + else: |
| 98 | + raise TypeError(f"Missing required payload parameter '{payload_name}'") |
94 | 99 |
|
95 | 100 | action = KernelAction(name=name, handler=wrapper) |
96 | 101 | self.actions[name] = action |
|
0 commit comments