diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py new file mode 100644 index 0000000000000..1334e99069f3b --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datetime import datetime + +from airflow.api_fastapi.core_api.base import BaseModel + + +class DagResponse(BaseModel): + """Schema for DAG response.""" + + dag_id: str + is_paused: bool + bundle_name: str | None + bundle_version: str | None + relative_fileloc: str | None + owners: str | None + tags: list[str] + next_dagrun: datetime | None diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index aeef4d092b194..a076592d6471a 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -24,6 +24,7 @@ assets, connections, dag_runs, + dags, health, hitl, task_instances, @@ -43,6 +44,7 @@ authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"]) +authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"]) authenticated_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) authenticated_router.include_router( task_reschedules.router, prefix="/task-reschedules", tags=["Task Reschedules"] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py new file mode 100644 index 0000000000000..9061b4862161c --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from fastapi import APIRouter, HTTPException, status + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.dags import DagResponse +from airflow.models.dag import DagModel + +router = APIRouter() + + +@router.get( + "/{dag_id}", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"}, + }, +) +def get_dag( + dag_id: str, + session: SessionDep, +) -> DagResponse: + """Get a DAG.""" + dag_model: DagModel | None = session.get(DagModel, dag_id) + if not dag_model: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"The Dag with dag_id: `{dag_id}` was not found", + }, + ) + + return DagResponse( + dag_id=dag_model.dag_id, + is_paused=dag_model.is_paused, + bundle_name=dag_model.bundle_name, + bundle_version=dag_model.bundle_version, + relative_fileloc=dag_model.relative_fileloc, + owners=dag_model.owners, + tags=sorted(tag.name for tag in dag_model.tags), + next_dagrun=dag_model.next_dagrun, + ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py index f4f2d967e0254..2cbe2e3007b3f 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py @@ -39,9 +39,11 @@ ModifyDeferredTaskKwargsToJsonValue, RemoveUpstreamMapIndexesField, ) +from airflow.api_fastapi.execution_api.versions.v2026_04_13 import AddDagEndpoint bundle = VersionBundle( HeadVersion(), + Version("2026-04-13", AddDagEndpoint), Version( "2026-03-31", MakeDagRunStartDateNullable, diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py new file mode 100644 index 0000000000000..95da513d7bc38 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_04_13.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from cadwyn import VersionChange, endpoint + + +class AddDagEndpoint(VersionChange): + """Add the `/dags/{dag_id}` endpoint.""" + + description = __doc__ + + instructions_to_migrate_to_previous_version = (endpoint("/dags/{dag_id}", ["GET"]).didnt_exist,) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py new file mode 100644 index 0000000000000..78b8f74a1d6ea --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import ANY + +import pytest +from sqlalchemy import update + +from airflow.models import DagModel +from airflow.providers.standard.operators.empty import EmptyOperator + +from tests_common.test_utils.db import clear_db_runs + +pytestmark = pytest.mark.db_test + + +class TestDag: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize( + ("state", "expected"), + [ + pytest.param(True, True), + pytest.param(False, False), + ], + ) + def test_get_dag(self, client, session, dag_maker, state, expected): + """Test getting a DAG.""" + + dag_id = "test_get_dag" + next_dagrun = datetime(2026, 4, 13, tzinfo=timezone.utc) + + with dag_maker(dag_id=dag_id, session=session, serialized=True, tags=["z_tag", "a_tag"]): + EmptyOperator(task_id="test_task") + + session.execute( + update(DagModel) + .where(DagModel.dag_id == dag_id) + .values( + is_paused=state, + bundle_version="bundle-version", + relative_fileloc="dags/example.py", + owners="owner_1", + next_dagrun=next_dagrun, + ) + ) + + session.commit() + + response = client.get( + f"/execution/dags/{dag_id}", + ) + + assert response.status_code == 200 + assert response.json() == { + "dag_id": dag_id, + "is_paused": expected, + "bundle_name": "dag_maker", + "bundle_version": "bundle-version", + "relative_fileloc": "dags/example.py", + "owners": "owner_1", + "tags": ["a_tag", "z_tag"], + "next_dagrun": "2026-04-13T00:00:00Z", + } + + def test_dag_not_found(self, client, session, dag_maker): + """Test Dag not found""" + + dag_id = "test_get_dag" + + response = client.get( + f"/execution/dags/{dag_id}", + ) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "The Dag with dag_id: `test_get_dag` was not found", + "reason": "not_found", + } + } + + def test_get_dag_defaults(self, client, session, dag_maker): + """Test getting a DAG with default model values.""" + + dag_id = "test_get_dag_defaults" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + session.commit() + + response = client.get( + f"/execution/dags/{dag_id}", + ) + + assert response.status_code == 200 + assert response.json() == { + "dag_id": dag_id, + "is_paused": False, + "bundle_name": "dag_maker", + "bundle_version": None, + "relative_fileloc": "test_dags.py", + "owners": "airflow", + "tags": [], + "next_dagrun": ANY, + } diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py new file mode 100644 index 0000000000000..7269344394474 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2026_03_31/test_dags.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +pytestmark = pytest.mark.db_test + + +@pytest.fixture +def old_ver_client(client): + client.headers["Airflow-API-Version"] = "2026-03-31" + return client + + +def test_dag_endpoint_not_available_in_previous_version(old_ver_client): + response = old_ver_client.get("/execution/dags/test_dag") + + assert response.status_code == 404 diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 3dc57345f6fe0..c46aceda7949b 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1911,6 +1911,7 @@ def get_type_names(union_type): "GetAssetEventByAssetAlias", "GetDagRun", "GetDagRunState", + "GetDag", "GetDRCount", "GetTaskBreadcrumbs", "GetTaskRescheduleStartDate", @@ -1935,6 +1936,7 @@ def get_type_names(union_type): in_task_runner_but_not_in_dag_processing_process = { "AssetResult", "AssetEventsResult", + "DagResult", "DagRunResult", "DagRunStateResult", "DRCount", diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 3d77880427929..802a34192e352 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1317,6 +1317,7 @@ def get_type_names(union_type): "ResendLoggingFD", "CreateHITLDetailPayload", "SetRenderedMapIndex", + "GetDag", } in_task_but_not_in_trigger_runner = { @@ -1336,6 +1337,7 @@ def get_type_names(union_type): "PreviousDagRunResult", "PreviousTIResult", "HITLDetailRequestResult", + "DagResult", } supervisor_diff = ( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index f7106bb4aa597..90374f76be50f 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -45,6 +45,7 @@ AssetEventsResponse, AssetResponse, ConnectionResponse, + DagResponse, DagRun, DagRunStateResponse, DagRunType, @@ -772,6 +773,18 @@ def get_previous( return PreviousDagRunResult(dag_run=resp.json()) +class DagsOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get(self, dag_id: str) -> DagResponse: + """Get a DAG via the API server.""" + resp = self.client.get(f"dags/{dag_id}") + return DagResponse.model_validate_json(resp.read()) + + class HITLOperations: """ Operations related to Human in the loop. Require Airflow 3.1+. @@ -1012,6 +1025,12 @@ def hitl(self): """Operations related to HITL Responses.""" return HITLOperations(self) + @lru_cache() # type: ignore[misc] + @property + def dags(self) -> DagsOperations: + """Operations related to DAGs.""" + return DagsOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 617e2a23934b4..b6c08e9d76c82 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -27,7 +27,7 @@ from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, RootModel -API_VERSION: Final[str] = "2026-03-31" +API_VERSION: Final[str] = "2026-04-13" class AssetAliasReferenceAssetEventDagRun(BaseModel): @@ -78,6 +78,21 @@ class ConnectionResponse(BaseModel): extra: Annotated[str | None, Field(title="Extra")] = None +class DagResponse(BaseModel): + """ + Schema for DAG response. + """ + + dag_id: Annotated[str, Field(title="Dag Id")] + is_paused: Annotated[bool, Field(title="Is Paused")] + bundle_name: Annotated[str | None, Field(title="Bundle Name")] = None + bundle_version: Annotated[str | None, Field(title="Bundle Version")] = None + relative_fileloc: Annotated[str | None, Field(title="Relative Fileloc")] = None + owners: Annotated[str | None, Field(title="Owners")] = None + tags: Annotated[list[str], Field(title="Tags")] + next_dagrun: Annotated[AwareDatetime | None, Field(title="Next Dagrun")] = None + + class DagRunAssetReference(BaseModel): """ DagRun serializer for asset responses. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 15755e640d97e..2a9a9bbd4eb02 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -71,6 +71,7 @@ AssetResponse, BundleInfo, ConnectionResponse, + DagResponse, DagRun, DagRunStateResponse, HITLDetailRequest, @@ -694,6 +695,21 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") +class DagResult(DagResponse): + type: Literal["DagResult"] = "DagResult" + + @classmethod + def from_api_response(cls, dag_response: DagResponse) -> DagResult: + """ + Create result class from API Response. + + API Response is autogenerated from the API schema, so we need to convert it to Result + for communication between the Supervisor and the task process since it needs a + discriminator field. + """ + return cls(**dag_response.model_dump(exclude_defaults=True), type="DagResult") + + ToTask = Annotated[ AssetResult | AssetEventsResult @@ -701,6 +717,7 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest | DagRunResult | DagRunStateResult | DRCount + | DagResult | ErrorResponse | PrevSuccessfulDagRunResult | PreviousTIResult @@ -1018,6 +1035,11 @@ class MaskSecret(BaseModel): type: Literal["MaskSecret"] = "MaskSecret" +class GetDag(BaseModel): + dag_id: str + type: Literal["GetDag"] = "GetDag" + + ToSupervisor = Annotated[ DeferTask | DeleteXCom @@ -1029,6 +1051,7 @@ class MaskSecret(BaseModel): | GetDagRun | GetDagRunState | GetDRCount + | GetDag | GetPrevSuccessfulDagRun | GetPreviousDagRun | GetPreviousTI diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 9894d4fff3153..1dfefee54047c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -66,6 +66,7 @@ AssetResult, ConnectionResult, CreateHITLDetailPayload, + DagResult, DagRunResult, DagRunStateResult, DeferTask, @@ -77,6 +78,7 @@ GetAssetEventByAsset, GetAssetEventByAssetAlias, GetConnection, + GetDag, GetDagRun, GetDagRunState, GetDRCount, @@ -1477,6 +1479,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dump_opts = {"exclude_unset": True} elif isinstance(msg, MaskSecret): mask_secret(msg.value, msg.name) + elif isinstance(msg, GetDag): + dag = self.client.dags.get( + dag_id=msg.dag_id, + ) + resp = DagResult.from_api_response(dag) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 611c4fc28ec19..aa5a5a08ad426 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -73,10 +73,12 @@ from airflow.sdk.execution_time.comms import ( AssetEventDagRunReferenceResult, CommsDecoder, + DagResult, DagRunStateResult, DeferTask, DRCount, ErrorResponse, + GetDag, GetDagRunState, GetDRCount, GetPreviousDagRun, @@ -670,6 +672,16 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str: return response.state + @staticmethod + def get_dag(dag_id: str) -> DagResult: + """Return the DAG with the given dag_id.""" + response = SUPERVISOR_COMMS.send(msg=GetDag(dag_id=dag_id)) + + if TYPE_CHECKING: + assert isinstance(response, DagResult) + + return response + @property def log_url(self) -> str: run_id = quote(self.run_id) diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index 71c24474806e0..4ce43c57c117b 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -39,6 +39,7 @@ from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey from airflow.sdk.definitions.context import Context from airflow.sdk.definitions.mappedoperator import MappedOperator + from airflow.sdk.execution_time.comms import DagResult Operator: TypeAlias = BaseOperator | MappedOperator @@ -183,6 +184,9 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: ... + @staticmethod + def get_dag(dag_id: str) -> DagResult: ... + # Public alias for RuntimeTaskInstanceProtocol class TaskInstance(RuntimeTaskInstanceProtocol): diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index 6347d611a6a89..7d960b76570a8 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -19,7 +19,7 @@ import json import pickle -from datetime import datetime +from datetime import datetime, timezone as dt_timezone from typing import TYPE_CHECKING from unittest import mock @@ -37,6 +37,7 @@ AssetEventsResponse, AssetResponse, ConnectionResponse, + DagResponse, DagRunState, DagRunStateResponse, HITLDetailRequest, @@ -1566,3 +1567,88 @@ def test_cache_miss_on_different_parameters(self): assert ctx1 is not ctx2 assert info.misses == 2 assert info.currsize == 2 + + +class TestDagsOperations: + def test_get(self): + """Test that the client can get a dag.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/test_dag": + return httpx.Response( + status_code=200, + json={ + "dag_id": "test_dag", + "is_paused": False, + "bundle_name": "dags-folder", + "bundle_version": "bundle-version", + "relative_fileloc": "dags/example.py", + "owners": "owner_1", + "tags": ["a_tag", "z_tag"], + "next_dagrun": "2026-04-13T00:00:00Z", + }, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dags.get(dag_id="test_dag") + + assert result == DagResponse( + dag_id="test_dag", + is_paused=False, + bundle_name="dags-folder", + bundle_version="bundle-version", + relative_fileloc="dags/example.py", + owners="owner_1", + tags=["a_tag", "z_tag"], + next_dagrun=datetime(2026, 4, 13, tzinfo=dt_timezone.utc), + ) + + def test_get_not_found(self): + """Test that getting a missing dag raises a server response error.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/missing_dag": + return httpx.Response( + status_code=404, + json={ + "detail": { + "message": "The Dag with dag_id: `missing_dag` was not found", + "reason": "not_found", + } + }, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(ServerResponseError) as exc_info: + client.dags.get(dag_id="missing_dag") + + assert exc_info.value.response.status_code == 404 + assert exc_info.value.detail == { + "detail": { + "message": "The Dag with dag_id: `missing_dag` was not found", + "reason": "not_found", + } + } + + def test_get_server_error(self): + """Test that a server error while getting a dag.""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/test_dag": + return httpx.Response( + status_code=500, + headers=[("content-Type", "application/json")], + json={ + "reason": "internal_server_error", + "message": "Internal Server Error", + }, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + + with pytest.raises(ServerResponseError): + client.dags.get(dag_id="test_dag") diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 66eda6f5b87f4..ef6cd19b8d7ae 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -30,7 +30,7 @@ import time from contextlib import nullcontext from dataclasses import dataclass, field -from datetime import datetime +from datetime import datetime, timezone as dt_timezone from operator import attrgetter from random import randint from textwrap import dedent @@ -71,6 +71,7 @@ CommsDecoder, ConnectionResult, CreateHITLDetailPayload, + DagResult, DagRunResult, DagRunStateResult, DeferTask, @@ -83,6 +84,7 @@ GetAssetEventByAsset, GetAssetEventByAssetAlias, GetConnection, + GetDag, GetDagRun, GetDagRunState, GetDRCount, @@ -2499,6 +2501,37 @@ class RequestTestCase: }, test_id="get_task_breadcrumbs", ), + RequestTestCase( + message=GetDag(dag_id="test_dag"), + expected_body={ + "dag_id": "test_dag", + "is_paused": False, + "bundle_name": "dags-folder", + "bundle_version": "bundle-version", + "relative_fileloc": "dags/example.py", + "owners": "owner_1", + "tags": ["a_tag", "z_tag"], + "next_dagrun": datetime(2026, 4, 13, tzinfo=dt_timezone.utc), + "type": "DagResult", + }, + client_mock=ClientMock( + method_path="dags.get", + kwargs={ + "dag_id": "test_dag", + }, + response=DagResult( + dag_id="test_dag", + is_paused=False, + bundle_name="dags-folder", + bundle_version="bundle-version", + relative_fileloc="dags/example.py", + owners="owner_1", + tags=["a_tag", "z_tag"], + next_dagrun=datetime(2026, 4, 13, tzinfo=dt_timezone.utc), + ), + ), + test_id="get_dag", + ), ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 47de56c384fa9..0eab6a50afcf6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -80,11 +80,13 @@ AssetEventsResult, BundleInfo, ConnectionResult, + DagResult, DagRunStateResult, DeferTask, DRCount, ErrorResponse, GetConnection, + GetDag, GetDagRunState, GetDRCount, GetPreviousDagRun, @@ -2942,6 +2944,29 @@ def execute(self, context): if hasattr(call.kwargs.get("msg"), "rendered_fields") ) + def test_get_dag(self, mock_supervisor_comms): + """Test that get_dag sends the correct request and returns the dag.""" + mock_supervisor_comms.send.return_value = DagResult( + dag_id="test_dag", + is_paused=False, + bundle_name="dags-folder", + bundle_version="bundle-version", + relative_fileloc="dags/example.py", + owners="owner_1", + tags=["a_tag", "z_tag"], + next_dagrun=datetime(2026, 4, 13, tzinfo=dt_timezone.utc), + ) + + response = RuntimeTaskInstance.get_dag( + dag_id="test_dag", + ) + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetDag(dag_id="test_dag"), + ) + assert response.dag_id == "test_dag" + assert response.is_paused is False + class TestXComAfterTaskExecution: @pytest.mark.parametrize(