Skip to content

Commit 00f293f

Browse files
committed
WIP
1 parent 0d80e67 commit 00f293f

File tree

1 file changed

+226
-3
lines changed

1 file changed

+226
-3
lines changed

temporalio/client.py

Lines changed: 226 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from google.protobuf.internal.containers import MessageMap
4444
from typing_extensions import Concatenate, Required, Self, TypedDict
4545

46+
import temporalio.api.activity.v1
4647
import temporalio.api.common.v1
4748
import temporalio.api.enums.v1
4849
import temporalio.api.errordetails.v1
@@ -3041,6 +3042,101 @@ class ActivityExecutionDescription:
30413042
raw_info: Any
30423043
"""Raw proto response."""
30433044

3045+
@classmethod
3046+
async def _from_raw_info(
3047+
cls,
3048+
info: temporalio.api.activity.v1.ActivityExecutionInfo,
3049+
data_converter: temporalio.converter.DataConverter,
3050+
) -> Self:
3051+
"""Create from raw proto activity info."""
3052+
return cls(
3053+
activity_id=info.activity_id,
3054+
run_id=info.run_id,
3055+
activity_type=(
3056+
info.activity_type.name if info.HasField("activity_type") else ""
3057+
),
3058+
status=(
3059+
temporalio.common.ActivityExecutionStatus(info.status)
3060+
if info.status
3061+
else temporalio.common.ActivityExecutionStatus.RUNNING
3062+
),
3063+
run_state=(
3064+
temporalio.common.PendingActivityState(info.run_state)
3065+
if info.run_state
3066+
else None
3067+
),
3068+
heartbeat_details=(
3069+
await data_converter.decode(info.heartbeat_details.payloads)
3070+
if info.HasField("heartbeat_details")
3071+
else []
3072+
),
3073+
last_heartbeat_time=(
3074+
info.last_heartbeat_time.ToDatetime(tzinfo=timezone.utc)
3075+
if info.HasField("last_heartbeat_time")
3076+
else None
3077+
),
3078+
last_started_time=(
3079+
info.last_started_time.ToDatetime(tzinfo=timezone.utc)
3080+
if info.HasField("last_started_time")
3081+
else None
3082+
),
3083+
attempt=info.attempt,
3084+
maximum_attempts=info.maximum_attempts,
3085+
scheduled_time=(
3086+
info.scheduled_time.ToDatetime(tzinfo=timezone.utc)
3087+
if info.HasField("scheduled_time")
3088+
else datetime.min
3089+
),
3090+
expiration_time=(
3091+
info.expiration_time.ToDatetime(tzinfo=timezone.utc)
3092+
if info.HasField("expiration_time")
3093+
else datetime.min
3094+
),
3095+
last_failure=(
3096+
cast(
3097+
Optional[Exception],
3098+
await data_converter.decode_failure(info.last_failure),
3099+
)
3100+
if info.HasField("last_failure")
3101+
else None
3102+
),
3103+
last_worker_identity=info.last_worker_identity,
3104+
current_retry_interval=(
3105+
info.current_retry_interval.ToTimedelta()
3106+
if info.HasField("current_retry_interval")
3107+
else None
3108+
),
3109+
last_attempt_complete_time=(
3110+
info.last_attempt_complete_time.ToDatetime(tzinfo=timezone.utc)
3111+
if info.HasField("last_attempt_complete_time")
3112+
else None
3113+
),
3114+
next_attempt_schedule_time=(
3115+
info.next_attempt_schedule_time.ToDatetime(tzinfo=timezone.utc)
3116+
if info.HasField("next_attempt_schedule_time")
3117+
else None
3118+
),
3119+
task_queue=(
3120+
info.activity_options.task_queue.name
3121+
if info.HasField("activity_options")
3122+
and info.activity_options.HasField("task_queue")
3123+
else ""
3124+
),
3125+
paused=info.HasField("pause_info"),
3126+
input=(
3127+
await data_converter.decode(info.input.payloads)
3128+
if info.HasField("input")
3129+
else []
3130+
),
3131+
state_transition_count=info.state_transition_count,
3132+
search_attributes=temporalio.converter.decode_search_attributes(
3133+
info.search_attributes
3134+
),
3135+
eager_execution_requested=info.eager_execution_requested,
3136+
canceled_reason=info.canceled_reason,
3137+
raw_info=info,
3138+
)
3139+
30443140

30453141
@dataclass(frozen=True)
30463142
class ActivityIDReference:
@@ -3292,7 +3388,15 @@ async def cancel(
32923388
rpc_metadata: Headers used on the RPC call.
32933389
rpc_timeout: Optional RPC deadline to set for the RPC call.
32943390
"""
3295-
raise NotImplementedError
3391+
await self._client._impl.cancel_activity(
3392+
CancelActivityInput(
3393+
activity_id=self._id,
3394+
run_id=self._run_id,
3395+
reason=reason,
3396+
rpc_metadata=rpc_metadata,
3397+
rpc_timeout=rpc_timeout,
3398+
)
3399+
)
32963400

32973401
async def terminate(
32983402
self,
@@ -3312,7 +3416,15 @@ async def terminate(
33123416
rpc_metadata: Headers used on the RPC call.
33133417
rpc_timeout: Optional RPC deadline to set for the RPC call.
33143418
"""
3315-
raise NotImplementedError
3419+
await self._client._impl.terminate_activity(
3420+
TerminateActivityInput(
3421+
activity_id=self._id,
3422+
run_id=self._run_id,
3423+
reason=reason,
3424+
rpc_metadata=rpc_metadata,
3425+
rpc_timeout=rpc_timeout,
3426+
)
3427+
)
33163428

33173429
async def describe(
33183430
self,
@@ -3329,7 +3441,14 @@ async def describe(
33293441
Returns:
33303442
Activity execution description.
33313443
"""
3332-
raise NotImplementedError
3444+
return await self._client._impl.describe_activity(
3445+
DescribeActivityInput(
3446+
activity_id=self._id,
3447+
run_id=self._run_id,
3448+
rpc_metadata=rpc_metadata,
3449+
rpc_timeout=rpc_timeout,
3450+
)
3451+
)
33333452

33343453
# TODO:
33353454
# update_options
@@ -6053,6 +6172,38 @@ class TerminateWorkflowInput:
60536172
rpc_timeout: Optional[timedelta]
60546173

60556174

6175+
@dataclass
6176+
class CancelActivityInput:
6177+
"""Input for :py:meth:`OutboundInterceptor.cancel_activity`."""
6178+
6179+
activity_id: str
6180+
run_id: str
6181+
reason: Optional[str]
6182+
rpc_metadata: Mapping[str, Union[str, bytes]]
6183+
rpc_timeout: Optional[timedelta]
6184+
6185+
6186+
@dataclass
6187+
class TerminateActivityInput:
6188+
"""Input for :py:meth:`OutboundInterceptor.terminate_activity`."""
6189+
6190+
activity_id: str
6191+
run_id: str
6192+
reason: Optional[str]
6193+
rpc_metadata: Mapping[str, Union[str, bytes]]
6194+
rpc_timeout: Optional[timedelta]
6195+
6196+
6197+
@dataclass
6198+
class DescribeActivityInput:
6199+
"""Input for :py:meth:`OutboundInterceptor.describe_activity`."""
6200+
6201+
activity_id: str
6202+
run_id: str
6203+
rpc_metadata: Mapping[str, Union[str, bytes]]
6204+
rpc_timeout: Optional[timedelta]
6205+
6206+
60566207
@dataclass
60576208
class StartWorkflowUpdateInput:
60586209
"""Input for :py:meth:`OutboundInterceptor.start_workflow_update`."""
@@ -6391,6 +6542,22 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
63916542
"""Called for every :py:meth:`WorkflowHandle.terminate` call."""
63926543
await self.next.terminate_workflow(input)
63936544

6545+
### Activity calls
6546+
6547+
async def cancel_activity(self, input: CancelActivityInput) -> None:
6548+
"""Called for every :py:meth:`ActivityHandle.cancel` call."""
6549+
await self.next.cancel_activity(input)
6550+
6551+
async def terminate_activity(self, input: TerminateActivityInput) -> None:
6552+
"""Called for every :py:meth:`ActivityHandle.terminate` call."""
6553+
await self.next.terminate_activity(input)
6554+
6555+
async def describe_activity(
6556+
self, input: DescribeActivityInput
6557+
) -> ActivityExecutionDescription:
6558+
"""Called for every :py:meth:`ActivityHandle.describe` call."""
6559+
return await self.next.describe_activity(input)
6560+
63946561
async def start_workflow_update(
63956562
self, input: StartWorkflowUpdateInput
63966563
) -> WorkflowUpdateHandle[Any]:
@@ -6842,6 +7009,62 @@ async def terminate_workflow(self, input: TerminateWorkflowInput) -> None:
68427009
req, retry=True, metadata=input.rpc_metadata, timeout=input.rpc_timeout
68437010
)
68447011

7012+
async def cancel_activity(self, input: CancelActivityInput) -> None:
7013+
"""Cancel a standalone activity."""
7014+
await self._client.workflow_service.request_cancel_activity_execution(
7015+
temporalio.api.workflowservice.v1.RequestCancelActivityExecutionRequest(
7016+
namespace=self._client.namespace,
7017+
activity_id=input.activity_id,
7018+
run_id=input.run_id,
7019+
identity=self._client.identity,
7020+
request_id=str(uuid.uuid4()),
7021+
reason=input.reason or "",
7022+
),
7023+
retry=True,
7024+
metadata=input.rpc_metadata,
7025+
timeout=input.rpc_timeout,
7026+
)
7027+
7028+
async def terminate_activity(self, input: TerminateActivityInput) -> None:
7029+
"""Terminate a standalone activity."""
7030+
await self._client.workflow_service.terminate_activity_execution(
7031+
temporalio.api.workflowservice.v1.TerminateActivityExecutionRequest(
7032+
namespace=self._client.namespace,
7033+
activity_id=input.activity_id,
7034+
run_id=input.run_id,
7035+
reason=input.reason or "",
7036+
identity=self._client.identity,
7037+
),
7038+
retry=True,
7039+
metadata=input.rpc_metadata,
7040+
timeout=input.rpc_timeout,
7041+
)
7042+
7043+
async def describe_activity(
7044+
self, input: DescribeActivityInput
7045+
) -> ActivityExecutionDescription:
7046+
"""Describe a standalone activity."""
7047+
resp = await self._client.workflow_service.describe_activity_execution(
7048+
temporalio.api.workflowservice.v1.DescribeActivityExecutionRequest(
7049+
namespace=self._client.namespace,
7050+
activity_id=input.activity_id,
7051+
run_id=input.run_id,
7052+
include_input=True,
7053+
),
7054+
retry=True,
7055+
metadata=input.rpc_metadata,
7056+
timeout=input.rpc_timeout,
7057+
)
7058+
return await ActivityExecutionDescription._from_raw_info(
7059+
resp.info,
7060+
self._client.data_converter.with_context(
7061+
WorkflowSerializationContext(
7062+
namespace=self._client.namespace,
7063+
workflow_id=input.activity_id, # Using activity_id as workflow_id for standalone activities
7064+
)
7065+
),
7066+
)
7067+
68457068
async def start_workflow_update(
68467069
self, input: StartWorkflowUpdateInput
68477070
) -> WorkflowUpdateHandle[Any]:

0 commit comments

Comments
 (0)