From 3b464092ca5d67664021545ba7a50dfa7f9d149c Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Tue, 21 Oct 2025 15:48:39 +0100 Subject: [PATCH 01/12] Add GetDagState endpoint to execution_api --- .../execution_api/datamodels/dags.py | 26 ++++++ .../execution_api/routes/__init__.py | 2 + .../api_fastapi/execution_api/routes/dags.py | 57 +++++++++++++ .../execution_api/versions/head/test_dags.py | 79 +++++++++++++++++++ task-sdk/src/airflow/sdk/api/client.py | 18 +++++ .../airflow/sdk/api/datamodels/_generated.py | 8 ++ .../src/airflow/sdk/execution_time/comms.py | 12 +++ .../airflow/sdk/execution_time/supervisor.py | 5 ++ .../airflow/sdk/execution_time/task_runner.py | 12 +++ task-sdk/tests/task_sdk/api/test_client.py | 19 +++++ .../execution_time/test_task_runner.py | 15 ++++ 11 files changed, 253 insertions(+) create mode 100644 airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py create mode 100644 airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py create mode 100644 airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py new file mode 100644 index 0000000000000..a00225fea0bbc --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.api_fastapi.core_api.base import BaseModel + + +class DagStateResponse(BaseModel): + """Schema for DAG State response.""" + + is_paused: bool diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index aeef4d092b194..a076592d6471a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -24,6 +24,7 @@ assets, connections, dag_runs, + dags, health, hitl, task_instances, @@ -43,6 +44,7 @@ authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"]) +authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"]) authenticated_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) authenticated_router.include_router( task_reschedules.router, prefix="/task-reschedules", tags=["Task Reschedules"] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py new file mode 100644 index 0000000000000..9b10393217c41 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, HTTPException, status + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.dags import DagStateResponse +from airflow.models.dag import DagModel + +router = APIRouter() + + +log = logging.getLogger(__name__) + + +@router.get( + "/{dag_id}/state", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"}, + }, +) +def get_dag_state( + dag_id: str, + session: SessionDep, +) -> DagStateResponse: + """Get a DAG Run State.""" + dag_model: DagModel = session.get(DagModel, dag_id) + if not dag_model: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"The Dag with dag_id: `{dag_id}` was not found", + }, + ) + + is_paused = False if dag_model.is_paused is None else dag_model.is_paused + + return DagStateResponse(is_paused=is_paused) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py new file mode 100644 index 0000000000000..b17e6cd1056cb --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.models import DagModel +from airflow.providers.standard.operators.empty import EmptyOperator + +from tests_common.test_utils.db import clear_db_runs + +pytestmark = pytest.mark.db_test + + +class TestDagState: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize( + "state, expected", + [ + (True, True), + (False, False), + (None, False), + ], + ) + def test_dag_is_paused(self, state, expected, client, session, dag_maker): + """Test DagState is active or paused""" + + dag_id = "test_dag_is_paused" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"is_paused": state}) + + session.commit() + + response = client.get( + f"/execution/dags/{dag_id}/state", + ) + + assert response.status_code == 200 + assert response.json() == {"is_paused": expected} + + def test_dag_not_found(self, client, session, dag_maker): + """Test Dag not found""" + + dag_id = "test_dag_is_paused" + + response = client.get( + f"/execution/dags/{dag_id}/state", + ) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "The Dag with dag_id: `test_dag_is_paused` was not found", + "reason": "not_found", + } + } diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index f7106bb4aa597..9a19b89ada96a 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -49,6 +49,7 @@ DagRunStateResponse, DagRunType, HITLDetailRequest, + DagStateResponse, HITLDetailResponse, HITLUser, InactiveAssetsResponse, @@ -772,6 +773,18 @@ def get_previous( return PreviousDagRunResult(dag_run=resp.json()) +class DagsOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get_state(self, dag_id: str) -> DagStateResponse: + """Get the state of a Dag via the API server.""" + resp = self.client.get(f"dags/{dag_id}/state") + return DagStateResponse.model_validate_json(resp.read()) + + class HITLOperations: """ Operations related to Human in the loop. Require Airflow 3.1+. @@ -1012,6 +1025,11 @@ def hitl(self): """Operations related to HITL Responses.""" return HITLOperations(self) + @lru_cache() # type: ignore[misc] + @property + def dags(self): + return DagsOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 617e2a23934b4..caaaab7cdd2e3 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -132,6 +132,14 @@ class DagRunType(str, Enum): ASSET_MATERIALIZATION = "asset_materialization" +class DagStateResponse(BaseModel): + """ + Schema for DAG State response. + """ + + is_paused: Annotated[bool, Field(title="Is Paused")] + + class HITLUser(BaseModel): """ Schema for a Human-in-the-loop users. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 15755e640d97e..e0557848e7d59 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -694,6 +694,11 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") +class DagStateResult(BaseModel): + is_paused: bool + type: Literal["DagStateResult"] = "DagStateResult" + + ToTask = Annotated[ AssetResult | AssetEventsResult @@ -701,6 +706,7 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest | DagRunResult | DagRunStateResult | DRCount + | DagStateResult | ErrorResponse | PrevSuccessfulDagRunResult | PreviousTIResult @@ -1018,6 +1024,11 @@ class MaskSecret(BaseModel): type: Literal["MaskSecret"] = "MaskSecret" +class GetDagState(BaseModel): + dag_id: str + type: Literal["GetDagState"] = "GetDagState" + + ToSupervisor = Annotated[ DeferTask | DeleteXCom @@ -1029,6 +1040,7 @@ class MaskSecret(BaseModel): | GetDagRun | GetDagRunState | GetDRCount + | GetDagState | GetPrevSuccessfulDagRun | GetPreviousDagRun | GetPreviousTI diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 9894d4fff3153..6d55de532ff60 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -79,6 +79,7 @@ GetConnection, GetDagRun, GetDagRunState, + GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -1477,6 +1478,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dump_opts = {"exclude_unset": True} elif isinstance(msg, MaskSecret): mask_secret(msg.value, msg.name) + elif isinstance(msg, GetDagState): + resp = self.client.dags.get_state( + dag_id=msg.dag_id, + ) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 611c4fc28ec19..ab9b11af74647 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -74,10 +74,12 @@ AssetEventDagRunReferenceResult, CommsDecoder, DagRunStateResult, + DagStateResult, DeferTask, DRCount, ErrorResponse, GetDagRunState, + GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -670,6 +672,16 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str: return response.state + @staticmethod + def get_dag_state(dag_id: str) -> DagStateResult: + """Return the state of the Dag run with the given Run ID.""" + response = SUPERVISOR_COMMS.send(msg=GetDagState(dag_id=dag_id)) + + if TYPE_CHECKING: + assert isinstance(response, DagStateResult) + + return response + @property def log_url(self) -> str: run_id = quote(self.run_id) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 6347d611a6a89..974f76552787b 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -40,6 +40,7 @@ DagRunState, DagRunStateResponse, HITLDetailRequest, + DagStateResponse, HITLDetailResponse, HITLUser, TerminalTIState, @@ -1566,3 +1567,21 @@ def test_cache_miss_on_different_parameters(self): assert ctx1 is not ctx2 assert info.misses == 2 assert info.currsize == 2 + + +class TestDagsOperations: + def test_get_state(self): + """Test that the client can get the state of a dag run""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/test_dag/state": + return httpx.Response( + status_code=200, + json={"is_paused": False}, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dags.get_state(dag_id="test_dag") + + assert result == DagStateResponse(is_paused=False) diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 47de56c384fa9..329831a0e94bc 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -81,11 +81,13 @@ BundleInfo, ConnectionResult, DagRunStateResult, + DagStateResult, DeferTask, DRCount, ErrorResponse, GetConnection, GetDagRunState, + GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -2942,6 +2944,19 @@ def execute(self, context): if hasattr(call.kwargs.get("msg"), "rendered_fields") ) + def test_get_dag_state(self, mock_supervisor_comms): + """Test that get_dag_state sends the correct request and returns the state.""" + mock_supervisor_comms.send.return_value = DagStateResult(is_paused=False) + + response = RuntimeTaskInstance.get_dag_state( + dag_id="test_dag", + ) + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetDagState(dag_id="test_dag"), + ) + assert response.is_paused is False + class TestXComAfterTaskExecution: @pytest.mark.parametrize( From b1cb126c068daa807593ba398a52ea092b7cf27c Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Thu, 23 Oct 2025 11:27:06 +0100 Subject: [PATCH 02/12] Fixup tests --- .../tests/task_sdk/execution_time/test_supervisor.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 66eda6f5b87f4..407b424cf7eff 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -73,6 +73,7 @@ CreateHITLDetailPayload, DagRunResult, DagRunStateResult, + DagStateResult, DeferTask, DeleteVariable, DeleteXCom, @@ -85,6 +86,7 @@ GetConnection, GetDagRun, GetDagRunState, + GetDagState, GetDRCount, GetHITLDetailResponse, GetPreviousDagRun, @@ -2499,6 +2501,16 @@ class RequestTestCase: }, test_id="get_task_breadcrumbs", ), + RequestTestCase( + message=GetDagState(dag_id="test_dag"), + expected_body={"is_paused": False, "type": "DagStateResult"}, + client_mock=ClientMock( + method_path="dags.get_state", + args=("test_dag",), + response=DagStateResult(is_paused=False), + ), + test_id="get_dag_run_state", + ), ] From d7c31c76f83b1ff8677716fb9ad0920f63b9616f Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Fri, 24 Oct 2025 22:03:44 +0100 Subject: [PATCH 03/12] Fixup tests --- airflow-core/tests/unit/jobs/test_triggerer_job.py | 2 ++ .../tests/task_sdk/execution_time/test_supervisor.py | 9 ++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 3d77880427929..645c31f394056 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1317,6 +1317,7 @@ def get_type_names(union_type): "ResendLoggingFD", "CreateHITLDetailPayload", "SetRenderedMapIndex", + "GetDagState", } in_task_but_not_in_trigger_runner = { @@ -1336,6 +1337,7 @@ def get_type_names(union_type): "PreviousDagRunResult", "PreviousTIResult", "HITLDetailRequestResult", + "DagStateResult", } supervisor_diff = ( diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 407b424cf7eff..e4a9b263acf64 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1450,7 +1450,8 @@ class RequestTestCase: client_mock=ClientMock( method_path="connections.get", args=("test_conn",), - response=ConnectionResult(conn_id="test_conn", conn_type="mysql", schema="mysql"), # type: ignore[call-arg] + response=ConnectionResult(conn_id="test_conn", conn_type="mysql", schema="mysql"), + # type: ignore[call-arg] ), expected_body={ "conn_id": "test_conn", @@ -2506,10 +2507,12 @@ class RequestTestCase: expected_body={"is_paused": False, "type": "DagStateResult"}, client_mock=ClientMock( method_path="dags.get_state", - args=("test_dag",), + kwargs={ + "dag_id": "test_dag", + }, response=DagStateResult(is_paused=False), ), - test_id="get_dag_run_state", + test_id="get_dag_state", ), ] From a1a7c434af3ef3073521783fb4b9666153745cf2 Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Wed, 7 Jan 2026 20:19:50 +0000 Subject: [PATCH 04/12] Fix static checks --- .../execution_api/versions/head/test_dags.py | 10 +++++----- task-sdk/src/airflow/sdk/api/client.py | 2 +- task-sdk/tests/task_sdk/api/test_client.py | 2 +- .../tests/task_sdk/execution_time/test_supervisor.py | 3 +-- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py index b17e6cd1056cb..e03760697ac21 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py @@ -35,14 +35,14 @@ def teardown_method(self): clear_db_runs() @pytest.mark.parametrize( - "state, expected", + ("state", "expected"), [ - (True, True), - (False, False), - (None, False), + pytest.param(True, True), + pytest.param(False, False), + pytest.param(None, False), ], ) - def test_dag_is_paused(self, state, expected, client, session, dag_maker): + def test_dag_is_paused(self, client, session, dag_maker, state, expected): """Test DagState is active or paused""" dag_id = "test_dag_is_paused" diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 9a19b89ada96a..82f6510012989 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -48,8 +48,8 @@ DagRun, DagRunStateResponse, DagRunType, - HITLDetailRequest, DagStateResponse, + HITLDetailRequest, HITLDetailResponse, HITLUser, InactiveAssetsResponse, diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 974f76552787b..d3abc3972e467 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -39,8 +39,8 @@ ConnectionResponse, DagRunState, DagRunStateResponse, - HITLDetailRequest, DagStateResponse, + HITLDetailRequest, HITLDetailResponse, HITLUser, TerminalTIState, diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index e4a9b263acf64..f54efe9e1d84b 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -1450,8 +1450,7 @@ class RequestTestCase: client_mock=ClientMock( method_path="connections.get", args=("test_conn",), - response=ConnectionResult(conn_id="test_conn", conn_type="mysql", schema="mysql"), - # type: ignore[call-arg] + response=ConnectionResult(conn_id="test_conn", conn_type="mysql", schema="mysql"), # type: ignore[call-arg] ), expected_body={ "conn_id": "test_conn", From e6642207966fb1bbcb5cc8c2753c9bfde6dad9c1 Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Thu, 8 Jan 2026 21:09:54 +0000 Subject: [PATCH 05/12] Fixup tests --- airflow-core/tests/unit/dag_processing/test_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 3dc57345f6fe0..3c6ba810f8e47 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1911,6 +1911,7 @@ def get_type_names(union_type): "GetAssetEventByAssetAlias", "GetDagRun", "GetDagRunState", + "GetDagState", "GetDRCount", "GetTaskBreadcrumbs", "GetTaskRescheduleStartDate", From 9a6f82861886fede17ed89d73c32f2b631b920af Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Fri, 13 Mar 2026 15:53:39 +0000 Subject: [PATCH 06/12] Fixup tests --- airflow-core/tests/unit/dag_processing/test_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 3c6ba810f8e47..066e2df2dfd4a 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1936,6 +1936,7 @@ def get_type_names(union_type): in_task_runner_but_not_in_dag_processing_process = { "AssetResult", "AssetEventsResult", + "DagStateResult", "DagRunResult", "DagRunStateResult", "DRCount", From b6fe52cfad6a9c4b837d061b9c30e67bc1338caf Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Fri, 13 Mar 2026 16:54:19 +0000 Subject: [PATCH 07/12] Fixup return type --- task-sdk/src/airflow/sdk/execution_time/comms.py | 15 +++++++++++++-- .../src/airflow/sdk/execution_time/supervisor.py | 8 +++++++- .../src/airflow/sdk/execution_time/task_runner.py | 1 + 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index e0557848e7d59..ffff6333906b5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -73,6 +73,7 @@ ConnectionResponse, DagRun, DagRunStateResponse, + DagStateResponse, HITLDetailRequest, InactiveAssetsResponse, PreviousTIResponse, @@ -694,10 +695,20 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") -class DagStateResult(BaseModel): - is_paused: bool +class DagStateResult(DagStateResponse): type: Literal["DagStateResult"] = "DagStateResult" + @classmethod + def from_api_response(cls, dg_state_response: DagStateResponse) -> DagStateResult: + """ + Create result class from API Response. + + API Response is autogenerated from the API schema, so we need to convert it to Result + for communication between the Supervisor and the task process since it needs a + discriminator field. + """ + return cls(**dg_state_response.model_dump(exclude_defaults=True), type="DagStateResult") + ToTask = Annotated[ AssetResult diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 6d55de532ff60..f087f1767c5ca 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -52,6 +52,7 @@ from airflow.sdk.api.datamodels._generated import ( AssetResponse, ConnectionResponse, + DagStateResponse, TaskInstance, TaskInstanceState, TaskStatesResponse, @@ -68,6 +69,7 @@ CreateHITLDetailPayload, DagRunResult, DagRunStateResult, + DagStateResult, DeferTask, DeleteVariable, DeleteXCom, @@ -1479,9 +1481,13 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: elif isinstance(msg, MaskSecret): mask_secret(msg.value, msg.name) elif isinstance(msg, GetDagState): - resp = self.client.dags.get_state( + dg_state = self.client.dags.get_state( dag_id=msg.dag_id, ) + if isinstance(dg_state, DagStateResponse): + resp = DagStateResult.from_api_response(dg_state) + else: + resp = dg_state else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index ab9b11af74647..fba7edc33c96a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -676,6 +676,7 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str: def get_dag_state(dag_id: str) -> DagStateResult: """Return the state of the Dag run with the given Run ID.""" response = SUPERVISOR_COMMS.send(msg=GetDagState(dag_id=dag_id)) + print(response) if TYPE_CHECKING: assert isinstance(response, DagStateResult) From 3483ff6ec7662d19f3bf3cffef3f86048d4bd6d8 Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Fri, 13 Mar 2026 18:02:55 +0000 Subject: [PATCH 08/12] Fixup tests and mypy --- .../src/airflow/api_fastapi/execution_api/routes/dags.py | 2 +- .../unit/api_fastapi/execution_api/versions/head/test_dags.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py index 9b10393217c41..9762b7b981713 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py @@ -42,7 +42,7 @@ def get_dag_state( session: SessionDep, ) -> DagStateResponse: """Get a DAG Run State.""" - dag_model: DagModel = session.get(DagModel, dag_id) + dag_model: DagModel | None = session.get(DagModel, dag_id) if not dag_model: raise HTTPException( status.HTTP_404_NOT_FOUND, diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py index e03760697ac21..4eb3c349b7d11 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py @@ -18,6 +18,7 @@ from __future__ import annotations import pytest +from sqlalchemy import update from airflow.models import DagModel from airflow.providers.standard.operators.empty import EmptyOperator @@ -39,7 +40,6 @@ def teardown_method(self): [ pytest.param(True, True), pytest.param(False, False), - pytest.param(None, False), ], ) def test_dag_is_paused(self, client, session, dag_maker, state, expected): @@ -50,7 +50,7 @@ def test_dag_is_paused(self, client, session, dag_maker, state, expected): with dag_maker(dag_id=dag_id, session=session, serialized=True): EmptyOperator(task_id="test_task") - session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"is_paused": state}) + session.execute(update(DagModel).where(DagModel.dag_id == dag_id).values(is_paused=state)) session.commit() From 29929345298d4061992e1edc255308ace82903dd Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Fri, 13 Mar 2026 21:23:43 +0000 Subject: [PATCH 09/12] Add cadwyn migration --- .../execution_api/versions/__init__.py | 2 ++ .../execution_api/versions/v2026_04_13.py | 28 +++++++++++++++++++ .../airflow/sdk/api/datamodels/_generated.py | 2 +- 3 files changed, 31 insertions(+), 1 deletion(-) create mode 100644 airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index f4f2d967e0254..7e2be9b652b07 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -39,9 +39,11 @@ ModifyDeferredTaskKwargsToJsonValue, RemoveUpstreamMapIndexesField, ) +from airflow.api_fastapi.execution_api.versions.v2026_04_13 import AddDagStateEndpoint bundle = VersionBundle( HeadVersion(), + Version("2026-04-13", AddDagStateEndpoint), Version( "2026-03-31", MakeDagRunStartDateNullable, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py new file mode 100644 index 0000000000000..f1192da12df40 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import VersionChange, endpoint + + +class AddDagStateEndpoint(VersionChange): + """Add the `/dags/{dag_id}/state` endpoint.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = (endpoint("/dags/{dag_id}/state", ["GET"]).didnt_exist,) diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index caaaab7cdd2e3..b67f84d482b97 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel -API_VERSION: Final[str] = "2026-03-31" +API_VERSION: Final[str] = "2026-04-13" class AssetAliasReferenceAssetEventDagRun(BaseModel): From 30393dba6fe7ff2104f1376fc2593bbc58a072a4 Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Fri, 13 Mar 2026 21:52:31 +0000 Subject: [PATCH 10/12] Add tests --- .../versions/v2026_03_31/test_dags.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py new file mode 100644 index 0000000000000..8003b31c6651a --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def old_ver_client(client): + client.headers["Airflow-API-Version"] = "2026-03-31" + return client + + +def test_dag_state_endpoint_not_available_in_previous_version(old_ver_client): + response = old_ver_client.get("/execution/dags/test_dag/state") + + assert response.status_code == 404 From f494557d637382dc2bcee4e64d8e1272b9cc8355 Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Tue, 17 Mar 2026 20:52:50 +0000 Subject: [PATCH 11/12] Resolve comments round1 --- .../airflow/api_fastapi/execution_api/routes/dags.py | 11 ++--------- task-sdk/src/airflow/sdk/api/client.py | 3 ++- task-sdk/src/airflow/sdk/execution_time/supervisor.py | 6 +----- .../src/airflow/sdk/execution_time/task_runner.py | 3 +-- task-sdk/src/airflow/sdk/types.py | 4 ++++ 5 files changed, 10 insertions(+), 17 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py index 9762b7b981713..4ca74d6fcedb9 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py @@ -17,8 +17,6 @@ from __future__ import annotations -import logging - from fastapi import APIRouter, HTTPException, status from airflow.api_fastapi.common.db.common import SessionDep @@ -28,9 +26,6 @@ router = APIRouter() -log = logging.getLogger(__name__) - - @router.get( "/{dag_id}/state", responses={ @@ -41,7 +36,7 @@ def get_dag_state( dag_id: str, session: SessionDep, ) -> DagStateResponse: - """Get a DAG Run State.""" + """Get the state of a DAG.""" dag_model: DagModel | None = session.get(DagModel, dag_id) if not dag_model: raise HTTPException( @@ -52,6 +47,4 @@ def get_dag_state( }, ) - is_paused = False if dag_model.is_paused is None else dag_model.is_paused - - return DagStateResponse(is_paused=is_paused) + return DagStateResponse(is_paused=dag_model.is_paused) diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 82f6510012989..cd2dec35d7d2b 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -1027,7 +1027,8 @@ def hitl(self): @lru_cache() # type: ignore[misc] @property - def dags(self): + def dags(self) -> DagsOperations: + """Operations related to DAGs.""" return DagsOperations(self) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index f087f1767c5ca..8afb5a1fff033 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -52,7 +52,6 @@ from airflow.sdk.api.datamodels._generated import ( AssetResponse, ConnectionResponse, - DagStateResponse, TaskInstance, TaskInstanceState, TaskStatesResponse, @@ -1484,10 +1483,7 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dg_state = self.client.dags.get_state( dag_id=msg.dag_id, ) - if isinstance(dg_state, DagStateResponse): - resp = DagStateResult.from_api_response(dg_state) - else: - resp = dg_state + resp = DagStateResult.from_api_response(dg_state) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index fba7edc33c96a..edf929568eeb2 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -674,9 +674,8 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str: @staticmethod def get_dag_state(dag_id: str) -> DagStateResult: - """Return the state of the Dag run with the given Run ID.""" + """Return the state of the DAG with the given dag_id.""" response = SUPERVISOR_COMMS.send(msg=GetDagState(dag_id=dag_id)) - print(response) if TYPE_CHECKING: assert isinstance(response, DagStateResult) diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 71c24474806e0..a0c6dd7ee0326 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -39,6 +39,7 @@ from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.sdk.execution_time.comms import DagStateResult Operator: TypeAlias = BaseOperator | MappedOperator @@ -183,6 +184,9 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: ... + @staticmethod + def get_dag_state(dag_id: str) -> DagStateResult: ... + # Public alias for RuntimeTaskInstanceProtocol class TaskInstance(RuntimeTaskInstanceProtocol): From 6c2a3cf58b8aa6068ba4fe76ddf5a7da52c8ee57 Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Wed, 18 Mar 2026 02:13:24 +0000 Subject: [PATCH 12/12] Replace DAG state endpoint with DAG details endpoint --- .../execution_api/datamodels/dags.py | 13 ++- .../api_fastapi/execution_api/routes/dags.py | 21 +++-- .../execution_api/versions/__init__.py | 4 +- .../execution_api/versions/v2026_04_13.py | 6 +- .../execution_api/versions/head/test_dags.py | 71 +++++++++++++--- .../versions/v2026_03_31/test_dags.py | 4 +- .../unit/dag_processing/test_processor.py | 4 +- .../tests/unit/jobs/test_triggerer_job.py | 4 +- task-sdk/src/airflow/sdk/api/client.py | 10 +-- .../airflow/sdk/api/datamodels/_generated.py | 23 +++-- .../src/airflow/sdk/execution_time/comms.py | 18 ++-- .../airflow/sdk/execution_time/supervisor.py | 10 +-- .../airflow/sdk/execution_time/task_runner.py | 12 +-- task-sdk/src/airflow/sdk/types.py | 4 +- task-sdk/tests/task_sdk/api/test_client.py | 83 +++++++++++++++++-- .../execution_time/test_supervisor.py | 35 ++++++-- .../execution_time/test_task_runner.py | 24 ++++-- 17 files changed, 258 insertions(+), 88 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py index a00225fea0bbc..1334e99069f3b 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py @@ -17,10 +17,19 @@ from __future__ import annotations +from datetime import datetime + from airflow.api_fastapi.core_api.base import BaseModel -class DagStateResponse(BaseModel): - """Schema for DAG State response.""" +class DagResponse(BaseModel): + """Schema for DAG response.""" + dag_id: str is_paused: bool + bundle_name: str | None + bundle_version: str | None + relative_fileloc: str | None + owners: str | None + tags: list[str] + next_dagrun: datetime | None diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py index 4ca74d6fcedb9..9061b4862161c 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py @@ -20,23 +20,23 @@ from fastapi import APIRouter, HTTPException, status from airflow.api_fastapi.common.db.common import SessionDep -from airflow.api_fastapi.execution_api.datamodels.dags import DagStateResponse +from airflow.api_fastapi.execution_api.datamodels.dags import DagResponse from airflow.models.dag import DagModel router = APIRouter() @router.get( - "/{dag_id}/state", + "/{dag_id}", responses={ status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"}, }, ) -def get_dag_state( +def get_dag( dag_id: str, session: SessionDep, -) -> DagStateResponse: - """Get the state of a DAG.""" +) -> DagResponse: + """Get a DAG.""" dag_model: DagModel | None = session.get(DagModel, dag_id) if not dag_model: raise HTTPException( @@ -47,4 +47,13 @@ def get_dag_state( }, ) - return DagStateResponse(is_paused=dag_model.is_paused) + return DagResponse( + dag_id=dag_model.dag_id, + is_paused=dag_model.is_paused, + bundle_name=dag_model.bundle_name, + bundle_version=dag_model.bundle_version, + relative_fileloc=dag_model.relative_fileloc, + owners=dag_model.owners, + tags=sorted(tag.name for tag in dag_model.tags), + next_dagrun=dag_model.next_dagrun, + ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index 7e2be9b652b07..2cbe2e3007b3f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -39,11 +39,11 @@ ModifyDeferredTaskKwargsToJsonValue, RemoveUpstreamMapIndexesField, ) -from airflow.api_fastapi.execution_api.versions.v2026_04_13 import AddDagStateEndpoint +from airflow.api_fastapi.execution_api.versions.v2026_04_13 import AddDagEndpoint bundle = VersionBundle( HeadVersion(), - Version("2026-04-13", AddDagStateEndpoint), + Version("2026-04-13", AddDagEndpoint), Version( "2026-03-31", MakeDagRunStartDateNullable, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py index f1192da12df40..95da513d7bc38 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py @@ -20,9 +20,9 @@ from cadwyn import VersionChange, endpoint -class AddDagStateEndpoint(VersionChange): - """Add the `/dags/{dag_id}/state` endpoint.""" +class AddDagEndpoint(VersionChange): + """Add the `/dags/{dag_id}` endpoint.""" description = __doc__ - instructions_to_migrate_to_previous_version = (endpoint("/dags/{dag_id}/state", ["GET"]).didnt_exist,) + instructions_to_migrate_to_previous_version = (endpoint("/dags/{dag_id}", ["GET"]).didnt_exist,) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py index 4eb3c349b7d11..78b8f74a1d6ea 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py @@ -17,6 +17,9 @@ from __future__ import annotations +from datetime import datetime, timezone +from unittest.mock import ANY + import pytest from sqlalchemy import update @@ -28,7 +31,7 @@ pytestmark = pytest.mark.db_test -class TestDagState: +class TestDag: def setup_method(self): clear_db_runs() @@ -42,38 +45,84 @@ def teardown_method(self): pytest.param(False, False), ], ) - def test_dag_is_paused(self, client, session, dag_maker, state, expected): - """Test DagState is active or paused""" + def test_get_dag(self, client, session, dag_maker, state, expected): + """Test getting a DAG.""" - dag_id = "test_dag_is_paused" + dag_id = "test_get_dag" + next_dagrun = datetime(2026, 4, 13, tzinfo=timezone.utc) - with dag_maker(dag_id=dag_id, session=session, serialized=True): + with dag_maker(dag_id=dag_id, session=session, serialized=True, tags=["z_tag", "a_tag"]): EmptyOperator(task_id="test_task") - session.execute(update(DagModel).where(DagModel.dag_id == dag_id).values(is_paused=state)) + session.execute( + update(DagModel) + .where(DagModel.dag_id == dag_id) + .values( + is_paused=state, + bundle_version="bundle-version", + relative_fileloc="dags/example.py", + owners="owner_1", + next_dagrun=next_dagrun, + ) + ) session.commit() response = client.get( - f"/execution/dags/{dag_id}/state", + f"/execution/dags/{dag_id}", ) assert response.status_code == 200 - assert response.json() == {"is_paused": expected} + assert response.json() == { + "dag_id": dag_id, + "is_paused": expected, + "bundle_name": "dag_maker", + "bundle_version": "bundle-version", + "relative_fileloc": "dags/example.py", + "owners": "owner_1", + "tags": ["a_tag", "z_tag"], + "next_dagrun": "2026-04-13T00:00:00Z", + } def test_dag_not_found(self, client, session, dag_maker): """Test Dag not found""" - dag_id = "test_dag_is_paused" + dag_id = "test_get_dag" response = client.get( - f"/execution/dags/{dag_id}/state", + f"/execution/dags/{dag_id}", ) assert response.status_code == 404 assert response.json() == { "detail": { - "message": "The Dag with dag_id: `test_dag_is_paused` was not found", + "message": "The Dag with dag_id: `test_get_dag` was not found", "reason": "not_found", } } + + def test_get_dag_defaults(self, client, session, dag_maker): + """Test getting a DAG with default model values.""" + + dag_id = "test_get_dag_defaults" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + session.commit() + + response = client.get( + f"/execution/dags/{dag_id}", + ) + + assert response.status_code == 200 + assert response.json() == { + "dag_id": dag_id, + "is_paused": False, + "bundle_name": "dag_maker", + "bundle_version": None, + "relative_fileloc": "test_dags.py", + "owners": "airflow", + "tags": [], + "next_dagrun": ANY, + } diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py index 8003b31c6651a..7269344394474 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py @@ -28,7 +28,7 @@ def old_ver_client(client): return client -def test_dag_state_endpoint_not_available_in_previous_version(old_ver_client): - response = old_ver_client.get("/execution/dags/test_dag/state") +def test_dag_endpoint_not_available_in_previous_version(old_ver_client): + response = old_ver_client.get("/execution/dags/test_dag") assert response.status_code == 404 diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 066e2df2dfd4a..c46aceda7949b 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1911,7 +1911,7 @@ def get_type_names(union_type): "GetAssetEventByAssetAlias", "GetDagRun", "GetDagRunState", - "GetDagState", + "GetDag", "GetDRCount", "GetTaskBreadcrumbs", "GetTaskRescheduleStartDate", @@ -1936,7 +1936,7 @@ def get_type_names(union_type): in_task_runner_but_not_in_dag_processing_process = { "AssetResult", "AssetEventsResult", - "DagStateResult", + "DagResult", "DagRunResult", "DagRunStateResult", "DRCount", diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 645c31f394056..802a34192e352 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1317,7 +1317,7 @@ def get_type_names(union_type): "ResendLoggingFD", "CreateHITLDetailPayload", "SetRenderedMapIndex", - "GetDagState", + "GetDag", } in_task_but_not_in_trigger_runner = { @@ -1337,7 +1337,7 @@ def get_type_names(union_type): "PreviousDagRunResult", "PreviousTIResult", "HITLDetailRequestResult", - "DagStateResult", + "DagResult", } supervisor_diff = ( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index cd2dec35d7d2b..90374f76be50f 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -45,10 +45,10 @@ AssetEventsResponse, AssetResponse, ConnectionResponse, + DagResponse, DagRun, DagRunStateResponse, DagRunType, - DagStateResponse, HITLDetailRequest, HITLDetailResponse, HITLUser, @@ -779,10 +779,10 @@ class DagsOperations: def __init__(self, client: Client): self.client = client - def get_state(self, dag_id: str) -> DagStateResponse: - """Get the state of a Dag via the API server.""" - resp = self.client.get(f"dags/{dag_id}/state") - return DagStateResponse.model_validate_json(resp.read()) + def get(self, dag_id: str) -> DagResponse: + """Get a DAG via the API server.""" + resp = self.client.get(f"dags/{dag_id}") + return DagResponse.model_validate_json(resp.read()) class HITLOperations: diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index b67f84d482b97..b6c08e9d76c82 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -78,6 +78,21 @@ class ConnectionResponse(BaseModel): extra: Annotated[str | None, Field(title="Extra")] = None +class DagResponse(BaseModel): + """ + Schema for DAG response. + """ + + dag_id: Annotated[str, Field(title="Dag Id")] + is_paused: Annotated[bool, Field(title="Is Paused")] + bundle_name: Annotated[str | None, Field(title="Bundle Name")] = None + bundle_version: Annotated[str | None, Field(title="Bundle Version")] = None + relative_fileloc: Annotated[str | None, Field(title="Relative Fileloc")] = None + owners: Annotated[str | None, Field(title="Owners")] = None + tags: Annotated[list[str], Field(title="Tags")] + next_dagrun: Annotated[AwareDatetime | None, Field(title="Next Dagrun")] = None + + class DagRunAssetReference(BaseModel): """ DagRun serializer for asset responses. @@ -132,14 +147,6 @@ class DagRunType(str, Enum): ASSET_MATERIALIZATION = "asset_materialization" -class DagStateResponse(BaseModel): - """ - Schema for DAG State response. - """ - - is_paused: Annotated[bool, Field(title="Is Paused")] - - class HITLUser(BaseModel): """ Schema for a Human-in-the-loop users. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index ffff6333906b5..2a9a9bbd4eb02 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -71,9 +71,9 @@ AssetResponse, BundleInfo, ConnectionResponse, + DagResponse, DagRun, DagRunStateResponse, - DagStateResponse, HITLDetailRequest, InactiveAssetsResponse, PreviousTIResponse, @@ -695,11 +695,11 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") -class DagStateResult(DagStateResponse): - type: Literal["DagStateResult"] = "DagStateResult" +class DagResult(DagResponse): + type: Literal["DagResult"] = "DagResult" @classmethod - def from_api_response(cls, dg_state_response: DagStateResponse) -> DagStateResult: + def from_api_response(cls, dag_response: DagResponse) -> DagResult: """ Create result class from API Response. @@ -707,7 +707,7 @@ def from_api_response(cls, dg_state_response: DagStateResponse) -> DagStateResul for communication between the Supervisor and the task process since it needs a discriminator field. """ - return cls(**dg_state_response.model_dump(exclude_defaults=True), type="DagStateResult") + return cls(**dag_response.model_dump(exclude_defaults=True), type="DagResult") ToTask = Annotated[ @@ -717,7 +717,7 @@ def from_api_response(cls, dg_state_response: DagStateResponse) -> DagStateResul | DagRunResult | DagRunStateResult | DRCount - | DagStateResult + | DagResult | ErrorResponse | PrevSuccessfulDagRunResult | PreviousTIResult @@ -1035,9 +1035,9 @@ class MaskSecret(BaseModel): type: Literal["MaskSecret"] = "MaskSecret" -class GetDagState(BaseModel): +class GetDag(BaseModel): dag_id: str - type: Literal["GetDagState"] = "GetDagState" + type: Literal["GetDag"] = "GetDag" ToSupervisor = Annotated[ @@ -1051,7 +1051,7 @@ class GetDagState(BaseModel): | GetDagRun | GetDagRunState | GetDRCount - | GetDagState + | GetDag | GetPrevSuccessfulDagRun | GetPreviousDagRun | GetPreviousTI diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 8afb5a1fff033..1dfefee54047c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -66,9 +66,9 @@ AssetResult, ConnectionResult, CreateHITLDetailPayload, + DagResult, DagRunResult, DagRunStateResult, - DagStateResult, DeferTask, DeleteVariable, DeleteXCom, @@ -78,9 +78,9 @@ GetAssetEventByAsset, GetAssetEventByAssetAlias, GetConnection, + GetDag, GetDagRun, GetDagRunState, - GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -1479,11 +1479,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dump_opts = {"exclude_unset": True} elif isinstance(msg, MaskSecret): mask_secret(msg.value, msg.name) - elif isinstance(msg, GetDagState): - dg_state = self.client.dags.get_state( + elif isinstance(msg, GetDag): + dag = self.client.dags.get( dag_id=msg.dag_id, ) - resp = DagStateResult.from_api_response(dg_state) + resp = DagResult.from_api_response(dag) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index edf929568eeb2..aa5a5a08ad426 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -73,13 +73,13 @@ from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, CommsDecoder, + DagResult, DagRunStateResult, - DagStateResult, DeferTask, DRCount, ErrorResponse, + GetDag, GetDagRunState, - GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -673,12 +673,12 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str: return response.state @staticmethod - def get_dag_state(dag_id: str) -> DagStateResult: - """Return the state of the DAG with the given dag_id.""" - response = SUPERVISOR_COMMS.send(msg=GetDagState(dag_id=dag_id)) + def get_dag(dag_id: str) -> DagResult: + """Return the DAG with the given dag_id.""" + response = SUPERVISOR_COMMS.send(msg=GetDag(dag_id=dag_id)) if TYPE_CHECKING: - assert isinstance(response, DagStateResult) + assert isinstance(response, DagResult) return response diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index a0c6dd7ee0326..4ce43c57c117b 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -39,7 +39,7 @@ from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator - from airflow.sdk.execution_time.comms import DagStateResult + from airflow.sdk.execution_time.comms import DagResult Operator: TypeAlias = BaseOperator | MappedOperator @@ -185,7 +185,7 @@ def get_dr_count( def get_dagrun_state(dag_id: str, run_id: str) -> str: ... @staticmethod - def get_dag_state(dag_id: str) -> DagStateResult: ... + def get_dag(dag_id: str) -> DagResult: ... # Public alias for RuntimeTaskInstanceProtocol diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index d3abc3972e467..7d960b76570a8 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -19,7 +19,7 @@ import json import pickle -from datetime import datetime +from datetime import datetime, timezone as dt_timezone from typing import TYPE_CHECKING from unittest import mock @@ -37,9 +37,9 @@ AssetEventsResponse, AssetResponse, ConnectionResponse, + DagResponse, DagRunState, DagRunStateResponse, - DagStateResponse, HITLDetailRequest, HITLDetailResponse, HITLUser, @@ -1570,18 +1570,85 @@ def test_cache_miss_on_different_parameters(self): class TestDagsOperations: - def test_get_state(self): - """Test that the client can get the state of a dag run""" + def test_get(self): + """Test that the client can get a dag.""" def handle_request(request: httpx.Request) -> httpx.Response: - if request.url.path == "/dags/test_dag/state": + if request.url.path == "/dags/test_dag": return httpx.Response( status_code=200, - json={"is_paused": False}, + json={ + "dag_id": "test_dag", + "is_paused": False, + "bundle_name": "dags-folder", + "bundle_version": "bundle-version", + "relative_fileloc": "dags/example.py", + "owners": "owner_1", + "tags": ["a_tag", "z_tag"], + "next_dagrun": "2026-04-13T00:00:00Z", + }, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dags.get(dag_id="test_dag") + + assert result == DagResponse( + dag_id="test_dag", + is_paused=False, + bundle_name="dags-folder", + bundle_version="bundle-version", + relative_fileloc="dags/example.py", + owners="owner_1", + tags=["a_tag", "z_tag"], + next_dagrun=datetime(2026, 4, 13, tzinfo=dt_timezone.utc), + ) + + def test_get_not_found(self): + """Test that getting a missing dag raises a server response error.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/missing_dag": + return httpx.Response( + status_code=404, + json={ + "detail": { + "message": "The Dag with dag_id: `missing_dag` was not found", + "reason": "not_found", + } + }, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(ServerResponseError) as exc_info: + client.dags.get(dag_id="missing_dag") + + assert exc_info.value.response.status_code == 404 + assert exc_info.value.detail == { + "detail": { + "message": "The Dag with dag_id: `missing_dag` was not found", + "reason": "not_found", + } + } + + def test_get_server_error(self): + """Test that a server error while getting a dag.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/test_dag": + return httpx.Response( + status_code=500, + headers=[("content-Type", "application/json")], + json={ + "reason": "internal_server_error", + "message": "Internal Server Error", + }, ) return httpx.Response(status_code=200) client = make_client(transport=httpx.MockTransport(handle_request)) - result = client.dags.get_state(dag_id="test_dag") - assert result == DagStateResponse(is_paused=False) + with pytest.raises(ServerResponseError): + client.dags.get(dag_id="test_dag") diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index f54efe9e1d84b..ef6cd19b8d7ae 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -30,7 +30,7 @@ import time from contextlib import nullcontext from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone as dt_timezone from operator import attrgetter from random import randint from textwrap import dedent @@ -71,9 +71,9 @@ CommsDecoder, ConnectionResult, CreateHITLDetailPayload, + DagResult, DagRunResult, DagRunStateResult, - DagStateResult, DeferTask, DeleteVariable, DeleteXCom, @@ -84,9 +84,9 @@ GetAssetEventByAsset, GetAssetEventByAssetAlias, GetConnection, + GetDag, GetDagRun, GetDagRunState, - GetDagState, GetDRCount, GetHITLDetailResponse, GetPreviousDagRun, @@ -2502,16 +2502,35 @@ class RequestTestCase: test_id="get_task_breadcrumbs", ), RequestTestCase( - message=GetDagState(dag_id="test_dag"), - expected_body={"is_paused": False, "type": "DagStateResult"}, + message=GetDag(dag_id="test_dag"), + expected_body={ + "dag_id": "test_dag", + "is_paused": False, + "bundle_name": "dags-folder", + "bundle_version": "bundle-version", + "relative_fileloc": "dags/example.py", + "owners": "owner_1", + "tags": ["a_tag", "z_tag"], + "next_dagrun": datetime(2026, 4, 13, tzinfo=dt_timezone.utc), + "type": "DagResult", + }, client_mock=ClientMock( - method_path="dags.get_state", + method_path="dags.get", kwargs={ "dag_id": "test_dag", }, - response=DagStateResult(is_paused=False), + response=DagResult( + dag_id="test_dag", + is_paused=False, + bundle_name="dags-folder", + bundle_version="bundle-version", + relative_fileloc="dags/example.py", + owners="owner_1", + tags=["a_tag", "z_tag"], + next_dagrun=datetime(2026, 4, 13, tzinfo=dt_timezone.utc), + ), ), - test_id="get_dag_state", + test_id="get_dag", ), ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 329831a0e94bc..0eab6a50afcf6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -80,14 +80,14 @@ AssetEventsResult, BundleInfo, ConnectionResult, + DagResult, DagRunStateResult, - DagStateResult, DeferTask, DRCount, ErrorResponse, GetConnection, + GetDag, GetDagRunState, - GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -2944,17 +2944,27 @@ def execute(self, context): if hasattr(call.kwargs.get("msg"), "rendered_fields") ) - def test_get_dag_state(self, mock_supervisor_comms): - """Test that get_dag_state sends the correct request and returns the state.""" - mock_supervisor_comms.send.return_value = DagStateResult(is_paused=False) + def test_get_dag(self, mock_supervisor_comms): + """Test that get_dag sends the correct request and returns the dag.""" + mock_supervisor_comms.send.return_value = DagResult( + dag_id="test_dag", + is_paused=False, + bundle_name="dags-folder", + bundle_version="bundle-version", + relative_fileloc="dags/example.py", + owners="owner_1", + tags=["a_tag", "z_tag"], + next_dagrun=datetime(2026, 4, 13, tzinfo=dt_timezone.utc), + ) - response = RuntimeTaskInstance.get_dag_state( + response = RuntimeTaskInstance.get_dag( dag_id="test_dag", ) mock_supervisor_comms.send.assert_called_once_with( - msg=GetDagState(dag_id="test_dag"), + msg=GetDag(dag_id="test_dag"), ) + assert response.dag_id == "test_dag" assert response.is_paused is False