Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def delete_expired_executions(
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = cast(CursorResult, session.execute(delete_stmt))
session.commit()
total_deleted += result.rowcount
total_deleted += result.rowcount or 0

# If we deleted fewer than the batch size, we're done
if len(execution_ids) < batch_size:
Expand Down Expand Up @@ -334,7 +334,7 @@ def delete_executions_by_app(
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = cast(CursorResult, session.execute(delete_stmt))
session.commit()
total_deleted += result.rowcount
total_deleted += result.rowcount or 0

# If we deleted fewer than the batch size, we're done
if len(execution_ids) < batch_size:
Expand Down Expand Up @@ -393,7 +393,7 @@ def delete_executions_by_ids(
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
result = cast(CursorResult, session.execute(stmt))
session.commit()
return result.rowcount
return result.rowcount or 0

@override
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
Expand Down
4 changes: 2 additions & 2 deletions api/repositories/sqlalchemy_api_workflow_run_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def delete_runs_by_ids(
result = cast(CursorResult, session.execute(stmt))
session.commit()

deleted_count = result.rowcount
deleted_count = result.rowcount or 0
logger.info("Deleted %s workflow runs by IDs", deleted_count)
return deleted_count

Expand Down Expand Up @@ -357,7 +357,7 @@ def delete_runs_by_app(
result = cast(CursorResult, session.execute(delete_stmt))
session.commit()

batch_deleted = result.rowcount
batch_deleted = result.rowcount or 0
total_deleted += batch_deleted

logger.info("Deleted batch of %s workflow runs for app %s", batch_deleted, app_id)
Expand Down
25 changes: 16 additions & 9 deletions api/tests/unit_tests/models/test_snippet.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
from types import SimpleNamespace
from unittest.mock import Mock

import pytest

from models.model import Account, Tag
from models.snippet import CustomizedSnippet
from models.workflow import Workflow


def test_graph_dict_returns_empty_without_workflow_id() -> None:
Expand All @@ -14,8 +15,8 @@ def test_graph_dict_returns_empty_without_workflow_id() -> None:


def test_graph_dict_loads_published_workflow_graph(monkeypatch: pytest.MonkeyPatch) -> None:
workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"id": "llm-1"}], "edges": []}))
session = SimpleNamespace(get=Mock(return_value=workflow))
workflow = Workflow(graph=json.dumps({"nodes": [{"id": "llm-1"}], "edges": []}))
session = Mock(get=Mock(return_value=workflow))
monkeypatch.setattr("models.snippet.db.session", session)
snippet = CustomizedSnippet(workflow_id="workflow-1")

Expand All @@ -24,7 +25,7 @@ def test_graph_dict_loads_published_workflow_graph(monkeypatch: pytest.MonkeyPat


def test_graph_dict_returns_empty_when_workflow_missing(monkeypatch: pytest.MonkeyPatch) -> None:
session = SimpleNamespace(get=Mock(return_value=None))
session = Mock(get=Mock(return_value=None))
monkeypatch.setattr("models.snippet.db.session", session)
snippet = CustomizedSnippet(workflow_id="missing-workflow")

Expand All @@ -39,8 +40,10 @@ def test_input_fields_list_parses_json_or_returns_empty() -> None:


def test_tags_returns_query_results_or_empty(monkeypatch: pytest.MonkeyPatch) -> None:
tags = [SimpleNamespace(id="tag-1")]
session = SimpleNamespace(scalars=Mock(return_value=SimpleNamespace(all=Mock(return_value=tags))))
tag = Tag()
tag.id = "tag-1"
tags = [tag]
session = Mock(scalars=Mock(return_value=Mock(all=Mock(return_value=tags))))
monkeypatch.setattr("models.snippet.db.session", session)
snippet = CustomizedSnippet(id="snippet-1", tenant_id="tenant-1")

Expand All @@ -51,9 +54,13 @@ def test_tags_returns_query_results_or_empty(monkeypatch: pytest.MonkeyPatch) ->


def test_account_properties_and_author_name(monkeypatch: pytest.MonkeyPatch) -> None:
account = SimpleNamespace(id="account-1", name="Ada")
updated_account = SimpleNamespace(id="account-2", name="Grace")
session = SimpleNamespace(
account = Account()
account.id = "account-1"
account.name = "Ada"
updated_account = Account()
updated_account.id = "account-2"
updated_account.name = "Grace"
session = Mock(
get=Mock(side_effect=lambda _model, account_id: account if account_id == "account-1" else updated_account)
)
monkeypatch.setattr("models.snippet.db.session", session)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

from collections.abc import Sequence
from datetime import UTC, datetime
from unittest.mock import MagicMock

from core.repositories.factory import OrderConfig
from graphon.entities import WorkflowNodeExecution
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
)


class _ConcreteNodeRepo(DifyAPISQLAlchemyWorkflowNodeExecutionRepository):
def save(self, execution: WorkflowNodeExecution) -> None:
pass

def save_execution_data(self, execution: WorkflowNodeExecution) -> None:
pass

def get_by_workflow_execution(
self,
workflow_execution_id: str,
order_config: OrderConfig | None = None,
) -> Sequence[WorkflowNodeExecution]:
return []


def _make_repo() -> tuple[_ConcreteNodeRepo, MagicMock]:
session = MagicMock()
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
return _ConcreteNodeRepo(session_maker), session


def test_delete_expired_executions_rowcount_none_returns_zero() -> None:
repo, session = _make_repo()
# select returns one ID (< default batch_size=1000), loop exits after first batch
select_result = MagicMock()
select_result.scalars.return_value.all.return_value = ["exec-1"]
delete_result = MagicMock()
delete_result.rowcount = None
session.execute.side_effect = [select_result, delete_result]

assert repo.delete_expired_executions("tenant-1", datetime(2024, 1, 1, tzinfo=UTC)) == 0


def test_delete_executions_by_app_rowcount_none_returns_zero() -> None:
repo, session = _make_repo()
select_result = MagicMock()
select_result.scalars.return_value.all.return_value = ["exec-1"]
delete_result = MagicMock()
delete_result.rowcount = None
session.execute.side_effect = [select_result, delete_result]

assert repo.delete_executions_by_app("tenant-1", "app-1") == 0


def test_delete_executions_by_ids_rowcount_none_returns_zero() -> None:
repo, session = _make_repo()
delete_result = MagicMock()
delete_result.rowcount = None
session.execute.return_value = delete_result

assert repo.delete_executions_by_ids(["exec-1", "exec-2"]) == 0
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,21 @@

from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock

from graphon.entities import WorkflowExecution
from graphon.nodes.human_input.entities import FormDefinition, ParagraphInputConfig, UserActionConfig
from graphon.nodes.human_input.enums import FormInputType
from models.human_input import RecipientType
from repositories.sqlalchemy_api_workflow_run_repository import _build_human_input_required_reason
from repositories.sqlalchemy_api_workflow_run_repository import (
DifyAPISQLAlchemyWorkflowRunRepository,
_build_human_input_required_reason,
)


class _ConcreteRunRepo(DifyAPISQLAlchemyWorkflowRunRepository):
def save(self, execution: WorkflowExecution) -> None:
pass


def _build_form_model() -> SimpleNamespace:
Expand Down Expand Up @@ -62,3 +72,30 @@ def test_build_human_input_required_reason_falls_back_to_console_token() -> None
assert reason.node_id == "node-1"
assert reason.actions[0].id == "approve"
assert not hasattr(reason, "form_token")


def _make_run_repo() -> tuple[_ConcreteRunRepo, MagicMock]:
session = MagicMock()
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = session
return _ConcreteRunRepo(session_maker), session


def test_delete_runs_by_ids_rowcount_none_returns_zero() -> None:
repo, session = _make_run_repo()
delete_result = MagicMock()
delete_result.rowcount = None
session.execute.return_value = delete_result

assert repo.delete_runs_by_ids(["run-1", "run-2"]) == 0


def test_delete_runs_by_app_rowcount_none_returns_zero() -> None:
repo, session = _make_run_repo()
# select returns one ID (< default batch_size), loop exits after first batch
session.scalars.return_value.all.return_value = ["run-1"]
delete_result = MagicMock()
delete_result.rowcount = None
session.execute.return_value = delete_result

assert repo.delete_runs_by_app("tenant-1", "app-1") == 0
Loading