From f10b0fe7700e2d4193432715007459dc44991dc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Weso=C5=82owski?= Date: Fri, 18 Nov 2022 15:21:28 +0100 Subject: [PATCH] Now the EMRServerlessStartJobOperator has the countdown and check_interval_seconds parameters. --- airflow/emr_serverless/hooks/emr.py | 7 +++++-- airflow/emr_serverless/operators/emr.py | 13 ++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/airflow/emr_serverless/hooks/emr.py b/airflow/emr_serverless/hooks/emr.py index 0dcdca2..d65c1c4 100644 --- a/airflow/emr_serverless/hooks/emr.py +++ b/airflow/emr_serverless/hooks/emr.py @@ -24,6 +24,9 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +DEFAULT_COUNTDOWN = 25 * 60 +DEFAULT_CHECK_INTERVAL_SECONDS = 60 + class EmrServerlessHook(AwsBaseHook): """ @@ -54,8 +57,8 @@ def waiter( failure_states: Set, object_type: str, action: str, - countdown: int = 25 * 60, - check_interval_seconds: int = 60, + countdown: int = DEFAULT_COUNTDOWN, + check_interval_seconds: int = DEFAULT_CHECK_INTERVAL_SECONDS, ) -> None: """ Will run the sensor until it turns True. diff --git a/airflow/emr_serverless/operators/emr.py b/airflow/emr_serverless/operators/emr.py index 7546f31..609c8e7 100644 --- a/airflow/emr_serverless/operators/emr.py +++ b/airflow/emr_serverless/operators/emr.py @@ -19,7 +19,12 @@ from typing import TYPE_CHECKING, Dict, Optional, Sequence from uuid import uuid4 -from emr_serverless.hooks.emr import EmrServerlessHook +from emr_serverless.hooks.emr import ( + EmrServerlessHook, + DEFAULT_COUNTDOWN, + DEFAULT_CHECK_INTERVAL_SECONDS, +) + from emr_serverless.sensors.emr import ( EmrServerlessApplicationSensor, EmrServerlessJobSensor, @@ -158,6 +163,8 @@ def __init__( config: Optional[dict] = None, wait_for_completion: bool = True, aws_conn_id: str = "aws_default", + countdown: int = DEFAULT_COUNTDOWN, + check_interval_seconds: int = DEFAULT_CHECK_INTERVAL_SECONDS, **kwargs, ): self.aws_conn_id = aws_conn_id @@ -167,6 +174,8 @@ def __init__( self.configuration_overrides = configuration_overrides self.wait_for_completion = wait_for_completion self.config = config or {} + self.countdown = countdown + self.check_interval_seconds = check_interval_seconds super().__init__(**kwargs) self.client_request_token = client_request_token or str(uuid4()) @@ -221,6 +230,8 @@ def execute(self, context: "Context") -> Dict: failure_states=EmrServerlessJobSensor.FAILURE_STATES, object_type="job", action="run", + countdown=self.countdown, + check_interval_seconds=self.check_interval_seconds ) return response["jobRunId"]