Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.helpers import prune_dict
Expand Down Expand Up @@ -149,13 +150,17 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
client=client,
config_overrides=self.waiter_config_overrides,
)
await async_wait(
waiter,
self.waiter_delay,
self.attempts,
self.waiter_args,
self.failure_message,
self.status_message,
self.status_queries,
)
yield TriggerEvent({"status": "success", self.return_key: self.return_value})
try:
await async_wait(
waiter,
self.waiter_delay,
self.attempts,
self.waiter_args,
self.failure_message,
self.status_message,
self.status_queries,
)
except AirflowException as e:
yield TriggerEvent({"status": "error", "message": str(e), self.return_key: self.return_value})
Comment thread
vincbeck marked this conversation as resolved.
else:
yield TriggerEvent({"status": "success", self.return_key: self.return_value})
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
waiter_args={"JobName": job_name, "RunId": run_id},
failure_message="AWS Glue job failed.",
status_message="Status of AWS Glue job is",
status_queries=["JobRun.JobRunState"],
status_queries=["JobRun.JobRunState", "JobRun.ErrorMessage"],
return_key="run_id",
return_value=run_id,
waiter_delay=waiter_delay,
Expand Down
19 changes: 19 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/triggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger

if TYPE_CHECKING:
Expand Down Expand Up @@ -125,3 +126,21 @@ async def test_run(self, wait_mock: MagicMock):
assert isinstance(res.payload, dict)
assert res.payload["status"] == "success"
assert res.payload["hello"] == "world"

@pytest.mark.asyncio
@mock.patch(
"airflow.providers.amazon.aws.triggers.base.async_wait",
side_effect=AirflowException("AWS Glue job failed.\nTerminal failure"),
)
async def test_run_error_yields_event(self, wait_mock: MagicMock):
self.trigger.return_key = "hello"
self.trigger.return_value = "world"

generator = self.trigger.run()
res: TriggerEvent = await generator.asend(None)

wait_mock.assert_called_once()
assert isinstance(res.payload, dict)
assert res.payload["status"] == "error"
assert "AWS Glue job failed." in res.payload["message"]
assert res.payload["hello"] == "world"
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
GlueJobCompleteTrigger,
)
from airflow.providers.common.compat.sdk import AirflowException
from airflow.triggers.base import TriggerEvent

from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type
Expand Down Expand Up @@ -85,10 +84,12 @@ async def test_wait_job_failed(self, mock_async_conn, mock_get_waiter):
waiter_delay=10,
)
generator = trigger.run()
event = await generator.asend(None)

with pytest.raises(AirflowException):
await generator.asend(None)
assert_expected_waiter_type(mock_get_waiter, "job_complete")
assert event.payload["status"] == "error"
assert "message" in event.payload
assert event.payload["run_id"] == "JobRunId"

def test_serialization(self):
trigger = GlueJobCompleteTrigger(
Expand Down
Loading