diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 6dd13b485bcbb1..b2c8bda0581c65 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -243,72 +243,71 @@ def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str, if not credential: raise NotFound("Credential not found.") exist_page_ids = [] - with sessionmaker(db.engine).begin() as session: - # import notion in the exist dataset - if query.dataset_id: - dataset = DatasetService.get_dataset(query.dataset_id) - if not dataset: - raise NotFound("Dataset not found.") - if dataset.data_source_type != "notion_import": - raise ValueError("Dataset is not notion type.") - - documents = session.scalars( - select(Document).where( - Document.dataset_id == query.dataset_id, - Document.tenant_id == current_tenant_id, - Document.data_source_type == "notion_import", - Document.enabled.is_(True), - ) - ).all() - if documents: - for document in documents: - data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info["notion_page_id"]) - # get all authorized pages - from core.datasource.datasource_manager import DatasourceManager - - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id="langgenius/notion_datasource/notion_datasource", - datasource_name="notion_datasource", - tenant_id=current_tenant_id, - datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, - ) - datasource_provider_service = DatasourceProviderService() - if credential: - datasource_runtime.runtime.credentials = credential - datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) - online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( - datasource_runtime.get_online_document_pages( - user_id=current_user.id, - datasource_parameters={}, - provider_type=datasource_runtime.datasource_provider_type(), + # import notion in the exist dataset + if query.dataset_id: + dataset = DatasetService.get_dataset(query.dataset_id, db.session) + if not dataset: + raise NotFound("Dataset not found.") + if dataset.data_source_type != "notion_import": + raise ValueError("Dataset is not notion type.") + + documents = db.session.scalars( + select(Document).where( + Document.dataset_id == query.dataset_id, + Document.tenant_id == current_tenant_id, + Document.data_source_type == "notion_import", + Document.enabled.is_(True), ) + ).all() + if documents: + for document in documents: + data_source_info = json.loads(document.data_source_info) + exist_page_ids.append(data_source_info["notion_page_id"]) + # get all authorized pages + from core.datasource.datasource_manager import DatasourceManager + + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id="langgenius/notion_datasource/notion_datasource", + datasource_name="notion_datasource", + tenant_id=current_tenant_id, + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + datasource_provider_service = DatasourceProviderService() + if credential: + datasource_runtime.runtime.credentials = credential + datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime) + online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = ( + datasource_runtime.get_online_document_pages( + user_id=current_user.id, + datasource_parameters={}, + provider_type=datasource_runtime.datasource_provider_type(), ) - try: - pages = [] - workspace_info = {} - for message in online_document_result: - result = message.result - for info in result: - workspace_info = { - "workspace_id": info.workspace_id, - "workspace_name": info.workspace_name, - "workspace_icon": info.workspace_icon, + ) + try: + pages = [] + workspace_info = {} + for message in online_document_result: + result = message.result + for info in result: + workspace_info = { + "workspace_id": info.workspace_id, + "workspace_name": info.workspace_name, + "workspace_icon": info.workspace_icon, + } + for page in info.pages: + page_info = { + "page_id": page.page_id, + "page_name": page.page_name, + "type": page.type, + "parent_id": page.parent_id, + "is_bound": page.page_id in exist_page_ids, + "page_icon": page.page_icon, } - for page in info.pages: - page_info = { - "page_id": page.page_id, - "page_name": page.page_name, - "type": page.type, - "parent_id": page.parent_id, - "is_bound": page.page_id in exist_page_ids, - "page_icon": page.page_icon, - } - pages.append(page_info) - except Exception as e: - raise e - notion_info = [{**workspace_info, "pages": pages}] if workspace_info else [] - return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200 + pages.append(page_info) + except Exception as e: + raise e + notion_info = [{**workspace_info, "pages": pages}] if workspace_info else [] + return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200 @console_ns.route("/notion/pages///preview") @@ -401,11 +400,11 @@ class DataSourceNotionDatasetSyncApi(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_CREATE_AND_MANAGEMENT) def get(self, dataset_id: UUID) -> tuple[dict[str, str], int]: dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") - documents = DocumentService.get_document_by_dataset_id(dataset_id_str) + documents = DocumentService.get_document_by_dataset_id(dataset_id_str, db.session) for document in documents: document_indexing_sync_task.delay(dataset_id_str, document.id) return {"result": "success"}, 200 @@ -421,11 +420,11 @@ class DataSourceNotionDocumentSyncApi(Resource): def get(self, dataset_id: UUID, document_id: UUID) -> tuple[dict[str, str], int]: dataset_id_str = str(dataset_id) document_id_str = str(document_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if document is None: raise NotFound("Document not found.") document_indexing_sync_task.delay(dataset_id_str, document_id_str) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 70ce54830c7c5f..dbf546532ee726 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -561,6 +561,7 @@ def post(self, current_tenant_id: str, current_user: Account): provider=payload.provider, external_knowledge_api_id=payload.external_knowledge_api_id, external_knowledge_id=payload.external_knowledge_id, + session=db.session, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -598,7 +599,7 @@ class DatasetApi(Resource): @with_current_tenant_id def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") try: @@ -618,7 +619,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): provider_id = ModelProviderID(dataset.embedding_model_provider) data["embedding_model_provider"] = str(provider_id) if data.get("permission") == "partial_members": - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) + part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str, db.session) data.update({"partial_member_list": part_users_list}) # check embedding setting @@ -661,7 +662,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_EDIT) def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") @@ -680,10 +681,10 @@ def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not dify_config.RBAC_ENABLED: DatasetPermissionService.check_permission( - current_user, dataset, payload.permission, payload.partial_member_list + current_user, dataset, payload.permission, payload.partial_member_list, db.session ) - dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user, db.session) if dataset is None: raise NotFound("Dataset not found.") @@ -698,12 +699,14 @@ def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID) tenant_id = current_tenant_id if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM: - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list) + DatasetPermissionService.update_partial_member_list( + tenant_id, dataset_id_str, payload.partial_member_list, db.session + ) # clear partial member list when permission is only_me or all_team_members elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}: - DatasetPermissionService.clear_partial_member_list(dataset_id_str) + DatasetPermissionService.clear_partial_member_list(dataset_id_str, db.session) - partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) + partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str, db.session) result_data.update({"partial_member_list": partial_member_list}) return result_data, 200 @@ -722,8 +725,8 @@ def delete(self, current_user: Account, dataset_id: UUID): raise Forbidden() try: - if DatasetService.delete_dataset(dataset_id_str, current_user): - DatasetPermissionService.clear_partial_member_list(dataset_id_str) + if DatasetService.delete_dataset(dataset_id_str, current_user, db.session): + DatasetPermissionService.clear_partial_member_list(dataset_id_str, db.session) return "", 204 else: raise NotFound("Dataset not found.") @@ -748,7 +751,7 @@ class DatasetUseCheckApi(Resource): def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) + dataset_is_using = DatasetService.dataset_use_check(dataset_id_str, db.session) return {"is_using": dataset_is_using}, 200 @@ -769,7 +772,7 @@ class DatasetQueryApi(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY) def get(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") @@ -910,7 +913,7 @@ class DatasetRelatedAppListApi(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY) def get(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") @@ -919,7 +922,7 @@ def get(self, current_user: Account, dataset_id: UUID): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - app_dataset_joins = DatasetService.get_related_apps(dataset.id) + app_dataset_joins = DatasetService.get_related_apps(dataset.id, db.session) related_apps = [] for app_dataset_join in app_dataset_joins: @@ -1094,7 +1097,7 @@ class DatasetEnableApiApi(Resource): def post(self, dataset_id: UUID, status: str): dataset_id_str = str(dataset_id) - DatasetService.update_dataset_api_status(dataset_id_str, status == "enable") + DatasetService.update_dataset_api_status(dataset_id_str, status == "enable", db.session) return {"result": "success"}, 200 @@ -1163,10 +1166,10 @@ class DatasetErrorDocs(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY) def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") - results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) + results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str, db.session) return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200 @@ -1190,7 +1193,7 @@ class DatasetPermissionUserListApi(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY) def get(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") try: @@ -1198,7 +1201,7 @@ def get(self, current_user: Account, dataset_id: UUID): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) + partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str, db.session) return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200 @@ -1220,7 +1223,8 @@ class DatasetAutoDisableLogApi(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY) def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") - return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200 + auto_disable_logs = DatasetService.get_dataset_auto_disable_logs(dataset_id_str, db.session) + return dump_response(AutoDisableLogsResponse, auto_disable_logs), 200 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 499558dc4dde67..37f64f54d24c6c 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -181,7 +181,7 @@ class DocumentResource(Resource): def get_document( self, dataset_id: str, document_id: str, current_user: Account, current_tenant_id: str ) -> Document: - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id, db.session) if not dataset: raise NotFound("Dataset not found.") @@ -190,7 +190,7 @@ def get_document( except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - document = DocumentService.get_document(dataset_id, document_id) + document = DocumentService.get_document(dataset_id, document_id, session=db.session) if not document: raise NotFound("Document not found.") @@ -201,7 +201,7 @@ def get_document( return document def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account) -> Sequence[Document]: - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id, db.session) if not dataset: raise NotFound("Dataset not found.") @@ -210,7 +210,7 @@ def get_batch_documents(self, dataset_id: str, batch: str, current_user: Account except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - documents = DocumentService.get_batch_documents(dataset_id, batch) + documents = DocumentService.get_batch_documents(dataset_id, batch, db.session) if not documents: raise NotFound("Documents not found.") @@ -241,7 +241,7 @@ def get(self, current_user: Account): # get the latest process rule document = db.get_or_404(Document, document_id) - dataset = DatasetService.get_dataset(document.dataset_id) + dataset = DatasetService.get_dataset(document.dataset_id, db.session) if not dataset: raise NotFound("Dataset not found.") @@ -317,7 +317,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): ) except (ArgumentTypeError, ValueError, Exception): fetch = False - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") @@ -421,7 +421,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): def post(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") @@ -444,8 +444,10 @@ def post(self, current_user: Account, dataset_id: UUID): DocumentService.document_create_args_validate(knowledge_config) try: - documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) - dataset = DatasetService.get_dataset(dataset_id_str) + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, knowledge_config, current_user, session=db.session + ) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -464,7 +466,7 @@ def post(self, current_user: Account, dataset_id: UUID): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_EDIT) def delete(self, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") # check user's model setting @@ -472,7 +474,7 @@ def delete(self, dataset_id: UUID): try: document_ids = request.args.getlist("document_id") - DocumentService.delete_documents(dataset, document_ids) + DocumentService.delete_documents(dataset, document_ids, db.session) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") @@ -531,6 +533,7 @@ def post(self, current_tenant_id: str, current_user: Account): tenant_id=current_tenant_id, knowledge_config=knowledge_config, account=current_user, + session=db.session, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -867,7 +870,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, d if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": - dataset_process_rules = DatasetService.get_process_rules(dataset_id_str) + dataset_process_rules = DatasetService.get_process_rules(dataset_id_str, db.session) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} response = { "id": document.id, @@ -901,7 +904,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, d "need_summary": document.need_summary if document.need_summary is not None else False, } else: - dataset_process_rules = DatasetService.get_process_rules(dataset_id_str) + dataset_process_rules = DatasetService.get_process_rules(dataset_id_str, db.session) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} response = { "id": document.id, @@ -950,7 +953,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, d def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") # check user's model setting @@ -959,7 +962,7 @@ def delete(self, current_tenant_id: str, current_user: Account, dataset_id: UUID document = self.get_document(dataset_id_str, document_id_str, current_user, current_tenant_id) try: - DocumentService.delete_document(document) + DocumentService.delete_document(document, db.session) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") @@ -983,7 +986,7 @@ class DocumentDownloadApi(DocumentResource): def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID) -> dict[str, Any]: # Reuse the shared permission/tenant checks implemented in DocumentResource. document = self.get_document(str(dataset_id), str(document_id), current_user, current_tenant_id) - return {"url": DocumentService.get_document_download_url(document)} + return {"url": DocumentService.get_document_download_url(document, db.session)} @console_ns.route("/datasets//documents/download-zip") @@ -1013,6 +1016,7 @@ def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): document_ids=document_ids, tenant_id=current_tenant_id, current_user=current_user, + session=db.session, ) # Delegate ZIP packing to FileService, but keep Flask response+cleanup in the route. @@ -1161,7 +1165,7 @@ def patch( self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable", "archive", "un_archive"] ): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") @@ -1178,7 +1182,7 @@ def patch( document_ids = request.args.getlist("document_id") try: - DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) + DocumentService.batch_update_document_status(dataset, document_ids, action, current_user, db.session) except services.errors.document.DocumentIndexingError as e: raise InvalidActionError(str(e)) except ValueError as e: @@ -1202,11 +1206,11 @@ def patch(self, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) # 404 if document not found if document is None: @@ -1218,7 +1222,7 @@ def patch(self, dataset_id: UUID, document_id: UUID): try: # pause document - DocumentService.pause_document(document) + DocumentService.pause_document(document, db.session) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot pause completed document.") @@ -1237,10 +1241,10 @@ def patch(self, dataset_id: UUID, document_id: UUID): """recover document.""" dataset_id_str = str(dataset_id) document_id_str = str(document_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) # 404 if document not found if document is None: @@ -1251,7 +1255,7 @@ def patch(self, dataset_id: UUID, document_id: UUID): raise ArchivedDocumentImmutableError() try: # pause document - DocumentService.recover_document(document) + DocumentService.recover_document(document, db.session) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Document is not in paused status.") @@ -1271,13 +1275,13 @@ def post(self, dataset_id: UUID): """retry document.""" payload = DocumentRetryPayload.model_validate(console_ns.payload or {}) dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) retry_documents = [] if not dataset: raise NotFound("Dataset not found.") for document_id in payload.document_ids: try: - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id, session=db.session) # 404 if document not found if document is None: @@ -1295,7 +1299,7 @@ def post(self, dataset_id: UUID): logger.exception("Failed to retry document, document id: %s", document_id) continue # retry document - DocumentService.retry_document(dataset_id_str, retry_documents) + DocumentService.retry_document(dataset_id_str, retry_documents, db.session) return "", 204 @@ -1313,14 +1317,14 @@ def post(self, current_user: Account, dataset_id: UUID, document_id: UUID): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id, db.session) if not dataset: raise NotFound("Dataset not found.") - DatasetService.check_dataset_operator_permission(current_user, dataset) + DatasetService.check_dataset_operator_permission(current_user, dataset, session=db.session) payload = DocumentRenamePayload.model_validate(console_ns.payload or {}) try: - document = DocumentService.rename_document(str(dataset_id), str(document_id), payload.name) + document = DocumentService.rename_document(str(dataset_id), str(document_id), payload.name, db.session) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") @@ -1338,11 +1342,11 @@ class WebsiteDocumentSyncApi(DocumentResource): def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID): """sync website document.""" dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") document_id_str = str(document_id) - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") if document.tenant_id != current_tenant_id: @@ -1353,7 +1357,7 @@ def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID): if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() # sync document - DocumentService.sync_website_document(dataset_id_str, document) + DocumentService.sync_website_document(dataset_id_str, document, db.session) return {"result": "success"}, 200 @@ -1373,10 +1377,10 @@ def get(self, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") log = db.session.scalar( @@ -1431,7 +1435,7 @@ def post(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) # Get dataset - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") @@ -1465,7 +1469,7 @@ def post(self, current_user: Account, dataset_id: UUID): raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.") # Verify all documents exist and belong to the dataset - documents = DocumentService.get_documents_by_ids(dataset_id_str, document_list) + documents = DocumentService.get_documents_by_ids(dataset_id_str, document_list, db.session) if len(documents) != len(document_list): found_ids = {doc.id for doc in documents} @@ -1481,6 +1485,7 @@ def post(self, current_user: Account, dataset_id: UUID): DocumentService.update_documents_need_summary( dataset_id=dataset_id_str, document_ids=document_ids_to_update, + session=db.session, need_summary=True, ) @@ -1531,7 +1536,7 @@ def get(self, current_user: Account, dataset_id: UUID, document_id: UUID): document_id_str = str(document_id) # Get dataset - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") @@ -1547,6 +1552,7 @@ def get(self, current_user: Account, dataset_id: UUID, document_id: UUID): result = SummaryIndexService.get_document_summary_status_detail( document_id=document_id_str, dataset_id=dataset_id_str, + session=db.session, ) return result, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5ba115ff491ce4..1a6f4c7a71244e 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -176,7 +176,7 @@ class DatasetDocumentSegmentListApi(Resource): def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): dataset_id_str = str(dataset_id) document_id_str = str(document_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") @@ -185,7 +185,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, d except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") @@ -286,14 +286,14 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, d def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") segment_ids = request.args.getlist("segment_id") @@ -305,7 +305,7 @@ def delete(self, current_user: Account, dataset_id: UUID, document_id: UUID): DatasetService.check_dataset_permission(dataset, current_user, db.session) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - SegmentService.delete_segments(segment_ids, document, dataset) + SegmentService.delete_segments(segment_ids, document, dataset, db.session) return "", 204 @@ -331,11 +331,11 @@ def patch( action: Literal["enable", "disable"], ): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") # check user's model setting @@ -371,7 +371,7 @@ def patch( if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") try: - SegmentService.update_segments_status(segment_ids, action, dataset, document) + SegmentService.update_segments_status(segment_ids, action, dataset, document, db.session) except Exception as e: raise InvalidActionError(str(e)) return dump_response(SimpleResultResponse, {"result": "success"}), 200 @@ -394,12 +394,12 @@ class DatasetDocumentSegmentAddApi(Resource): def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") if not current_user.is_dataset_editor: @@ -428,7 +428,7 @@ def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, payload = SegmentCreatePayload.model_validate(console_ns.payload or {}) payload_dict = payload.model_dump(exclude_none=True) SegmentService.segment_create_args_validate(payload_dict, document) - segment = type_cast(DocumentSegment, SegmentService.create_segment(payload_dict, document, dataset)) + segment = type_cast(DocumentSegment, SegmentService.create_segment(payload_dict, document, dataset, db.session)) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str) response = { "data": segment_response_with_summary(segment, summary.summary_content if summary else None), @@ -455,14 +455,14 @@ def patch( ): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: @@ -504,7 +504,11 @@ def patch( # Update segment (summary update with change detection is handled in SegmentService.update_segment) segment = SegmentService.update_segment( - SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset + SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), + segment, + document, + dataset, + db.session, ) summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id_str) response = { @@ -527,14 +531,14 @@ def delete( ): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") # check segment @@ -553,7 +557,7 @@ def delete( DatasetService.check_dataset_permission(dataset, current_user, db.session) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - SegmentService.delete_segment(segment, document, dataset) + SegmentService.delete_segment(segment, document, dataset, db.session) return "", 204 @@ -576,12 +580,12 @@ class DatasetDocumentSegmentBatchImportApi(Resource): def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, document_id: UUID): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") @@ -651,12 +655,12 @@ def post( ): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") # check segment @@ -693,7 +697,7 @@ def post( # validate args try: payload = ChildChunkCreatePayload.model_validate(console_ns.payload or {}) - child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset) + child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset, db.session) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200 @@ -709,14 +713,14 @@ def post( def get(self, current_tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: UUID): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") # check segment @@ -766,14 +770,14 @@ def patch( ): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") # check segment @@ -795,7 +799,7 @@ def patch( # validate args payload = ChildChunkBatchUpdatePayload.model_validate(console_ns.payload or {}) try: - child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset) + child_chunks = SegmentService.update_child_chunks(payload.chunks, segment, document, dataset, db.session) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return dump_response(ChildChunkBatchUpdateResponse, {"data": child_chunks}), 200 @@ -825,14 +829,14 @@ def delete( ): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") # check segment @@ -866,7 +870,7 @@ def delete( except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) try: - SegmentService.delete_child_chunk(child_chunk, dataset) + SegmentService.delete_child_chunk(child_chunk, dataset, db.session) except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) return "", 204 @@ -893,14 +897,14 @@ def patch( ): # check dataset dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if not dataset: raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id_str = str(document_id) - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") # check segment @@ -936,7 +940,9 @@ def patch( # validate args try: payload = ChildChunkUpdatePayload.model_validate(console_ns.payload or {}) - child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset) + child_chunk = SegmentService.update_child_chunk( + payload.content, child_chunk, segment, document, dataset, db.session + ) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) return dump_response(ChildChunkDetailResponse, {"data": child_chunk}), 200 diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 99a61807a4d634..7a3c746b80fe85 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -377,7 +377,7 @@ class ExternalKnowledgeHitTestingApi(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_PIPELINE_TEST) def post(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index 82c30fc7ffbedf..6464e435fc27b9 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -85,7 +85,7 @@ def get_and_validate_dataset( dataset_id: str, current_user: Account | None = None, current_tenant_id: str | None = None ) -> Dataset: current_user, _ = resolve_account_fallback(current_user, current_tenant_id) - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id, db.session) if dataset is None: raise NotFound("Dataset not found.") diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 90ce263dfe534a..8802fcf2814c6c 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -61,7 +61,7 @@ def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): metadata_args = MetadataArgs.model_validate(console_ns.payload or {}) dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) @@ -81,7 +81,7 @@ def post(self, current_tenant_id: str, current_user: Account, dataset_id: UUID): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_CREATE_AND_MANAGEMENT) def get(self, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") metadata = MetadataService.get_dataset_metadatas(db.session(), dataset) @@ -105,7 +105,7 @@ def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) @@ -125,7 +125,7 @@ def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID, def delete(self, current_user: Account, dataset_id: UUID, metadata_id: UUID): dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) @@ -162,7 +162,7 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_EDIT) def post(self, current_user: Account, dataset_id: UUID, action: Literal["enable", "disable"]): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) @@ -191,7 +191,7 @@ class DocumentMetadataEditApi(Resource): @rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_EDIT) def post(self, current_user: Account, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py index 0f277b6a4cc256..a373c8b1a41a6b 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_datasets.py @@ -1,6 +1,5 @@ from flask_restx import Resource from pydantic import BaseModel -from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden import services @@ -66,19 +65,19 @@ def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWit yaml_content=payload.yaml_content, ) try: - with Session(db.engine, expire_on_commit=False) as session: - rag_pipeline_dsl_service = RagPipelineDslService(session) - import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( - tenant_id=current_tenant_id, - rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, - ) - session.commit() + rag_pipeline_dsl_service = RagPipelineDslService(db.session) + import_info = rag_pipeline_dsl_service.create_rag_pipeline_dataset( + tenant_id=current_tenant_id, + rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity, + ) if rag_pipeline_dataset_create_entity.permission == "partial_members": DatasetPermissionService.update_partial_member_list( current_tenant_id, import_info["dataset_id"], rag_pipeline_dataset_create_entity.partial_member_list, + db.session, ) + db.session.commit() except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -111,5 +110,6 @@ def post(self, current_tenant_id: str, current_user: Account) -> JsonResponseWit permission=DatasetPermissionEnum.ONLY_ME, partial_member_list=None, ), + session=db.session, ) return dump_response(DatasetDetailResponse, dataset), 201 diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index bfb7a045082712..e903e92e7a63b8 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -519,6 +519,7 @@ def post(self, tenant_id): embedding_model_name=payload.embedding_model, retrieval_model=payload.retrieval_model, summary_index_setting=payload.summary_index_setting, + session=db.session, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -561,7 +562,7 @@ class DatasetApi(DatasetApiResource): ) def get(self, _, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") try: @@ -597,7 +598,7 @@ def get(self, _, dataset_id: UUID): retrieval_model_dict["search_method"] = "keyword_search" if data.get("permission") == "partial_members": - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) + part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str, db.session) data.update({"partial_member_list": part_users_list}) return _dump_service_dataset_with_partial_members(data), 200 @@ -635,7 +636,7 @@ def get(self, _, dataset_id: UUID): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, _, dataset_id: UUID): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") @@ -676,9 +677,10 @@ def patch(self, _, dataset_id: UUID): dataset, str(payload.permission) if payload.permission else None, payload.partial_member_list, + db.session, ) - dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user, db.session) if dataset is None: raise NotFound("Dataset not found.") @@ -688,12 +690,14 @@ def patch(self, _, dataset_id: UUID): tenant_id = current_user.current_tenant_id if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM: - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list) + DatasetPermissionService.update_partial_member_list( + tenant_id, dataset_id_str, payload.partial_member_list, db.session + ) # clear partial member list when permission is only_me or all_team_members elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}: - DatasetPermissionService.clear_partial_member_list(dataset_id_str) + DatasetPermissionService.clear_partial_member_list(dataset_id_str, db.session) - partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) + partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str, db.session) result_data.update({"partial_member_list": partial_member_list}) return _dump_service_dataset_with_partial_members(result_data), 200 @@ -746,8 +750,8 @@ def delete(self, _, dataset_id: UUID): dataset_id_str = str(dataset_id) try: - if DatasetService.delete_dataset(dataset_id_str, current_user): - DatasetPermissionService.clear_partial_member_list(dataset_id_str) + if DatasetService.delete_dataset(dataset_id_str, current_user, db.session): + DatasetPermissionService.clear_partial_member_list(dataset_id_str, db.session) return "", 204 else: raise NotFound("Dataset not found.") @@ -812,7 +816,7 @@ def patch(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable" InvalidActionError: If the action is invalid or cannot be performed. """ dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") @@ -831,7 +835,7 @@ def patch(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable" document_ids = data.get("document_ids", []) try: - DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) + DocumentService.batch_update_document_status(dataset, document_ids, action, current_user, db.session) except services.errors.document.DocumentIndexingError as e: raise InvalidActionError(str(e)) except ValueError as e: diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 9bae862814a498..49ccb1bd55c156 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -400,6 +400,7 @@ def _create_document_by_text(tenant_id: str, dataset_id: UUID) -> tuple[Mapping[ account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", + session=db.session, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -459,6 +460,7 @@ def _update_document_by_text(tenant_id: str, dataset_id: UUID, document_id: UUID account=current_user, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", + session=db.session, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -756,6 +758,7 @@ def post(self, tenant_id, dataset_id: UUID): account=dataset.created_by_account, dataset_process_rule=dataset_process_rule, created_from="api", + session=db.session, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -832,6 +835,7 @@ def _update_document_by_file(tenant_id: str, dataset_id: UUID, document_id: UUID account=dataset.created_by_account, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, created_from="api", + session=db.session, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -1002,6 +1006,7 @@ def post(self, tenant_id, dataset_id: UUID): document_ids=[str(document_id) for document_id in payload.document_ids], tenant_id=str(tenant_id), current_user=current_user, + session=db.session, ) with ExitStack() as stack: @@ -1058,7 +1063,7 @@ def get(self, tenant_id, dataset_id: UUID, batch: str): if not dataset: raise NotFound("Dataset not found.") # get documents - documents = DocumentService.get_batch_documents(dataset_id_str, batch) + documents = DocumentService.get_batch_documents(dataset_id_str, batch, db.session) if not documents: raise NotFound("Documents not found.") documents_status = [] @@ -1134,7 +1139,7 @@ class DocumentDownloadApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def get(self, tenant_id, dataset_id: UUID, document_id: UUID): dataset = self.get_dataset(str(dataset_id), str(tenant_id)) - document = DocumentService.get_document(dataset.id, str(document_id)) + document = DocumentService.get_document(dataset.id, str(document_id), session=db.session) if not document: raise NotFound("Document not found.") @@ -1142,7 +1147,7 @@ def get(self, tenant_id, dataset_id: UUID, document_id: UUID): if document.tenant_id != str(tenant_id): raise Forbidden("No permission.") - return {"url": DocumentService.get_document_download_url(document)} + return {"url": DocumentService.get_document_download_url(document, db.session)} @service_api_ns.route("/datasets//documents/") @@ -1190,7 +1195,7 @@ def get(self, tenant_id, dataset_id: UUID, document_id: UUID): dataset = self.get_dataset(dataset_id_str, tenant_id) - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") @@ -1215,7 +1220,7 @@ def get(self, tenant_id, dataset_id: UUID, document_id: UUID): if metadata == "only": response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": - dataset_process_rules = DatasetService.get_process_rules(dataset_id_str) + dataset_process_rules = DatasetService.get_process_rules(dataset_id_str, db.session) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { @@ -1250,7 +1255,7 @@ def get(self, tenant_id, dataset_id: UUID, document_id: UUID): "need_summary": document.need_summary if document.need_summary is not None else False, } else: - dataset_process_rules = DatasetService.get_process_rules(dataset_id_str) + dataset_process_rules = DatasetService.get_process_rules(dataset_id_str, db.session) document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { @@ -1345,7 +1350,7 @@ def delete(self, tenant_id, dataset_id: UUID, document_id: UUID): if not dataset: raise ValueError("Dataset does not exist.") - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) # 404 if document not found if document is None: @@ -1357,7 +1362,7 @@ def delete(self, tenant_id, dataset_id: UUID, document_id: UUID): try: # delete document - DocumentService.delete_document(document) + DocumentService.delete_document(document, db.session) except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 3bb39f0cd4f504..aec3b06a91eed9 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -81,7 +81,7 @@ def post(self, tenant_id, dataset_id: UUID): metadata_args = MetadataArgs.model_validate(service_api_ns.payload or {}) dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) @@ -116,7 +116,7 @@ def post(self, tenant_id, dataset_id: UUID): def get(self, tenant_id, dataset_id: UUID): """Get all metadata for a dataset.""" dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") metadata = MetadataService.get_dataset_metadatas(db.session(), dataset) @@ -154,7 +154,7 @@ def patch(self, tenant_id, dataset_id: UUID, metadata_id: UUID): dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) @@ -189,7 +189,7 @@ def delete(self, tenant_id, dataset_id: UUID, metadata_id: UUID): """Delete metadata.""" dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) @@ -257,7 +257,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): def post(self, tenant_id, dataset_id: UUID, action: Literal["enable", "disable"]): """Enable or disable built-in metadata field.""" dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) @@ -303,7 +303,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): def post(self, tenant_id, dataset_id: UUID): """Update metadata for multiple documents.""" dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) + dataset = DatasetService.get_dataset(dataset_id_str, db.session) if dataset is None: raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user, db.session) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index dd8d7c76632c7e..7b0e31952c7d42 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -175,7 +175,7 @@ def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): raise NotFound("Dataset not found.") document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") if document.indexing_status != "completed": @@ -210,7 +210,9 @@ def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): for args_item in segment_items: SegmentService.segment_create_args_validate(args_item, document) - segments = cast(list[DocumentSegment], SegmentService.multi_create_segment(segment_items, document, dataset)) + segments = cast( + list[DocumentSegment], SegmentService.multi_create_segment(segment_items, document, dataset, db.session) + ) segment_ids = [segment.id for segment in segments] summaries: dict[str, str | None] = {} if segment_ids: @@ -267,7 +269,7 @@ def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID): raise NotFound("Dataset not found.") document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") # check embedding model setting @@ -349,15 +351,17 @@ def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id DatasetService.check_dataset_model_setting(dataset) document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id( + segment_id=segment_id_str, tenant_id=current_tenant_id, session=db.session + ) if not segment: raise NotFound("Segment not found.") - SegmentService.delete_segment(segment, document, dataset) + SegmentService.delete_segment(segment, document, dataset, db.session) return "", 204 @service_api_ns.doc( @@ -395,7 +399,7 @@ def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: DatasetService.check_dataset_model_setting(dataset) document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: @@ -416,13 +420,15 @@ def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: raise ProviderNotInitializeError(ex.description) segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id( + segment_id=segment_id_str, tenant_id=current_tenant_id, session=db.session + ) if not segment: raise NotFound("Segment not found.") payload = SegmentUpdatePayload.model_validate(service_api_ns.payload or {}) - updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset) + updated_segment = SegmentService.update_segment(payload.segment, segment, document, dataset, db.session) summary = SummaryIndexService.get_segment_summary(segment_id=updated_segment.id, dataset_id=dataset_id_str) response = { "data": segment_response_with_summary(updated_segment, summary.summary_content if summary else None), @@ -469,12 +475,14 @@ def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: U DatasetService.check_dataset_model_setting(dataset) document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id( + segment_id=segment_id_str, tenant_id=current_tenant_id, session=db.session + ) if not segment: raise NotFound("Segment not found.") @@ -533,13 +541,15 @@ def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id( + segment_id=segment_id_str, tenant_id=current_tenant_id, session=db.session + ) if not segment: raise NotFound("Segment not found.") @@ -564,7 +574,7 @@ def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: payload = ChildChunkCreatePayload.model_validate(service_api_ns.payload or {}) try: - child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset) + child_chunk = SegmentService.create_child_chunk(payload.content, segment, document, dataset, db.session) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) @@ -607,13 +617,15 @@ def get(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: U document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id( + segment_id=segment_id_str, tenant_id=current_tenant_id, session=db.session + ) if not segment: raise NotFound("Segment not found.") @@ -677,13 +689,15 @@ def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id document_id_str = str(document_id) # check document - document = DocumentService.get_document(dataset.id, document_id_str) + document = DocumentService.get_document(dataset.id, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") segment_id_str = str(segment_id) # check segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id( + segment_id=segment_id_str, tenant_id=current_tenant_id, session=db.session + ) if not segment: raise NotFound("Segment not found.") @@ -694,7 +708,7 @@ def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id child_chunk_id_str = str(child_chunk_id) # check child chunk child_chunk = SegmentService.get_child_chunk_by_id( - child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id + child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id, session=db.session ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -704,7 +718,7 @@ def delete(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id raise NotFound("Child chunk not found.") try: - SegmentService.delete_child_chunk(child_chunk, dataset) + SegmentService.delete_child_chunk(child_chunk, dataset, db.session) except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) @@ -751,13 +765,15 @@ def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: document_id_str = str(document_id) # get document - document = DocumentService.get_document(dataset_id_str, document_id_str) + document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session) if not document: raise NotFound("Document not found.") segment_id_str = str(segment_id) # get segment - segment = SegmentService.get_segment_by_id(segment_id=segment_id_str, tenant_id=current_tenant_id) + segment = SegmentService.get_segment_by_id( + segment_id=segment_id_str, tenant_id=current_tenant_id, session=db.session + ) if not segment: raise NotFound("Segment not found.") @@ -768,7 +784,7 @@ def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: child_chunk_id_str = str(child_chunk_id) # get child chunk child_chunk = SegmentService.get_child_chunk_by_id( - child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id + child_chunk_id=child_chunk_id_str, tenant_id=current_tenant_id, session=db.session ) if not child_chunk: raise NotFound("Child chunk not found.") @@ -781,7 +797,9 @@ def patch(self, tenant_id: str, dataset_id: UUID, document_id: UUID, segment_id: payload = ChildChunkUpdatePayload.model_validate(service_api_ns.payload or {}) try: - child_chunk = SegmentService.update_child_chunk(payload.content, child_chunk, segment, document, dataset) + child_chunk = SegmentService.update_child_chunk( + payload.content, child_chunk, segment, document, dataset, db.session + ) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index be538455afb34e..140d4e6a2a66fc 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -9,6 +9,7 @@ ) from core.entities.agent_entities import PlanningStrategy from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from extensions.ext_database import db from models.model import AppMode, AppModelConfigDict from services.dataset_service import DatasetService @@ -256,7 +257,7 @@ def extract_dataset_config_for_legacy_compatibility(cls, tenant_id: str, app_mod @classmethod def is_dataset_exists(cls, tenant_id: str, dataset_id: str) -> bool: # verify if the dataset ID exists - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id, db.session) if not dataset: return False diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 255740b86a1820..cafc95d035ad8e 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -144,7 +144,7 @@ def generate( DocumentService.check_document_creation_limits(len(datasource_info_list), features) for datasource_info in datasource_info_list: - position = DocumentService.get_documents_position(dataset.id) + position = DocumentService.get_documents_position(dataset.id, session) document = self._build_document( tenant_id=pipeline.tenant_id, dataset_id=dataset.id, diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 0bd904811a0c8d..520ba7b85b33d3 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -45,7 +45,7 @@ def query( embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION + embedding_provider_name, embedding_model_name, db.session, CollectionBindingType.ANNOTATION ) dataset = Dataset( diff --git a/api/models/dataset.py b/api/models/dataset.py index 998bc02ee85bb4..58fec01d675265 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -14,7 +14,7 @@ import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column +from sqlalchemy.orm import Mapped, Session, mapped_column, scoped_session from configs import dify_config from core.rag.entities import ParentMode, Rule @@ -1670,7 +1670,7 @@ class Pipeline(TypeBase): init=False, ) - def retrieve_dataset(self, session: Session): + def retrieve_dataset(self, session: Session | scoped_session): return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id)) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 2dd8c533828162..4dbcf372bb004e 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -13,11 +13,10 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator from redis.exceptions import LockNotOwnedError from sqlalchemy import ColumnElement, delete, exists, func, select, update -from sqlalchemy.orm import Session, scoped_session, sessionmaker +from sqlalchemy.orm import Session, scoped_session from werkzeug.exceptions import Forbidden, NotFound from configs import dify_config -from core.db.session_factory import session_factory from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager @@ -109,6 +108,13 @@ logger = logging.getLogger(__name__) +def _session_for_helpers(session: scoped_session | Session) -> Session: + """Return a concrete SQLAlchemy session for helpers that do not accept scoped_session.""" + if isinstance(session, scoped_session): + return session() + return session + + class ProcessRulesDict(TypedDict): mode: ProcessRuleMode rules: dict[str, Any] @@ -353,9 +359,9 @@ def get_datasets( return datasets.items, datasets.total @staticmethod - def get_process_rules(dataset_id) -> ProcessRulesDict: + def get_process_rules(dataset_id, session: scoped_session | Session) -> ProcessRulesDict: # get the latest process rule - dataset_process_rule = db.session.execute( + dataset_process_rule = session.execute( select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == dataset_id) .order_by(DatasetProcessRule.created_at.desc()) @@ -411,9 +417,11 @@ def create_empty_dataset( embedding_model_name: str | None = None, retrieval_model: RetrievalModel | None = None, summary_index_setting: dict[str, Any] | None = None, + *, + session: scoped_session | Session, ): # check if dataset name already exists - if db.session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)): + if session.scalar(select(Dataset).where(Dataset.name == name, Dataset.tenant_id == tenant_id).limit(1)): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None if indexing_technique == IndexTechniqueType.HIGH_QUALITY: @@ -459,8 +467,8 @@ def create_empty_dataset( dataset.provider = provider if summary_index_setting is not None: dataset.summary_index_setting = summary_index_setting - db.session.add(dataset) - db.session.flush() + session.add(dataset) + session.flush() if provider == "external" and external_knowledge_api_id: external_knowledge_api = ExternalDatasetService.get_external_knowledge_api( @@ -477,9 +485,9 @@ def create_empty_dataset( external_knowledge_id=external_knowledge_id, created_by=account.id, ) - db.session.add(external_knowledge_binding) + session.add(external_knowledge_binding) - db.session.commit() + session.commit() enterprise_rbac_service.try_sync_creator_access_policy_member_bindings( tenant_id, account.id, @@ -492,10 +500,11 @@ def create_empty_dataset( def create_empty_rag_pipeline_dataset( tenant_id: str, rag_pipeline_dataset_create_entity: RagPipelineDatasetCreateEntity, + session: scoped_session | Session, ): if rag_pipeline_dataset_create_entity.name: # check if dataset name already exists - if db.session.scalar( + if session.scalar( select(Dataset) .where(Dataset.name == rag_pipeline_dataset_create_entity.name, Dataset.tenant_id == tenant_id) .limit(1) @@ -505,7 +514,7 @@ def create_empty_rag_pipeline_dataset( ) else: # generate a random name as Untitled 1 2 3 ... - datasets = db.session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all() + datasets = session.scalars(select(Dataset).where(Dataset.tenant_id == tenant_id)).all() names = [dataset.name for dataset in datasets] rag_pipeline_dataset_create_entity.name = generate_incremental_name( names, @@ -519,8 +528,8 @@ def create_empty_rag_pipeline_dataset( description=rag_pipeline_dataset_create_entity.description, created_by=current_user.id, ) - db.session.add(pipeline) - db.session.flush() + session.add(pipeline) + session.flush() dataset = Dataset( tenant_id=tenant_id, @@ -534,13 +543,13 @@ def create_empty_rag_pipeline_dataset( maintainer=current_user.id, pipeline_id=pipeline.id, ) - db.session.add(dataset) - db.session.commit() + session.add(dataset) + session.commit() return dataset @staticmethod - def get_dataset(dataset_id) -> Dataset | None: - dataset: Dataset | None = db.session.get(Dataset, dataset_id) + def get_dataset(dataset_id, session: scoped_session | Session) -> Dataset | None: + dataset: Dataset | None = session.get(Dataset, dataset_id) return dataset @staticmethod @@ -622,7 +631,7 @@ def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, raise ValueError(ex.description) @staticmethod - def update_dataset(dataset_id, data, user): + def update_dataset(dataset_id, data, user, session: scoped_session | Session): """ Update dataset configuration and settings. @@ -639,7 +648,7 @@ def update_dataset(dataset_id, data, user): NoPermissionError: If user lacks permission to update the dataset """ # Retrieve and validate dataset existence - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id, session) if not dataset: raise ValueError("Dataset not found") # check if dataset name is exists @@ -648,21 +657,22 @@ def update_dataset(dataset_id, data, user): tenant_id=dataset.tenant_id, dataset_id=dataset_id, name=data.get("name", dataset.name), + session=session, ): raise ValueError("Dataset name already exists") # Verify user has permission to update this dataset - DatasetService.check_dataset_permission(dataset, user, db.session) + DatasetService.check_dataset_permission(dataset, user, session) # Handle external dataset updates if dataset.provider == "external": - return DatasetService._update_external_dataset(dataset, data, user) + return DatasetService._update_external_dataset(dataset, data, user, session) else: - return DatasetService._update_internal_dataset(dataset, data, user) + return DatasetService._update_internal_dataset(dataset, data, user, session) @staticmethod - def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str): - dataset = db.session.scalar( + def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str, session: scoped_session | Session): + dataset = session.scalar( select(Dataset) .where( Dataset.id != dataset_id, @@ -674,7 +684,7 @@ def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str): return dataset is not None @staticmethod - def _update_external_dataset(dataset, data, user): + def _update_external_dataset(dataset, data, user, session: scoped_session | Session): """ Update external dataset configuration. @@ -718,18 +728,22 @@ def _update_external_dataset(dataset, data, user): # Update metadata fields dataset.updated_by = user.id if user else None dataset.updated_at = naive_utc_now() - db.session.add(dataset) + session.add(dataset) # Update external knowledge binding - DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id) + DatasetService._update_external_knowledge_binding( + dataset.id, external_knowledge_id, external_knowledge_api_id, session + ) # Commit changes to database - db.session.commit() + session.commit() return dataset @staticmethod - def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): + def _update_external_knowledge_binding( + dataset_id, external_knowledge_id, external_knowledge_api_id, session: scoped_session | Session + ): """ Update external knowledge binding configuration. @@ -738,25 +752,24 @@ def _update_external_knowledge_binding(dataset_id, external_knowledge_id, extern external_knowledge_id: External knowledge identifier external_knowledge_api_id: External knowledge API identifier """ - with sessionmaker(db.engine).begin() as session: - external_knowledge_binding = session.scalar( - select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1) - ) + external_knowledge_binding = session.scalar( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == dataset_id).limit(1) + ) - if not external_knowledge_binding: - raise ValueError("External knowledge binding not found.") + if not external_knowledge_binding: + raise ValueError("External knowledge binding not found.") - # Update binding if values have changed - if ( - external_knowledge_binding.external_knowledge_id != external_knowledge_id - or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id - ): - external_knowledge_binding.external_knowledge_id = external_knowledge_id - external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id - session.add(external_knowledge_binding) + # Update binding if values have changed + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + session.add(external_knowledge_binding) @staticmethod - def _update_internal_dataset(dataset, data, user): + def _update_internal_dataset(dataset, data, user, session: scoped_session | Session): """ Update internal dataset configuration. @@ -778,7 +791,7 @@ def _update_internal_dataset(dataset, data, user): filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} # Handle indexing technique changes and embedding model updates - action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data) + action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data, session) # Add metadata fields filtered_data["updated_by"] = user.id @@ -794,14 +807,14 @@ def _update_internal_dataset(dataset, data, user): filtered_data["icon_info"] = data.get("icon_info") # Update dataset in database - db.session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data)) - db.session.commit() + session.execute(update(Dataset).where(Dataset.id == dataset.id).values(**filtered_data)) + session.commit() # Reload dataset to get updated values - db.session.refresh(dataset) + session.refresh(dataset) # update pipeline knowledge base node data - DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id) + DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id, session) # Trigger vector index task if indexing technique changed if action: @@ -822,14 +835,16 @@ def _update_internal_dataset(dataset, data, user): return dataset @staticmethod - def _update_pipeline_knowledge_base_node_data(dataset: Dataset, updata_user_id: str): + def _update_pipeline_knowledge_base_node_data( + dataset: Dataset, updata_user_id: str, session: scoped_session | Session + ): """ Update pipeline knowledge base node data. """ if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE: return - pipeline = db.session.get(Pipeline, dataset.pipeline_id) + pipeline = session.get(Pipeline, dataset.pipeline_id) if not pipeline: return @@ -887,25 +902,25 @@ def update_knowledge_nodes(workflow_graph: str) -> str: marked_name="", marked_comment="", ) - db.session.add(workflow) + session.add(workflow) # Update draft workflow if draft_workflow: updated_graph = update_knowledge_nodes(draft_workflow.graph) if updated_graph != draft_workflow.graph: draft_workflow.graph = updated_graph - db.session.add(draft_workflow) + session.add(draft_workflow) # Commit all changes in one transaction - db.session.commit() + session.commit() except Exception: logging.exception("Failed to update pipeline knowledge base node data") - db.session.rollback() + session.rollback() raise @staticmethod - def _handle_indexing_technique_change(dataset, data, filtered_data): + def _handle_indexing_technique_change(dataset, data, filtered_data, session: scoped_session | Session): """ Handle changes in indexing technique and configure embedding models accordingly. @@ -913,6 +928,7 @@ def _handle_indexing_technique_change(dataset, data, filtered_data): dataset: Current dataset object data: Update data dictionary filtered_data: Filtered update data + session: SQLAlchemy session used for embedding collection binding lookups Returns: str: Action to perform ('add', 'remove', 'update', or None) @@ -928,21 +944,24 @@ def _handle_indexing_technique_change(dataset, data, filtered_data): return "remove" elif data["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: # Configure embedding model for high quality mode - DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) + DatasetService._configure_embedding_model_for_high_quality(data, filtered_data, session) return "add" else: # Handle embedding model updates when indexing technique remains the same - return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data) + return DatasetService._handle_embedding_model_update_when_technique_unchanged( + dataset, data, filtered_data, session + ) return None @staticmethod - def _configure_embedding_model_for_high_quality(data, filtered_data): + def _configure_embedding_model_for_high_quality(data, filtered_data, session: scoped_session | Session): """ Configure embedding model settings for high quality indexing. Args: data: Update data dictionary filtered_data: Filtered update data to modify + session: SQLAlchemy session used for embedding collection binding lookups """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None try: @@ -961,6 +980,7 @@ def _configure_embedding_model_for_high_quality(data, filtered_data): dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model_name, + session, ) filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: @@ -971,7 +991,9 @@ def _configure_embedding_model_for_high_quality(data, filtered_data): raise ValueError(ex.description) @staticmethod - def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): + def _handle_embedding_model_update_when_technique_unchanged( + dataset, data, filtered_data, session: scoped_session | Session + ): """ Handle embedding model updates when indexing technique remains the same. @@ -979,6 +1001,7 @@ def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filte dataset: Current dataset object data: Update data dictionary filtered_data: Filtered update data to modify + session: SQLAlchemy session used for embedding collection binding lookups Returns: str: Action to perform ('update' or None) @@ -993,7 +1016,7 @@ def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filte DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) return None else: - return DatasetService._update_embedding_model_settings(dataset, data, filtered_data) + return DatasetService._update_embedding_model_settings(dataset, data, filtered_data, session) @staticmethod def _preserve_existing_embedding_settings(dataset, filtered_data): @@ -1019,7 +1042,7 @@ def _preserve_existing_embedding_settings(dataset, filtered_data): del filtered_data["embedding_model"] @staticmethod - def _update_embedding_model_settings(dataset, data, filtered_data): + def _update_embedding_model_settings(dataset, data, filtered_data, session: scoped_session | Session): """ Update embedding model settings with new values. @@ -1027,6 +1050,7 @@ def _update_embedding_model_settings(dataset, data, filtered_data): dataset: Current dataset object data: Update data dictionary filtered_data: Filtered update data to modify + session: SQLAlchemy session used for embedding collection binding lookups Returns: str: Action to perform ('update' or None) @@ -1042,7 +1066,7 @@ def _update_embedding_model_settings(dataset, data, filtered_data): # Only update if values are different if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model: - DatasetService._apply_new_embedding_settings(dataset, data, filtered_data) + DatasetService._apply_new_embedding_settings(dataset, data, filtered_data, session) return "update" except LLMBadRequestError: raise ValueError( @@ -1053,7 +1077,7 @@ def _update_embedding_model_settings(dataset, data, filtered_data): return None @staticmethod - def _apply_new_embedding_settings(dataset, data, filtered_data): + def _apply_new_embedding_settings(dataset, data, filtered_data, session: scoped_session | Session): """ Apply new embedding model settings to the dataset. @@ -1061,6 +1085,7 @@ def _apply_new_embedding_settings(dataset, data, filtered_data): dataset: Current dataset object data: Update data dictionary filtered_data: Filtered update data to modify + session: SQLAlchemy session used for embedding collection binding lookups """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None @@ -1096,6 +1121,7 @@ def _apply_new_embedding_settings(dataset, data, filtered_data): dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model_name, + session, ) filtered_data["collection_binding_id"] = dataset_collection_binding.id @@ -1177,6 +1203,7 @@ def update_rag_pipeline_dataset_settings( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model_name, + session, ) dataset.collection_binding_id = dataset_collection_binding.id elif knowledge_configuration.indexing_technique == IndexTechniqueType.ECONOMY: @@ -1213,6 +1240,7 @@ def update_rag_pipeline_dataset_settings( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model_name, + session, ) is_multimodal = DatasetService.check_is_multimodal_model( current_user.current_tenant_id, @@ -1276,6 +1304,7 @@ def update_rag_pipeline_dataset_settings( DatasetCollectionBindingService.get_dataset_collection_binding( embedding_model.provider, embedding_model_name, + session, ) ) dataset.collection_binding_id = dataset_collection_binding.id @@ -1305,24 +1334,24 @@ def update_rag_pipeline_dataset_settings( deal_dataset_index_update_task.delay(dataset.id, action) @staticmethod - def delete_dataset(dataset_id, user): - dataset = DatasetService.get_dataset(dataset_id) + def delete_dataset(dataset_id, user, session: scoped_session | Session): + dataset = DatasetService.get_dataset(dataset_id, session) if dataset is None: return False - DatasetService.check_dataset_permission(dataset, user, db.session) + DatasetService.check_dataset_permission(dataset, user, session) dataset_was_deleted.send(dataset) - db.session.delete(dataset) - db.session.commit() + session.delete(dataset) + session.commit() return True @staticmethod - def dataset_use_check(dataset_id) -> bool: + def dataset_use_check(dataset_id, session: scoped_session | Session) -> bool: stmt = select(exists().where(AppDatasetJoin.dataset_id == dataset_id)) - return db.session.execute(stmt).scalar_one() + return session.execute(stmt).scalar_one() @staticmethod def check_dataset_permission(dataset, user, session: scoped_session | Session): @@ -1347,7 +1376,9 @@ def check_dataset_permission(dataset, user, session: scoped_session | Session): raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod - def check_dataset_operator_permission(user: Account | None = None, dataset: Dataset | None = None): + def check_dataset_operator_permission( + user: Account | None = None, dataset: Dataset | None = None, *, session: scoped_session | Session + ): if not dataset: raise ValueError("Dataset not found") @@ -1362,7 +1393,7 @@ def check_dataset_operator_permission(user: Account | None = None, dataset: Data elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( dp.dataset_id == dataset.id - for dp in db.session.scalars( + for dp in session.scalars( select(DatasetPermission).where(DatasetPermission.account_id == user.id) ).all() ): @@ -1377,16 +1408,16 @@ def get_dataset_queries(dataset_id: str, page: int, per_page: int): return dataset_queries.items, dataset_queries.total @staticmethod - def get_related_apps(dataset_id: str): - return db.session.scalars( + def get_related_apps(dataset_id: str, session: scoped_session | Session): + return session.scalars( select(AppDatasetJoin) .where(AppDatasetJoin.dataset_id == dataset_id) .order_by(AppDatasetJoin.created_at.desc()) ).all() @staticmethod - def update_dataset_api_status(dataset_id: str, status: bool): - dataset = DatasetService.get_dataset(dataset_id) + def update_dataset_api_status(dataset_id: str, status: bool, session: scoped_session | Session): + dataset = DatasetService.get_dataset(dataset_id, session) if dataset is None: raise NotFound("Dataset not found.") dataset.enable_api = status @@ -1394,10 +1425,10 @@ def update_dataset_api_status(dataset_id: str, status: bool): raise ValueError("Current user or current user id not found") dataset.updated_by = current_user.id dataset.updated_at = naive_utc_now() - db.session.commit() + session.commit() @staticmethod - def get_dataset_auto_disable_logs(dataset_id: str) -> AutoDisableLogsDict: + def get_dataset_auto_disable_logs(dataset_id: str, session: scoped_session | Session) -> AutoDisableLogsDict: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None features = FeatureService.get_features(current_user.current_tenant_id, exclude_vector_space=True) @@ -1408,7 +1439,7 @@ def get_dataset_auto_disable_logs(dataset_id: str) -> AutoDisableLogsDict: } # get recent 30 days auto disable logs start_date = datetime.datetime.now() - datetime.timedelta(days=30) - dataset_auto_disable_logs = db.session.scalars( + dataset_auto_disable_logs = session.scalars( select(DatasetAutoDisableLog).where( DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.created_at >= start_date, @@ -1596,9 +1627,12 @@ def apply_display_status_filter(cls, query, status: str | None): } @staticmethod - def get_document(dataset_id: str, document_id: str | None = None) -> Document | None: + def get_document( + dataset_id: str, document_id: str | None = None, *, session: scoped_session | Session + ) -> Document | None: + """Fetch a document by id within a dataset using the caller-provided session.""" if document_id: - document = db.session.scalar( + document = session.scalar( select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1) ) return document @@ -1606,13 +1640,15 @@ def get_document(dataset_id: str, document_id: str | None = None) -> Document | return None @staticmethod - def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequence[Document]: + def get_documents_by_ids( + dataset_id: str, document_ids: Sequence[str], session: scoped_session | Session + ) -> Sequence[Document]: """Fetch documents for a dataset in a single batch query.""" if not document_ids: return [] document_id_list: list[str] = [str(document_id) for document_id in document_ids] # Fetch all requested documents in one query to avoid N+1 lookups. - documents: Sequence[Document] = db.session.scalars( + documents: Sequence[Document] = session.scalars( select(Document).where( Document.dataset_id == dataset_id, Document.id.in_(document_id_list), @@ -1621,7 +1657,12 @@ def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequen return documents @staticmethod - def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], need_summary: bool = True) -> int: + def update_documents_need_summary( + dataset_id: str, + document_ids: Sequence[str], + session: scoped_session | Session, + need_summary: bool = True, + ) -> int: """ Update need_summary field for multiple documents. @@ -1631,6 +1672,7 @@ def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], Args: dataset_id: Dataset ID document_ids: List of document IDs to update + session: SQLAlchemy session used for the update need_summary: Value to set for need_summary field (default: True) Returns: @@ -1641,33 +1683,32 @@ def update_documents_need_summary(dataset_id: str, document_ids: Sequence[str], document_id_list: list[str] = [str(document_id) for document_id in document_ids] - with session_factory.create_session() as session: - result = session.execute( - update(Document) - .where( - Document.id.in_(document_id_list), - Document.dataset_id == dataset_id, - Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents - ) - .values(need_summary=need_summary) - .execution_options(synchronize_session=False) - ) - updated_count = result.rowcount # type: ignore[union-attr,attr-defined] - session.commit() - logger.info( - "Updated need_summary to %s for %d documents in dataset %s", - need_summary, - updated_count, - dataset_id, + result = session.execute( + update(Document) + .where( + Document.id.in_(document_id_list), + Document.dataset_id == dataset_id, + Document.doc_form != IndexStructureType.QA_INDEX, # Skip qa_model documents ) - return updated_count + .values(need_summary=need_summary) + .execution_options(synchronize_session=False) + ) + updated_count = result.rowcount # type: ignore[union-attr,attr-defined] + session.commit() + logger.info( + "Updated need_summary to %s for %d documents in dataset %s", + need_summary, + updated_count, + dataset_id, + ) + return updated_count @staticmethod - def get_document_download_url(document: Document) -> str: + def get_document_download_url(document: Document, session: scoped_session | Session) -> str: """ Return a signed download URL for an upload-file document. """ - upload_file = DocumentService._get_upload_file_for_upload_file_document(document) + upload_file = DocumentService._get_upload_file_for_upload_file_document(document, session) return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True) @staticmethod @@ -1721,15 +1762,16 @@ def prepare_document_batch_download_zip( document_ids: Sequence[str], tenant_id: str, current_user: Account, + session: scoped_session | Session, ) -> tuple[list[UploadFile], str]: """ Resolve upload files for batch ZIP downloads and generate a client-visible filename. """ - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id, session) if not dataset: raise NotFound("Dataset not found.") try: - DatasetService.check_dataset_permission(dataset, current_user, db.session) + DatasetService.check_dataset_permission(dataset, current_user, session) except NoPermissionError as e: raise Forbidden(str(e)) @@ -1737,6 +1779,7 @@ def prepare_document_batch_download_zip( dataset_id=dataset_id, document_ids=document_ids, tenant_id=tenant_id, + session=session, ) upload_files = [upload_files_by_document_id[document_id] for document_id in document_ids] download_name = DocumentService._generate_document_batch_download_zip_filename() @@ -1770,7 +1813,7 @@ def _get_upload_file_id_for_upload_file_document( return str(upload_file_id) @staticmethod - def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile: + def _get_upload_file_for_upload_file_document(document: Document, session: scoped_session | Session) -> UploadFile: """ Load the `UploadFile` row for an upload-file document. """ @@ -1779,7 +1822,9 @@ def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile: invalid_source_message="Document does not have an uploaded file to download.", missing_file_message="Uploaded file not found.", ) - upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), document.tenant_id, [upload_file_id]) + upload_files_by_id = FileService.get_upload_files_by_ids( + _session_for_helpers(session), document.tenant_id, [upload_file_id] + ) upload_file = upload_files_by_id.get(upload_file_id) if not upload_file: raise NotFound("Uploaded file not found.") @@ -1791,13 +1836,14 @@ def _get_upload_files_by_document_id_for_zip_download( dataset_id: str, document_ids: Sequence[str], tenant_id: str, + session: scoped_session | Session, ) -> dict[str, UploadFile]: """ Batch load upload files keyed by document id for ZIP downloads. """ document_id_list: list[str] = [str(document_id) for document_id in document_ids] - documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list) + documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list, session) documents_by_id: dict[str, Document] = {str(document.id): document for document in documents} missing_document_ids: set[str] = set(document_id_list) - set(documents_by_id.keys()) @@ -1818,7 +1864,9 @@ def _get_upload_files_by_document_id_for_zip_download( upload_file_ids.append(upload_file_id) upload_file_ids_by_document_id[document_id] = upload_file_id - upload_files_by_id = FileService.get_upload_files_by_ids(db.session(), tenant_id, upload_file_ids) + upload_files_by_id = FileService.get_upload_files_by_ids( + _session_for_helpers(session), tenant_id, upload_file_ids + ) missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys()) if missing_upload_file_ids: raise NotFound("Only uploaded-file documents can be downloaded as ZIP.") @@ -1829,14 +1877,14 @@ def _get_upload_files_by_document_id_for_zip_download( } @staticmethod - def get_document_by_id(document_id: str) -> Document | None: - document = db.session.get(Document, document_id) + def get_document_by_id(document_id: str, session: scoped_session | Session) -> Document | None: + document = session.get(Document, document_id) return document @staticmethod - def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]: - documents = db.session.scalars( + def get_document_by_ids(document_ids: list[str], session: scoped_session | Session) -> Sequence[Document]: + documents = session.scalars( select(Document).where( Document.id.in_(document_ids), Document.enabled == True, @@ -1847,8 +1895,8 @@ def get_document_by_ids(document_ids: list[str]) -> Sequence[Document]: return documents @staticmethod - def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]: - documents = db.session.scalars( + def get_document_by_dataset_id(dataset_id: str, session: scoped_session | Session) -> Sequence[Document]: + documents = session.scalars( select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, @@ -1858,8 +1906,8 @@ def get_document_by_dataset_id(dataset_id: str) -> Sequence[Document]: return documents @staticmethod - def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: - documents = db.session.scalars( + def get_working_documents_by_dataset_id(dataset_id: str, session: scoped_session | Session) -> Sequence[Document]: + documents = session.scalars( select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, @@ -1871,8 +1919,8 @@ def get_working_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: return documents @staticmethod - def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: - documents = db.session.scalars( + def get_error_documents_by_dataset_id(dataset_id: str, session: scoped_session | Session) -> Sequence[Document]: + documents = session.scalars( select(Document).where( Document.dataset_id == dataset_id, Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]), @@ -1881,9 +1929,9 @@ def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: return documents @staticmethod - def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]: + def get_batch_documents(dataset_id: str, batch: str, session: scoped_session | Session) -> Sequence[Document]: assert isinstance(current_user, Account) - documents = db.session.scalars( + documents = session.scalars( select(Document).where( Document.batch == batch, Document.dataset_id == dataset_id, @@ -1894,8 +1942,8 @@ def get_batch_documents(dataset_id: str, batch: str) -> Sequence[Document]: return documents @staticmethod - def get_document_file_detail(file_id: str): - file_detail = db.session.get(UploadFile, file_id) + def get_document_file_detail(file_id: str, session: scoped_session | Session): + file_detail = session.get(UploadFile, file_id) return file_detail @staticmethod @@ -1906,7 +1954,7 @@ def check_archived(document): return False @staticmethod - def delete_document(document): + def delete_document(document, session: scoped_session | Session): # trigger document_was_deleted signal file_id = None if document.data_source_type == DataSourceType.UPLOAD_FILE: @@ -1918,15 +1966,15 @@ def delete_document(document): document.id, dataset_id=document.dataset_id, doc_form=document.doc_form, file_id=file_id ) - db.session.delete(document) - db.session.commit() + session.delete(document) + session.commit() @staticmethod - def delete_documents(dataset: Dataset, document_ids: list[str]): + def delete_documents(dataset: Dataset, document_ids: list[str], session: scoped_session | Session): # Check if document_ids is not empty to avoid WHERE false condition if not document_ids or len(document_ids) == 0: return - documents = db.session.scalars(select(Document).where(Document.id.in_(document_ids))).all() + documents = session.scalars(select(Document).where(Document.id.in_(document_ids))).all() file_ids = [ document.data_source_info_dict.get("upload_file_id", "") for document in documents @@ -1936,8 +1984,8 @@ def delete_documents(dataset: Dataset, document_ids: list[str]): # Delete documents first, then dispatch cleanup task after commit # to avoid deadlock between main transaction and async task for document in documents: - db.session.delete(document) - db.session.commit() + session.delete(document) + session.commit() # Dispatch cleanup task after commit to avoid lock contention # Task cleans up segments, files, and vector indexes @@ -1945,14 +1993,14 @@ def delete_documents(dataset: Dataset, document_ids: list[str]): batch_clean_document_task.delay(document_ids, dataset.id, dataset.doc_form, file_ids) @staticmethod - def rename_document(dataset_id: str, document_id: str, name: str) -> Document: + def rename_document(dataset_id: str, document_id: str, name: str, session: scoped_session | Session) -> Document: assert isinstance(current_user, Account) - dataset = DatasetService.get_dataset(dataset_id) + dataset = DatasetService.get_dataset(dataset_id, session) if not dataset: raise ValueError("Dataset not found.") - document = DocumentService.get_document(dataset_id, document_id) + document = DocumentService.get_document(dataset_id, document_id, session=session) if not document: raise ValueError("Document not found.") @@ -1967,20 +2015,20 @@ def rename_document(dataset_id: str, document_id: str, name: str) -> Document: document.doc_metadata = doc_metadata document.name = name - db.session.add(document) + session.add(document) if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict: - db.session.execute( + session.execute( update(UploadFile) .where(UploadFile.id == document.data_source_info_dict["upload_file_id"]) .values(name=name) ) - db.session.commit() + session.commit() return document @staticmethod - def pause_document(document): + def pause_document(document, session: scoped_session | Session): if document.indexing_status not in { IndexingStatus.WAITING, IndexingStatus.PARSING, @@ -1995,14 +2043,14 @@ def pause_document(document): document.paused_by = current_user.id document.paused_at = naive_utc_now() - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() # set document paused flag indexing_cache_key = f"document_{document.id}_is_paused" redis_client.setnx(indexing_cache_key, "True") @staticmethod - def recover_document(document): + def recover_document(document, session: scoped_session | Session): if not document.is_paused: raise DocumentIndexingError() # update document to be recover @@ -2010,8 +2058,8 @@ def recover_document(document): document.paused_by = None document.paused_at = None - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() # delete paused flag indexing_cache_key = f"document_{document.id}_is_paused" redis_client.delete(indexing_cache_key) @@ -2019,7 +2067,7 @@ def recover_document(document): recover_document_indexing_task.delay(document.dataset_id, document.id) @staticmethod - def retry_document(dataset_id: str, documents: list[Document]): + def retry_document(dataset_id: str, documents: list[Document], session: scoped_session | Session): for document in documents: # add retry flag retry_indexing_cache_key = f"document_{document.id}_is_retried" @@ -2028,8 +2076,8 @@ def retry_document(dataset_id: str, documents: list[Document]): raise ValueError("Document is being retried, please try again later") # retry document indexing document.indexing_status = IndexingStatus.WAITING - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() redis_client.setex(retry_indexing_cache_key, 600, 1) # trigger async task @@ -2039,7 +2087,7 @@ def retry_document(dataset_id: str, documents: list[Document]): retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id) @staticmethod - def sync_website_document(dataset_id: str, document: Document): + def sync_website_document(dataset_id: str, document: Document, session: scoped_session | Session): # add sync flag sync_indexing_cache_key = f"document_{document.id}_is_sync" cache_result = redis_client.get(sync_indexing_cache_key) @@ -2051,16 +2099,16 @@ def sync_website_document(dataset_id: str, document: Document): if data_source_info: data_source_info["mode"] = "scrape" document.data_source_info = json.dumps(data_source_info, ensure_ascii=False) - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() redis_client.setex(sync_indexing_cache_key, 600, 1) sync_website_document_indexing_task.delay(dataset_id, document.id) @staticmethod - def get_documents_position(dataset_id): - document = db.session.scalar( + def get_documents_position(dataset_id, session: scoped_session | Session): + document = session.scalar( select(Document).where(Document.dataset_id == dataset_id).order_by(Document.position.desc()).limit(1) ) if document: @@ -2075,6 +2123,8 @@ def save_document_with_dataset_id( account: Account | Any, dataset_process_rule: DatasetProcessRule | None = None, created_from: str = DocumentCreatedFrom.WEB, + *, + session: scoped_session | Session, ) -> tuple[list[Document], str]: # check doc_form DatasetService.check_doc_form(dataset, knowledge_config.doc_form) @@ -2126,7 +2176,7 @@ def save_document_with_dataset_id( dataset.embedding_model = dataset_embedding_model dataset.embedding_model_provider = dataset_embedding_model_provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - dataset_embedding_model_provider, dataset_embedding_model + dataset_embedding_model_provider, dataset_embedding_model, session ) dataset.collection_binding_id = dataset_collection_binding.id if not dataset.retrieval_model: @@ -2146,7 +2196,9 @@ def save_document_with_dataset_id( documents = [] if knowledge_config.original_document_id: - document = DocumentService.update_document_with_dataset_id(dataset, knowledge_config, account) + document = DocumentService.update_document_with_dataset_id( + dataset, knowledge_config, account, session=session + ) documents.append(document) batch = document.batch else: @@ -2184,8 +2236,8 @@ def save_document_with_dataset_id( process_rule.mode, ) return [], "" - db.session.add(dataset_process_rule) - db.session.flush() + session.add(dataset_process_rule) + session.flush() else: # Fallback when no process_rule provided in knowledge_config: # 1) reuse dataset.latest_process_rule if present @@ -2198,13 +2250,13 @@ def save_document_with_dataset_id( rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) - db.session.add(dataset_process_rule) - db.session.flush() + session.add(dataset_process_rule) + session.flush() lock_name = f"add_document_lock_dataset_id_{dataset.id}" try: with redis_client.lock(lock_name, timeout=600): assert dataset_process_rule - position = DocumentService.get_documents_position(dataset.id) + position = DocumentService.get_documents_position(dataset.id, session) document_ids = [] duplicate_document_ids = [] if knowledge_config.data_source.info_list.data_source_type == "upload_file": @@ -2212,7 +2264,7 @@ def save_document_with_dataset_id( raise ValueError("File source info is required") upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids files = list( - db.session.scalars( + session.scalars( select(UploadFile).where( UploadFile.tenant_id == dataset.tenant_id, UploadFile.id.in_(upload_file_list), @@ -2224,7 +2276,7 @@ def save_document_with_dataset_id( file_names = [file.name for file in files] db_documents = list( - db.session.scalars( + session.scalars( select(Document).where( Document.dataset_id == dataset.id, Document.tenant_id == current_user.current_tenant_id, @@ -2249,7 +2301,7 @@ def save_document_with_dataset_id( document.data_source_info = json.dumps(data_source_info) document.batch = batch document.indexing_status = IndexingStatus.WAITING - db.session.add(document) + session.add(document) documents.append(document) duplicate_document_ids.append(document.id) continue @@ -2267,8 +2319,8 @@ def save_document_with_dataset_id( file.name, batch, ) - db.session.add(document) - db.session.flush() + session.add(document) + session.flush() document_ids.append(document.id) documents.append(document) position += 1 @@ -2279,7 +2331,7 @@ def save_document_with_dataset_id( exist_page_ids = [] exist_document = {} documents = list( - db.session.scalars( + session.scalars( select(Document).where( Document.dataset_id == dataset.id, Document.tenant_id == current_user.current_tenant_id, @@ -2319,8 +2371,8 @@ def save_document_with_dataset_id( truncated_page_name, batch, ) - db.session.add(document) - db.session.flush() + session.add(document) + session.flush() document_ids.append(document.id) documents.append(document) position += 1 @@ -2359,12 +2411,12 @@ def save_document_with_dataset_id( document_name, batch, ) - db.session.add(document) - db.session.flush() + session.add(document) + session.flush() document_ids.append(document.id) documents.append(document) position += 1 - db.session.commit() + session.commit() # trigger async task if document_ids: @@ -2486,8 +2538,8 @@ def save_document_with_dataset_id( # f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" # ) # return - # db.session.add(dataset_process_rule) - # db.session.commit() + # session.add(dataset_process_rule) + # session.commit() # lock_name = "add_document_lock_dataset_id_{}".format(dataset.id) # with redis_client.lock(lock_name, timeout=600): # position = DocumentService.get_documents_position(dataset.id) @@ -2497,7 +2549,7 @@ def save_document_with_dataset_id( # upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # for file_id in upload_file_list: # file = ( - # db.session.query(UploadFile) + # session.query(UploadFile) # .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) # .first() # ) @@ -2528,7 +2580,7 @@ def save_document_with_dataset_id( # document.data_source_info = json.dumps(data_source_info) # document.batch = batch # document.indexing_status = "waiting" - # db.session.add(document) + # session.add(document) # documents.append(document) # duplicate_document_ids.append(document.id) # continue @@ -2545,8 +2597,8 @@ def save_document_with_dataset_id( # file_name, # batch, # ) - # db.session.add(document) - # db.session.flush() + # session.add(document) + # session.flush() # document_ids.append(document.id) # documents.append(document) # position += 1 @@ -2602,8 +2654,8 @@ def save_document_with_dataset_id( # truncated_page_name, # batch, # ) - # db.session.add(document) - # db.session.flush() + # session.add(document) + # session.flush() # document_ids.append(document.id) # documents.append(document) # position += 1 @@ -2642,12 +2694,12 @@ def save_document_with_dataset_id( # document_name, # batch, # ) - # db.session.add(document) - # db.session.flush() + # session.add(document) + # session.flush() # document_ids.append(document.id) # documents.append(document) # position += 1 - # db.session.commit() + # session.commit() # # trigger async task # if document_ids: @@ -2728,11 +2780,11 @@ def build_document( return document @staticmethod - def get_tenant_documents_count(): + def get_tenant_documents_count(session: scoped_session | Session): assert isinstance(current_user, Account) documents_count = ( - db.session.scalar( + session.scalar( select(func.count(Document.id)).where( Document.completed_at.isnot(None), Document.enabled == True, @@ -2751,11 +2803,13 @@ def update_document_with_dataset_id( account: Account, dataset_process_rule: DatasetProcessRule | None = None, created_from: str = DocumentCreatedFrom.WEB, + *, + session: scoped_session | Session, ): assert isinstance(current_user, Account) DatasetService.check_dataset_model_setting(dataset) - document = DocumentService.get_document(dataset.id, document_data.original_document_id) + document = DocumentService.get_document(dataset.id, document_data.original_document_id, session=session) if document is None: raise NotFound("Document not found") if document.display_status != "available": @@ -2778,8 +2832,8 @@ def update_document_with_dataset_id( created_by=account.id, ) if dataset_process_rule is not None: - db.session.add(dataset_process_rule) - db.session.commit() + session.add(dataset_process_rule) + session.commit() document.dataset_process_rule_id = dataset_process_rule.id # update document data source if document_data.data_source: @@ -2790,7 +2844,7 @@ def update_document_with_dataset_id( raise ValueError("No file info list found.") upload_file_list = document_data.data_source.info_list.file_info_list.file_ids for file_id in upload_file_list: - file = db.session.scalar( + file = session.scalar( select(UploadFile) .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id) .limit(1) @@ -2810,7 +2864,7 @@ def update_document_with_dataset_id( notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: workspace_id = notion_info.workspace_id - data_source_binding = db.session.scalar( + data_source_binding = session.scalar( select(DataSourceOauthBinding) .where( sa.and_( @@ -2861,22 +2915,24 @@ def update_document_with_dataset_id( document.updated_at = naive_utc_now() document.created_from = created_from document.doc_form = IndexStructureType(document_data.doc_form) - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() # update document segment - db.session.execute( + session.execute( update(DocumentSegment) .where(DocumentSegment.document_id == document.id) .values(status=SegmentStatus.RE_SEGMENT) ) - db.session.commit() + session.commit() # trigger async task document_indexing_update_task.delay(document.dataset_id, document.id) return document @staticmethod - def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): + def save_document_without_dataset_id( + tenant_id: str, knowledge_config: KnowledgeConfig, account: Account, session: scoped_session | Session + ): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None assert knowledge_config.data_source @@ -2911,6 +2967,7 @@ def save_document_without_dataset_id(tenant_id: str, knowledge_config: Knowledge dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( knowledge_config.embedding_model_provider, knowledge_config.embedding_model, + session, ) dataset_collection_binding_id = dataset_collection_binding.id if knowledge_config.retrieval_model: @@ -2939,16 +2996,18 @@ def save_document_without_dataset_id(tenant_id: str, knowledge_config: Knowledge is_multimodal=knowledge_config.is_multimodal, ) - db.session.add(dataset) - db.session.flush() + session.add(dataset) + session.flush() - documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, knowledge_config, account, session=session + ) cut_length = 18 cut_name = documents[0].name[:cut_length] dataset.name = cut_name + "..." dataset.description = "useful for when you want to answer queries about the " + documents[0].name - db.session.commit() + session.commit() return dataset, documents, batch @@ -3057,7 +3116,11 @@ def estimate_args_validate(cls, args: dict[str, Any]): @staticmethod def batch_update_document_status( - dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user + dataset: Dataset, + document_ids: list[str], + action: Literal["enable", "disable", "archive", "un_archive"], + user, + session: scoped_session | Session, ): """ Batch update document status. @@ -3084,7 +3147,7 @@ def batch_update_document_status( # First pass: validate all documents and prepare updates for document_id in document_ids: - document = DocumentService.get_document(dataset.id, document_id) + document = DocumentService.get_document(dataset.id, document_id, session=session) if not document: continue @@ -3110,13 +3173,13 @@ def batch_update_document_status( for field, value in updates.items(): setattr(document, field, value) - db.session.add(document) + session.add(document) # Batch commit all changes - db.session.commit() + session.commit() except Exception as e: # Rollback on any error - db.session.rollback() + session.rollback() raise e # Execute async tasks and set Redis cache after successful commit # propagation_error is used to capture any errors for submitting async task execution @@ -3264,7 +3327,9 @@ def segment_create_args_validate(cls, args: dict[str, Any], document: Document): raise ValueError(f"Exceeded maximum attachment limit of {single_chunk_attachment_limit}") @classmethod - def create_segment(cls, args: dict[str, Any], document: Document, dataset: Dataset): + def create_segment( + cls, args: dict[str, Any], document: Document, dataset: Dataset, session: scoped_session | Session + ): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -3285,7 +3350,7 @@ def create_segment(cls, args: dict[str, Any], document: Document, dataset: Datas lock_name = f"add_segment_lock_document_id_{document.id}" try: with redis_client.lock(lock_name, timeout=600): - max_position = db.session.scalar( + max_position = session.scalar( select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id) ) segment_document = DocumentSegment( @@ -3307,12 +3372,12 @@ def create_segment(cls, args: dict[str, Any], document: Document, dataset: Datas segment_document.word_count += len(args["answer"]) segment_document.answer = args["answer"] - db.session.add(segment_document) + session.add(segment_document) # update document word count assert document.word_count is not None document.word_count += segment_document.word_count - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() if args["attachment_ids"]: for attachment_id in args["attachment_ids"]: @@ -3323,8 +3388,8 @@ def create_segment(cls, args: dict[str, Any], document: Document, dataset: Datas segment_id=segment_document.id, attachment_id=attachment_id, ) - db.session.add(binding) - db.session.commit() + session.add(binding) + session.commit() # save vector index try: @@ -3337,14 +3402,16 @@ def create_segment(cls, args: dict[str, Any], document: Document, dataset: Datas segment_document.disabled_at = naive_utc_now() segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) - db.session.commit() - segment = db.session.get(DocumentSegment, segment_document.id) + session.commit() + segment = session.get(DocumentSegment, segment_document.id) return segment except LockNotOwnedError: pass @classmethod - def multi_create_segment(cls, segments: list, document: Document, dataset: Dataset): + def multi_create_segment( + cls, segments: list, document: Document, dataset: Dataset, session: scoped_session | Session + ): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -3361,7 +3428,7 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - max_position = db.session.scalar( + max_position = session.scalar( select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document.id) ) pre_segment_data_list = [] @@ -3402,7 +3469,7 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas segment_document.answer = segment_item["answer"] segment_document.word_count += len(segment_item["answer"]) increment_word_count += segment_document.word_count - db.session.add(segment_document) + session.add(segment_document) segment_data_list.append(segment_document) position += 1 @@ -3414,7 +3481,7 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas # update document word count assert document.word_count is not None document.word_count += increment_word_count - db.session.add(document) + session.add(document) try: # save vector index VectorService.create_segments_vector( @@ -3427,13 +3494,20 @@ def multi_create_segment(cls, segments: list, document: Document, dataset: Datas segment_document.disabled_at = naive_utc_now() segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) - db.session.commit() + session.commit() return segment_data_list except LockNotOwnedError: pass @classmethod - def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, document: Document, dataset: Dataset): + def update_segment( + cls, + args: SegmentUpdateArgs, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + session: scoped_session | Session, + ): assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -3448,8 +3522,8 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum segment.enabled = action segment.disabled_at = naive_utc_now() segment.disabled_by = current_user.id - db.session.add(segment) - db.session.commit() + session.add(segment) + session.commit() # Set cache to prevent indexing the same segment multiple times redis_client.setex(indexing_cache_key, 600, 1) disable_segment_from_index_task.delay(segment.id) @@ -3477,13 +3551,13 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum segment.enabled = True segment.disabled_at = None segment.disabled_by = None - db.session.add(segment) - db.session.commit() + session.add(segment) + session.commit() # update document word count if word_count_change != 0: assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) - db.session.add(document) + session.add(document) # update segment index task if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # regenerate child chunks @@ -3507,7 +3581,7 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum else: raise ValueError("The knowledge base index technique is not high quality!") # get the process rule - processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id) + processing_rule = session.get(DatasetProcessRule, document.dataset_process_rule_id) if processing_rule: VectorService.generate_child_chunks( segment, document, dataset, embedding_model_instance, processing_rule, True @@ -3525,7 +3599,7 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum # Query existing summary from database from models.dataset import DocumentSegmentSummary - existing_summary = db.session.scalar( + existing_summary = session.scalar( select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment.id, @@ -3583,9 +3657,9 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum if word_count_change != 0: assert document.word_count is not None document.word_count = max(0, document.word_count + word_count_change) - db.session.add(document) - db.session.add(segment) - db.session.commit() + session.add(document) + session.add(segment) + session.commit() if document.doc_form == IndexStructureType.PARENT_CHILD_INDEX and args.regenerate_child_chunks: # get embedding model instance if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: @@ -3607,7 +3681,7 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum else: raise ValueError("The knowledge base index technique is not high quality!") # get the process rule - processing_rule = db.session.get(DatasetProcessRule, document.dataset_process_rule_id) + processing_rule = session.get(DatasetProcessRule, document.dataset_process_rule_id) if processing_rule: VectorService.generate_child_chunks( segment, document, dataset, embedding_model_instance, processing_rule, True @@ -3619,7 +3693,7 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: from models.dataset import DocumentSegmentSummary - existing_summary = db.session.scalar( + existing_summary = session.scalar( select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment.id, @@ -3690,14 +3764,16 @@ def update_segment(cls, args: SegmentUpdateArgs, segment: DocumentSegment, docum segment.disabled_at = naive_utc_now() segment.status = SegmentStatus.ERROR segment.error = str(e) - db.session.commit() - new_segment = db.session.get(DocumentSegment, segment.id) + session.commit() + new_segment = session.get(DocumentSegment, segment.id) if not new_segment: raise ValueError("new_segment is not found") return new_segment @classmethod - def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: Dataset): + def delete_segment( + cls, segment: DocumentSegment, document: Document, dataset: Dataset, session: scoped_session | Session + ): indexing_cache_key = f"segment_{segment.id}_delete_indexing" cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: @@ -3712,7 +3788,7 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D child_node_ids = [] if segment.index_node_id: child_node_ids = list( - db.session.scalars( + session.scalars( select(ChildChunk.index_node_id).where( ChildChunk.segment_id == segment.id, ChildChunk.dataset_id == dataset.id, @@ -3724,20 +3800,22 @@ def delete_segment(cls, segment: DocumentSegment, document: Document, dataset: D [segment.index_node_id], dataset.id, document.id, [segment.id], child_node_ids ) - db.session.delete(segment) + session.delete(segment) # update document word count assert document.word_count is not None document.word_count -= segment.word_count - db.session.add(document) - db.session.commit() + session.add(document) + session.commit() @classmethod - def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): + def delete_segments( + cls, segment_ids: list, document: Document, dataset: Dataset, session: scoped_session | Session + ): assert current_user is not None # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - segments_info = db.session.execute( + segments_info = session.execute( select(DocumentSegment.index_node_id, DocumentSegment.id, DocumentSegment.word_count).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -3758,7 +3836,7 @@ def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset if index_node_ids: child_node_ids = [ nid - for nid in db.session.scalars( + for nid in session.scalars( select(ChildChunk.index_node_id).where( ChildChunk.segment_id.in_(segment_db_ids), ChildChunk.dataset_id == dataset.id, @@ -3778,15 +3856,20 @@ def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset else: document.word_count = max(0, document.word_count - total_words) - db.session.add(document) + session.add(document) # Delete database records - db.session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))) - db.session.commit() + session.execute(delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))) + session.commit() @classmethod def update_segments_status( - cls, segment_ids: list, action: Literal["enable", "disable"], dataset: Dataset, document: Document + cls, + segment_ids: list, + action: Literal["enable", "disable"], + dataset: Dataset, + document: Document, + session: scoped_session | Session, ): assert current_user is not None @@ -3795,7 +3878,7 @@ def update_segments_status( return match action: case "enable": - segments = db.session.scalars( + segments = session.scalars( select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -3814,13 +3897,13 @@ def update_segments_status( segment.enabled = True segment.disabled_at = None segment.disabled_by = None - db.session.add(segment) + session.add(segment) real_deal_segment_ids.append(segment.id) - db.session.commit() + session.commit() enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) case "disable": - segments = db.session.scalars( + segments = session.scalars( select(DocumentSegment).where( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -3839,15 +3922,20 @@ def update_segments_status( segment.enabled = False segment.disabled_at = naive_utc_now() segment.disabled_by = current_user.id - db.session.add(segment) + session.add(segment) real_deal_segment_ids.append(segment.id) - db.session.commit() + session.commit() disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) @classmethod def create_child_chunk( - cls, content: str, segment: DocumentSegment, document: Document, dataset: Dataset + cls, + content: str, + segment: DocumentSegment, + document: Document, + dataset: Dataset, + session: scoped_session | Session, ) -> ChildChunk: assert isinstance(current_user, Account) @@ -3855,7 +3943,7 @@ def create_child_chunk( with redis_client.lock(lock_name, timeout=20): index_node_id = str(uuid.uuid4()) index_node_hash = helper.generate_text_hash(content) - max_position = db.session.scalar( + max_position = session.scalar( select(func.max(ChildChunk.position)).where( ChildChunk.tenant_id == current_user.current_tenant_id, ChildChunk.dataset_id == dataset.id, @@ -3877,15 +3965,15 @@ def create_child_chunk( type=SegmentType.CUSTOMIZED, created_by=current_user.id, ) - db.session.add(child_chunk) + session.add(child_chunk) # save vector index try: VectorService.create_child_chunk_vector(child_chunk, dataset) except Exception as e: logger.exception("create child chunk index failed") - db.session.rollback() + session.rollback() raise ChildChunkIndexingError(str(e)) - db.session.commit() + session.commit() return child_chunk @@ -3896,9 +3984,10 @@ def update_child_chunks( segment: DocumentSegment, document: Document, dataset: Dataset, + session: scoped_session | Session, ) -> list[ChildChunk]: assert isinstance(current_user, Account) - child_chunks = db.session.scalars( + child_chunks = session.scalars( select(ChildChunk).where( ChildChunk.dataset_id == dataset.id, ChildChunk.document_id == document.id, @@ -3926,11 +4015,11 @@ def update_child_chunks( delete_child_chunks = list(child_chunks_map.values()) try: if update_child_chunks: - db.session.bulk_save_objects(update_child_chunks) + session.bulk_save_objects(update_child_chunks) if delete_child_chunks: for child_chunk in delete_child_chunks: - db.session.delete(child_chunk) + session.delete(child_chunk) if new_child_chunks_args: child_chunk_count = len(child_chunks) for position, args in enumerate(new_child_chunks_args, start=child_chunk_count + 1): @@ -3951,14 +4040,14 @@ def update_child_chunks( created_by=current_user.id, ) - db.session.add(child_chunk) - db.session.flush() + session.add(child_chunk) + session.flush() new_child_chunks.append(child_chunk) VectorService.update_child_chunk_vector(new_child_chunks, update_child_chunks, delete_child_chunks, dataset) - db.session.commit() + session.commit() except Exception as e: logger.exception("update child chunk index failed") - db.session.rollback() + session.rollback() raise ChildChunkIndexingError(str(e)) return sorted(new_child_chunks + update_child_chunks, key=lambda x: x.position) @@ -3970,6 +4059,7 @@ def update_child_chunk( segment: DocumentSegment, document: Document, dataset: Dataset, + session: scoped_session | Session, ) -> ChildChunk: assert current_user is not None @@ -3979,25 +4069,25 @@ def update_child_chunk( child_chunk.updated_by = current_user.id child_chunk.updated_at = naive_utc_now() child_chunk.type = SegmentType.CUSTOMIZED - db.session.add(child_chunk) + session.add(child_chunk) VectorService.update_child_chunk_vector([], [child_chunk], [], dataset) - db.session.commit() + session.commit() except Exception as e: logger.exception("update child chunk index failed") - db.session.rollback() + session.rollback() raise ChildChunkIndexingError(str(e)) return child_chunk @classmethod - def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset): - db.session.delete(child_chunk) + def delete_child_chunk(cls, child_chunk: ChildChunk, dataset: Dataset, session: scoped_session | Session): + session.delete(child_chunk) try: VectorService.delete_child_chunk_vector(child_chunk, dataset) except Exception as e: logger.exception("delete child chunk index failed") - db.session.rollback() + session.rollback() raise ChildChunkDeleteIndexError(str(e)) - db.session.commit() + session.commit() @classmethod def get_child_chunks( @@ -4021,9 +4111,11 @@ def get_child_chunks( return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod - def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> ChildChunk | None: + def get_child_chunk_by_id( + cls, child_chunk_id: str, tenant_id: str, session: scoped_session | Session + ) -> ChildChunk | None: """Get a child chunk by its ID.""" - result = db.session.scalar( + result = session.scalar( select(ChildChunk).where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).limit(1) ) return result if isinstance(result, ChildChunk) else None @@ -4057,9 +4149,11 @@ def get_segments( return paginated_segments.items, paginated_segments.total @classmethod - def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> DocumentSegment | None: + def get_segment_by_id( + cls, segment_id: str, tenant_id: str, session: scoped_session | Session + ) -> DocumentSegment | None: """Get a segment by its ID.""" - result = db.session.scalar( + result = session.scalar( select(DocumentSegment) .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) .limit(1) @@ -4071,6 +4165,7 @@ def get_segments_by_document_and_dataset( cls, document_id: str, dataset_id: str, + session: scoped_session | Session, status: str | None = None, enabled: bool | None = None, ) -> Sequence[DocumentSegment]: @@ -4097,15 +4192,15 @@ def get_segments_by_document_and_dataset( if enabled is not None: query = query.where(DocumentSegment.enabled == enabled) - return db.session.scalars(query).all() + return session.scalars(query).all() class DatasetCollectionBindingService: @classmethod def get_dataset_collection_binding( - cls, provider_name: str, model_name: str, collection_type: str = "dataset" + cls, provider_name: str, model_name: str, session: scoped_session | Session, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.scalar( + dataset_collection_binding = session.scalar( select(DatasetCollectionBinding) .where( DatasetCollectionBinding.provider_name == provider_name, @@ -4123,15 +4218,15 @@ def get_dataset_collection_binding( collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), type=collection_type, ) - db.session.add(dataset_collection_binding) - db.session.commit() + session.add(dataset_collection_binding) + session.commit() return dataset_collection_binding @classmethod def get_dataset_collection_binding_by_id_and_type( - cls, collection_binding_id: str, collection_type: str = "dataset" + cls, collection_binding_id: str, session: scoped_session | Session, collection_type: str = "dataset" ) -> DatasetCollectionBinding: - dataset_collection_binding = db.session.scalar( + dataset_collection_binding = session.scalar( select(DatasetCollectionBinding) .where( DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type @@ -4147,8 +4242,8 @@ def get_dataset_collection_binding_by_id_and_type( class DatasetPermissionService: @classmethod - def get_dataset_partial_member_list(cls, dataset_id): - user_list_query = db.session.scalars( + def get_dataset_partial_member_list(cls, dataset_id, session: scoped_session | Session): + user_list_query = session.scalars( select( DatasetPermission.account_id, ).where(DatasetPermission.dataset_id == dataset_id) @@ -4157,9 +4252,9 @@ def get_dataset_partial_member_list(cls, dataset_id): return user_list_query @classmethod - def update_partial_member_list(cls, tenant_id, dataset_id, user_list): + def update_partial_member_list(cls, tenant_id, dataset_id, user_list, session: scoped_session | Session): try: - db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id)) + session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id)) permissions = [] for user in user_list: permission = DatasetPermission( @@ -4169,14 +4264,16 @@ def update_partial_member_list(cls, tenant_id, dataset_id, user_list): ) permissions.append(permission) - db.session.add_all(permissions) - db.session.commit() + session.add_all(permissions) + session.commit() except Exception as e: - db.session.rollback() + session.rollback() raise e @classmethod - def check_permission(cls, user, dataset, requested_permission, requested_partial_member_list): + def check_permission( + cls, user, dataset, requested_permission, requested_partial_member_list, session: scoped_session | Session + ): if not user.is_dataset_editor: raise NoPermissionError("User does not have permission to edit this dataset.") @@ -4187,16 +4284,16 @@ def check_permission(cls, user, dataset, requested_permission, requested_partial if not requested_partial_member_list: raise ValueError("Partial member list is required when setting to partial members.") - local_member_list = cls.get_dataset_partial_member_list(dataset.id) + local_member_list = cls.get_dataset_partial_member_list(dataset.id, session) request_member_list = [user["user_id"] for user in requested_partial_member_list] if set(local_member_list) != set(request_member_list): raise ValueError("Dataset operators cannot change the dataset permissions.") @classmethod - def clear_partial_member_list(cls, dataset_id): + def clear_partial_member_list(cls, dataset_id, session: scoped_session | Session): try: - db.session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id)) - db.session.commit() + session.execute(delete(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id)) + session.commit() except Exception as e: - db.session.rollback() + session.rollback() raise e diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index d9cd65b2b39df4..4e83858ea0e366 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -107,7 +107,7 @@ def update_metadata_name( ).all() if dataset_metadata_bindings: document_ids = [binding.document_id for binding in dataset_metadata_bindings] - documents = DocumentService.get_document_by_ids(document_ids) + documents = DocumentService.get_document_by_ids(document_ids, session) for document in documents: if not document.doc_metadata: doc_metadata = {} @@ -145,7 +145,7 @@ def delete_metadata(session: Session, dataset_id: str, metadata_id: str): ).all() if dataset_metadata_bindings: document_ids = [binding.document_id for binding in dataset_metadata_bindings] - documents = DocumentService.get_document_by_ids(document_ids) + documents = DocumentService.get_document_by_ids(document_ids, session) for document in documents: if not document.doc_metadata: doc_metadata = {} @@ -179,7 +179,7 @@ def enable_built_in_field(session: Session, dataset: Dataset): try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) session.add(dataset) - documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) + documents = DocumentService.get_working_documents_by_dataset_id(dataset.id, session) if documents: for document in documents: if not document.doc_metadata: @@ -208,7 +208,7 @@ def disable_built_in_field(session: Session, dataset: Dataset): try: MetadataService.knowledge_base_metadata_lock_check(dataset.id, None) session.add(dataset) - documents = DocumentService.get_working_documents_by_dataset_id(dataset.id) + documents = DocumentService.get_working_documents_by_dataset_id(dataset.id, session) document_ids = [] if documents: for document in documents: @@ -246,7 +246,7 @@ def update_documents_metadata( lock_key = f"document_metadata_lock_{operation.document_id}" try: MetadataService.knowledge_base_metadata_lock_check(None, operation.document_id) - document = DocumentService.get_document(dataset.id, operation.document_id) + document = DocumentService.get_document(dataset.id, operation.document_id, session=session) if document is None: raise ValueError("Document not found.") if operation.partial_update: diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index bdbf3e080e9331..4f9cde37a7ba49 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -15,7 +15,7 @@ from flask_login import current_user from pydantic import BaseModel from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, scoped_session from core.file import remote_fetcher from core.helper.name_generator import generate_incremental_name @@ -83,7 +83,7 @@ class RagPipelineDslService: when generated IDs are needed mid-operation; they never commit or rollback. """ - def __init__(self, session: Session): + def __init__(self, session: Session | scoped_session): self._session = session def import_rag_pipeline( diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 1657cfdd2566c6..3e065653bdfc7b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -7,7 +7,7 @@ from typing import TypedDict, cast from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, scoped_session from core.db.session_factory import session_factory from core.model_manager import ModelManager @@ -1407,6 +1407,7 @@ def get_documents_summary_index_status( def get_document_summary_status_detail( document_id: str, dataset_id: str, + session: Session | scoped_session, ) -> DocumentSummaryStatusDetailDict: """ Get detailed summary status for a document. @@ -1414,6 +1415,7 @@ def get_document_summary_status_detail( Args: document_id: Document ID dataset_id: Dataset ID + session: SQLAlchemy session used for segment lookup Returns: Dictionary containing: @@ -1431,6 +1433,7 @@ def get_document_summary_status_detail( segments = SegmentService.get_segments_by_document_and_dataset( document_id=document_id, dataset_id=dataset_id, + session=session, status="completed", enabled=True, ) diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index dafa36cc343d63..81f57d7d6f3f7f 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -4,6 +4,7 @@ import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document @@ -31,9 +32,10 @@ def add_annotation_to_index_task( start_at = time.perf_counter() try: - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, "annotation" - ) + with session_factory.create_session() as session: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, session, "annotation" + ) dataset = Dataset( id=app_id, tenant_id=tenant_id, diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py index 89844ef44b4db7..343b8ee2fa3e3a 100644 --- a/api/tasks/annotation/batch_import_annotations_task.py +++ b/api/tasks/annotation/batch_import_annotations_task.py @@ -63,7 +63,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id: if app_annotation_setting: dataset_collection_binding = ( DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - app_annotation_setting.collection_binding_id, "annotation" + app_annotation_setting.collection_binding_id, session, "annotation" ) ) if not dataset_collection_binding: diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index c9aa8fadb78d7a..79a8db2548b367 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -4,6 +4,7 @@ import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.index_type import IndexTechniqueType from models.dataset import Dataset @@ -20,9 +21,10 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str logger.info(click.style(f"Start delete app annotation index: {app_id}", fg="green")) start_at = time.perf_counter() try: - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, "annotation" - ) + with session_factory.create_session() as session: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, session, "annotation" + ) dataset = Dataset( id=app_id, diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 4cbca13a92e005..32c010eaef1cf7 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -51,7 +51,7 @@ def enable_annotation_reply_task( try: documents = [] dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION + embedding_provider_name, embedding_model_name, session, CollectionBindingType.ANNOTATION ) annotation_setting = session.scalar( select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1) @@ -60,7 +60,7 @@ def enable_annotation_reply_task( if dataset_collection_binding.id != annotation_setting.collection_binding_id: old_dataset_collection_binding = ( DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION + annotation_setting.collection_binding_id, session, CollectionBindingType.ANNOTATION ) ) if old_dataset_collection_binding and annotations: diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index f41da1d373e42c..eecc1f6fc7b4e5 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -4,6 +4,7 @@ import click from celery import shared_task +from core.db.session_factory import session_factory from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import Document @@ -31,9 +32,10 @@ def update_annotation_to_index_task( start_at = time.perf_counter() try: - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - collection_binding_id, "annotation" - ) + with session_factory.create_session() as session: + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + collection_binding_id, session, "annotation" + ) dataset = Dataset( id=app_id, diff --git a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py index c644281190cfee..aab79538f3c235 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py +++ b/api/tests/test_containers_integration_tests/controllers/console/datasets/test_data_source.py @@ -6,9 +6,11 @@ from collections.abc import Iterator from datetime import UTC, datetime from unittest.mock import MagicMock, PropertyMock, patch +from uuid import uuid4 import pytest from flask import Flask +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.console.datasets import data_source @@ -22,6 +24,8 @@ ) from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, DataSourceOauthBinding +from models.dataset import Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus @pytest.fixture @@ -277,9 +281,13 @@ def test_get_success_no_dataset_id(self, app: Flask, current_user: Account, mock assert status == 200 - def test_get_success_with_dataset_id(self, app: Flask, current_user: Account, mock_engine: None) -> None: + def test_get_success_with_dataset_id( + self, app: Flask, current_user: Account, mock_engine: None, db_session_with_containers: Session + ) -> None: api = DataSourceNotionListApi() method = inspect.unwrap(api.get) + tenant_id = str(uuid4()) + dataset_id = str(uuid4()) page = MagicMock( page_id="p1", @@ -301,10 +309,24 @@ def test_get_success_with_dataset_id(self, app: Flask, current_user: Account, mo ) dataset = MagicMock(data_source_type="notion_import") - document = MagicMock(data_source_info='{"notion_page_id": "p1"}') + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type=DataSourceType.NOTION_IMPORT, + data_source_info='{"notion_page_id": "p1"}', + batch=f"batch-{uuid4()}", + name="Notion Page", + created_from=DocumentCreatedFrom.WEB, + created_by=str(uuid4()), + indexing_status=IndexingStatus.COMPLETED, + enabled=True, + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() with ( - app.test_request_context("/?credential_id=c1&dataset_id=ds1"), + app.test_request_context(f"/?credential_id=c1&dataset_id={dataset_id}"), patch( "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", return_value={"token": "t"}, @@ -313,7 +335,6 @@ def test_get_success_with_dataset_id(self, app: Flask, current_user: Account, mo "controllers.console.datasets.data_source.DatasetService.get_dataset", return_value=dataset, ), - patch("controllers.console.datasets.data_source.sessionmaker") as mock_session_class, patch( "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", return_value=MagicMock( @@ -322,11 +343,7 @@ def test_get_success_with_dataset_id(self, app: Flask, current_user: Account, mo ), ), ): - mock_session = MagicMock() - mock_session_class.return_value.begin.return_value.__enter__.return_value = mock_session - mock_session.scalars.return_value.all.return_value = [document] - - response, status = method(api, "tenant-1", current_user) + response, status = method(api, tenant_id, current_user) assert status == 200 diff --git a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py index 1717fea789f24b..6e8af5ca43e9fd 100644 --- a/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/test_containers_integration_tests/controllers/service_api/dataset/test_dataset.py @@ -729,13 +729,15 @@ def test_patch_dataset_success_shape( assert response["name"] == "Updated Dataset" assert response["partial_member_list"] == ["user-1"] mock_dataset_svc.update_dataset.assert_called_once() - _, update_data, _ = mock_dataset_svc.update_dataset.call_args.args + _, update_data, _, session = mock_dataset_svc.update_dataset.call_args.args + assert isinstance(session, (Session, scoped_session)) assert update_data["name"] == "Updated Dataset" assert update_data["permission"] == "partial_members" mock_perm_svc.update_partial_member_list.assert_called_once_with( mock_dataset.tenant_id, mock_dataset.id, [{"user_id": "user-1", "role": "editor"}], + SessionMatcher(), ) diff --git a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py index 638a61c8151a0f..d180da9555e135 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py +++ b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py @@ -88,7 +88,7 @@ def test_get_dataset_collection_binding_existing_binding_success(self, db_sessio # Act result = DatasetCollectionBindingService.get_dataset_collection_binding( - provider_name, model_name, collection_type + provider_name, model_name, session=db_session_with_containers, collection_type=collection_type ) # Assert @@ -109,7 +109,7 @@ def test_get_dataset_collection_binding_create_new_binding_success(self, db_sess # Act result = DatasetCollectionBindingService.get_dataset_collection_binding( - provider_name, model_name, collection_type + provider_name, model_name, session=db_session_with_containers, collection_type=collection_type ) # Assert @@ -128,7 +128,7 @@ def test_get_dataset_collection_binding_different_collection_type(self, db_sessi # Act result = DatasetCollectionBindingService.get_dataset_collection_binding( - provider_name, model_name, collection_type + provider_name, model_name, session=db_session_with_containers, collection_type=collection_type ) # Assert @@ -143,7 +143,9 @@ def test_get_dataset_collection_binding_default_collection_type(self, db_session model_name = "text-embedding-ada-002" # Act - result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name) + result = DatasetCollectionBindingService.get_dataset_collection_binding( + provider_name, model_name, session=db_session_with_containers + ) # Assert assert result.type == CollectionBindingType.DATASET @@ -192,7 +194,7 @@ def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_ # Act result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - binding.id, CollectionBindingType.DATASET + binding.id, session=db_session_with_containers, collection_type=CollectionBindingType.DATASET ) # Assert @@ -210,7 +212,7 @@ def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_ # Act & Assert with pytest.raises(ValueError, match="Dataset collection binding not found"): DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - non_existent_id, CollectionBindingType.DATASET + non_existent_id, session=db_session_with_containers, collection_type=CollectionBindingType.DATASET ) def test_get_dataset_collection_binding_by_id_and_type_different_collection_type( @@ -228,7 +230,7 @@ def test_get_dataset_collection_binding_by_id_and_type_different_collection_type # Act result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - binding.id, "custom_type" + binding.id, session=db_session_with_containers, collection_type="custom_type" ) # Assert @@ -249,7 +251,9 @@ def test_get_dataset_collection_binding_by_id_and_type_default_collection_type( ) # Act - result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id) + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + binding.id, session=db_session_with_containers + ) # Assert assert result.id == binding.id @@ -268,4 +272,6 @@ def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db # Act & Assert with pytest.raises(ValueError, match="Dataset collection binding not found"): - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "wrong_type") + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + binding.id, session=db_session_with_containers, collection_type="wrong_type" + ) diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 1cffc43658a95f..7d4cd63267ab5e 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -141,7 +141,7 @@ def test_delete_dataset_success(self, db_session_with_containers: Session): # Act with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted: - result = DatasetService.delete_dataset(dataset.id, owner) + result = DatasetService.delete_dataset(dataset.id, owner, session=db_session_with_containers) # Assert assert result is True @@ -168,7 +168,7 @@ def test_delete_dataset_not_found(self, db_session_with_containers: Session): dataset_id = str(uuid4()) # Act - result = DatasetService.delete_dataset(dataset_id, owner) + result = DatasetService.delete_dataset(dataset_id, owner, session=db_session_with_containers) # Assert assert result is False @@ -198,7 +198,7 @@ def test_delete_dataset_permission_denied_error(self, db_session_with_containers # Act & Assert with pytest.raises(NoPermissionError): - DatasetService.delete_dataset(dataset.id, normal_user) + DatasetService.delete_dataset(dataset.id, normal_user, session=db_session_with_containers) # Verify no deletion was attempted assert db_session_with_containers.get(Dataset, dataset.id) is not None @@ -230,7 +230,7 @@ def test_dataset_use_check_in_use(self, db_session_with_containers: Session): DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(db_session_with_containers, app.id, dataset.id) # Act - result = DatasetService.dataset_use_check(dataset.id) + result = DatasetService.dataset_use_check(dataset.id, session=db_session_with_containers) # Assert assert result is True @@ -254,7 +254,7 @@ def test_dataset_use_check_not_in_use(self, db_session_with_containers: Session) dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act - result = DatasetService.dataset_use_check(dataset.id) + result = DatasetService.dataset_use_check(dataset.id, session=db_session_with_containers) # Assert assert result is False @@ -292,7 +292,7 @@ def test_update_dataset_api_status_enable_success(self, db_session_with_containe patch("services.dataset_service.current_user", owner), patch("services.dataset_service.naive_utc_now", return_value=current_time), ): - DatasetService.update_dataset_api_status(dataset.id, True) + DatasetService.update_dataset_api_status(dataset.id, True, session=db_session_with_containers) # Assert db_session_with_containers.refresh(dataset) @@ -327,7 +327,7 @@ def test_update_dataset_api_status_disable_success(self, db_session_with_contain patch("services.dataset_service.current_user", owner), patch("services.dataset_service.naive_utc_now", return_value=current_time), ): - DatasetService.update_dataset_api_status(dataset.id, False) + DatasetService.update_dataset_api_status(dataset.id, False, session=db_session_with_containers) # Assert db_session_with_containers.refresh(dataset) @@ -351,7 +351,7 @@ def test_update_dataset_api_status_not_found_error(self, db_session_with_contain # Act & Assert with pytest.raises(NotFound, match="Dataset not found"): - DatasetService.update_dataset_api_status(dataset_id, True) + DatasetService.update_dataset_api_status(dataset_id, True, session=db_session_with_containers) def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers: Session): """ @@ -378,7 +378,7 @@ def test_update_dataset_api_status_missing_current_user_error(self, db_session_w patch("services.dataset_service.current_user", None), pytest.raises(ValueError, match="Current user or current user id not found"), ): - DatasetService.update_dataset_api_status(dataset.id, True) + DatasetService.update_dataset_api_status(dataset.id, True, session=db_session_with_containers) # Verify no commit was attempted db_session_with_containers.rollback() diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index 327f14ddfe7670..7e78cef1db3b08 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -301,7 +301,7 @@ def test_pause_document_waiting_state_success( ) # Act - DocumentService.pause_document(document) + DocumentService.pause_document(document, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -336,7 +336,7 @@ def test_pause_document_indexing_state_success( ) # Act - DocumentService.pause_document(document) + DocumentService.pause_document(document, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -366,7 +366,7 @@ def test_pause_document_parsing_state_success( ) # Act - DocumentService.pause_document(document) + DocumentService.pause_document(document, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -398,7 +398,7 @@ def test_pause_document_completed_state_error( # Act & Assert with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) + DocumentService.pause_document(document, session=db_session_with_containers) db_session_with_containers.refresh(document) assert document.is_paused is False @@ -429,7 +429,7 @@ def test_pause_document_error_state_error( # Act & Assert with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(document) + DocumentService.pause_document(document, session=db_session_with_containers) db_session_with_containers.refresh(document) assert document.is_paused is False @@ -507,7 +507,7 @@ def test_recover_document_paused_success( ) # Act - DocumentService.recover_document(document) + DocumentService.recover_document(document, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -547,7 +547,7 @@ def test_recover_document_not_paused_error( # Act & Assert with pytest.raises(DocumentIndexingError): - DocumentService.recover_document(document) + DocumentService.recover_document(document, session=db_session_with_containers) db_session_with_containers.refresh(document) assert document.is_paused is False @@ -632,7 +632,7 @@ def test_retry_document_single_success( mock_document_service_dependencies["redis_client"].get.return_value = None # Act - DocumentService.retry_document(dataset.id, [document]) + DocumentService.retry_document(dataset.id, [document], session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -679,7 +679,7 @@ def test_retry_document_multiple_success( mock_document_service_dependencies["redis_client"].get.return_value = None # Act - DocumentService.retry_document(dataset.id, [document1, document2]) + DocumentService.retry_document(dataset.id, [document1, document2], session=db_session_with_containers) # Assert db_session_with_containers.refresh(document1) @@ -719,7 +719,7 @@ def test_retry_document_concurrent_retry_error( # Act & Assert with pytest.raises(ValueError, match="Document is being retried, please try again later"): - DocumentService.retry_document(dataset.id, [document]) + DocumentService.retry_document(dataset.id, [document], session=db_session_with_containers) db_session_with_containers.refresh(document) assert document.indexing_status == IndexingStatus.ERROR @@ -753,7 +753,7 @@ def test_retry_document_missing_current_user_error( # Act & Assert with pytest.raises(ValueError, match="Current user or current user id not found"): - DocumentService.retry_document(dataset.id, [document]) + DocumentService.retry_document(dataset.id, [document], session=db_session_with_containers) class TestDocumentServiceBatchUpdateDocumentStatus: @@ -851,7 +851,9 @@ def test_batch_update_document_status_enable_success( mock_document_service_dependencies["redis_client"].get.return_value = None # Act - DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) + DocumentService.batch_update_document_status( + dataset, document_ids, "enable", user, session=db_session_with_containers + ) # Assert db_session_with_containers.refresh(document1) @@ -893,7 +895,9 @@ def test_batch_update_document_status_disable_success( mock_document_service_dependencies["redis_client"].get.return_value = None # Act - DocumentService.batch_update_document_status(dataset, document_ids, "disable", user) + DocumentService.batch_update_document_status( + dataset, document_ids, "disable", user, session=db_session_with_containers + ) # Assert db_session_with_containers.refresh(document) @@ -935,7 +939,9 @@ def test_batch_update_document_status_archive_success( mock_document_service_dependencies["redis_client"].get.return_value = None # Act - DocumentService.batch_update_document_status(dataset, document_ids, "archive", user) + DocumentService.batch_update_document_status( + dataset, document_ids, "archive", user, session=db_session_with_containers + ) # Assert db_session_with_containers.refresh(document) @@ -977,7 +983,9 @@ def test_batch_update_document_status_unarchive_success( mock_document_service_dependencies["redis_client"].get.return_value = None # Act - DocumentService.batch_update_document_status(dataset, document_ids, "un_archive", user) + DocumentService.batch_update_document_status( + dataset, document_ids, "un_archive", user, session=db_session_with_containers + ) # Assert db_session_with_containers.refresh(document) @@ -1006,7 +1014,9 @@ def test_batch_update_document_status_empty_list( document_ids = [] # Act - DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) + DocumentService.batch_update_document_status( + dataset, document_ids, "enable", user, session=db_session_with_containers + ) # Assert mock_document_service_dependencies["add_task"].delay.assert_not_called() @@ -1042,7 +1052,9 @@ def test_batch_update_document_status_document_indexing_error( # Act & Assert with pytest.raises(DocumentIndexingError, match="is being indexed"): - DocumentService.batch_update_document_status(dataset, document_ids, "enable", user) + DocumentService.batch_update_document_status( + dataset, document_ids, "enable", user, session=db_session_with_containers + ) class TestDocumentServiceRenameDocument: @@ -1121,7 +1133,7 @@ def test_rename_document_success(self, db_session_with_containers: Session, mock ) # Act - result = DocumentService.rename_document(dataset.id, document.id, new_name) + result = DocumentService.rename_document(dataset.id, document.id, new_name, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -1164,7 +1176,7 @@ def test_rename_document_with_built_in_fields( ) # Act - DocumentService.rename_document(dataset.id, document.id, new_name) + DocumentService.rename_document(dataset.id, document.id, new_name, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -1214,7 +1226,7 @@ def test_rename_document_with_upload_file( ) # Act - DocumentService.rename_document(dataset.id, document.id, new_name) + DocumentService.rename_document(dataset.id, document.id, new_name, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -1243,7 +1255,7 @@ def test_rename_document_dataset_not_found_error( # Act & Assert with pytest.raises(ValueError, match="Dataset not found"): - DocumentService.rename_document(dataset_id, document_id, new_name) + DocumentService.rename_document(dataset_id, document_id, new_name, session=db_session_with_containers) def test_rename_document_not_found_error( self, db_session_with_containers: Session, mock_document_service_dependencies @@ -1272,7 +1284,7 @@ def test_rename_document_not_found_error( # Act & Assert with pytest.raises(ValueError, match="Document not found"): - DocumentService.rename_document(dataset.id, document_id, new_name) + DocumentService.rename_document(dataset.id, document_id, new_name, session=db_session_with_containers) def test_rename_document_permission_error( self, db_session_with_containers: Session, mock_document_service_dependencies @@ -1309,4 +1321,4 @@ def test_rename_document_permission_error( # Act & Assert with pytest.raises(ValueError, match="No permission"): - DocumentService.rename_document(dataset.id, document.id, new_name) + DocumentService.rename_document(dataset.id, document.id, new_name, session=db_session_with_containers) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py index b88204b2a6af49..0df4ece5d2d9f8 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -134,7 +134,9 @@ def test_get_dataset_partial_member_list_with_members(self, db_session_with_cont DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, account_id, tenant.id) # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) # Assert assert set(result) == set(expected_account_ids) @@ -156,7 +158,9 @@ def test_get_dataset_partial_member_list_with_single_member(self, db_session_wit DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) # Assert assert set(result) == set(expected_account_ids) @@ -171,7 +175,9 @@ def test_get_dataset_partial_member_list_empty(self, db_session_with_containers: dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) # Assert assert result == [] @@ -199,10 +205,14 @@ def test_update_partial_member_list_add_new_members(self, db_session_with_contai user_list = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) # Act - DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list) + DatasetPermissionService.update_partial_member_list( + tenant.id, dataset.id, user_list, session=db_session_with_containers + ) # Assert - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert set(result) == {member_1.id, member_2.id} def test_update_partial_member_list_replace_existing(self, db_session_with_containers: Session): @@ -230,15 +240,21 @@ def test_update_partial_member_list_replace_existing(self, db_session_with_conta dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) old_users = DatasetPermissionTestDataFactory.build_user_list_payload([old_member_1.id, old_member_2.id]) - DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, old_users) + DatasetPermissionService.update_partial_member_list( + tenant.id, dataset.id, old_users, session=db_session_with_containers + ) new_users = DatasetPermissionTestDataFactory.build_user_list_payload([new_member_1.id, new_member_2.id]) # Act - DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, new_users) + DatasetPermissionService.update_partial_member_list( + tenant.id, dataset.id, new_users, session=db_session_with_containers + ) # Assert - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert set(result) == {new_member_1.id, new_member_2.id} def test_update_partial_member_list_empty_list(self, db_session_with_containers: Session): @@ -257,13 +273,19 @@ def test_update_partial_member_list_empty_list(self, db_session_with_containers: ) dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) - DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + DatasetPermissionService.update_partial_member_list( + tenant.id, dataset.id, users, session=db_session_with_containers + ) # Act - DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, []) + DatasetPermissionService.update_partial_member_list( + tenant.id, dataset.id, [], session=db_session_with_containers + ) # Assert - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert result == [] def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): @@ -285,10 +307,11 @@ def test_update_partial_member_list_database_error_rollback(self, db_session_wit tenant.id, dataset.id, DatasetPermissionTestDataFactory.build_user_list_payload([existing_member.id]), + session=db_session_with_containers, ) user_list = DatasetPermissionTestDataFactory.build_user_list_payload([replacement_member.id]) rollback_called = {"count": 0} - original_rollback = db.session.rollback + original_rollback = db_session_with_containers.rollback # Act / Assert with pytest.MonkeyPatch.context() as mp: @@ -300,13 +323,17 @@ def _rollback_and_mark(): rollback_called["count"] += 1 original_rollback() - mp.setattr("services.dataset_service.db.session.commit", _raise_commit) - mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark) + mp.setattr(db_session_with_containers, "commit", _raise_commit) + mp.setattr(db_session_with_containers, "rollback", _rollback_and_mark) with pytest.raises(Exception, match="Database connection error"): - DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list) + DatasetPermissionService.update_partial_member_list( + tenant.id, dataset.id, user_list, session=db_session_with_containers + ) # Assert - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert rollback_called["count"] == 1 assert result == [existing_member.id] assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 1 @@ -331,13 +358,17 @@ def test_clear_partial_member_list_success(self, db_session_with_containers: Ses ) dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) - DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + DatasetPermissionService.update_partial_member_list( + tenant.id, dataset.id, users, session=db_session_with_containers + ) # Act - DatasetPermissionService.clear_partial_member_list(dataset.id) + DatasetPermissionService.clear_partial_member_list(dataset.id, session=db_session_with_containers) # Assert - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert result == [] def test_clear_partial_member_list_empty_list(self, db_session_with_containers: Session): @@ -349,10 +380,12 @@ def test_clear_partial_member_list_empty_list(self, db_session_with_containers: dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) # Act - DatasetPermissionService.clear_partial_member_list(dataset.id) + DatasetPermissionService.clear_partial_member_list(dataset.id, session=db_session_with_containers) # Assert - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert result == [] def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers: Session): @@ -371,9 +404,11 @@ def test_clear_partial_member_list_database_error_rollback(self, db_session_with ) dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) - DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + DatasetPermissionService.update_partial_member_list( + tenant.id, dataset.id, users, session=db_session_with_containers + ) rollback_called = {"count": 0} - original_rollback = db.session.rollback + original_rollback = db_session_with_containers.rollback # Act / Assert with pytest.MonkeyPatch.context() as mp: @@ -385,13 +420,15 @@ def _rollback_and_mark(): rollback_called["count"] += 1 original_rollback() - mp.setattr("services.dataset_service.db.session.commit", _raise_commit) - mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark) + mp.setattr(db_session_with_containers, "commit", _raise_commit) + mp.setattr(db_session_with_containers, "rollback", _rollback_and_mark) with pytest.raises(Exception, match="Database connection error"): - DatasetPermissionService.clear_partial_member_list(dataset.id) + DatasetPermissionService.clear_partial_member_list(dataset.id, session=db_session_with_containers) # Assert - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert rollback_called["count"] == 1 assert set(result) == {member_1.id, member_2.id} assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 2 @@ -486,7 +523,9 @@ def test_check_dataset_permission_partial_members_with_permission_success( DatasetService.check_dataset_permission(dataset, user, db_session_with_containers) # Assert - permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + permissions = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert user.id in permissions def test_check_dataset_permission_partial_members_without_permission_error( @@ -547,10 +586,12 @@ def test_check_dataset_operator_permission_partial_members_with_permission_succe DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) # Act (should not raise) - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset, session=db_session_with_containers) # Assert - permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + permissions = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert user.id in permissions def test_check_dataset_operator_permission_partial_members_without_permission_error( @@ -574,4 +615,6 @@ def test_check_dataset_operator_permission_partial_members_without_permission_er # Act & Assert with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + DatasetService.check_dataset_operator_permission( + user=user, dataset=dataset, session=db_session_with_containers + ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 201b65b30d04a5..912e00b0b7d8a1 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -137,6 +137,7 @@ def test_create_internal_dataset_basic_success(self, db_session_with_containers: description="Test description", indexing_technique=None, account=account, + session=db_session_with_containers, ) # Assert @@ -159,6 +160,7 @@ def test_create_internal_dataset_with_economy_indexing(self, db_session_with_con description=None, indexing_technique=IndexTechniqueType.ECONOMY, account=account, + session=db_session_with_containers, ) # Assert @@ -183,6 +185,7 @@ def test_create_internal_dataset_with_high_quality_indexing(self, db_session_wit description=None, indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, + session=db_session_with_containers, ) # Assert @@ -215,6 +218,7 @@ def test_create_dataset_duplicate_name_error(self, db_session_with_containers: S description=None, indexing_technique=None, account=account, + session=db_session_with_containers, ) def test_create_external_dataset_success(self, db_session_with_containers: Session): @@ -236,6 +240,7 @@ def test_create_external_dataset_success(self, db_session_with_containers: Sessi provider="external", external_knowledge_api_id=external_knowledge_api_id, external_knowledge_id=external_knowledge_id, + session=db_session_with_containers, ) # Assert @@ -276,6 +281,7 @@ def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with indexing_technique=IndexTechniqueType.HIGH_QUALITY, account=account, retrieval_model=retrieval_model, + session=db_session_with_containers, ) # Assert @@ -310,6 +316,7 @@ def test_create_internal_dataset_with_high_quality_indexing_custom_embedding( account=account, embedding_model_provider=embedding_provider, embedding_model_name=embedding_model_name, + session=db_session_with_containers, ) # Assert @@ -345,6 +352,7 @@ def test_create_internal_dataset_with_retrieval_model(self, db_session_with_cont indexing_technique=None, account=account, retrieval_model=retrieval_model, + session=db_session_with_containers, ) # Assert @@ -364,6 +372,7 @@ def test_create_internal_dataset_with_custom_permission(self, db_session_with_co indexing_technique=None, account=account, permission=DatasetPermissionEnum.ALL_TEAM, + session=db_session_with_containers, ) # Assert @@ -389,6 +398,7 @@ def test_create_external_dataset_missing_api_id_error(self, db_session_with_cont provider="external", external_knowledge_api_id=external_knowledge_api_id, external_knowledge_id="knowledge-123", + session=db_session_with_containers, ) def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): @@ -410,6 +420,7 @@ def test_create_external_dataset_missing_knowledge_id_error(self, db_session_wit provider="external", external_knowledge_api_id=external_knowledge_api_id, external_knowledge_id=None, + session=db_session_with_containers, ) @@ -431,7 +442,9 @@ def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_con # Act with patch("services.dataset_service.current_user", account): result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=entity, + session=db_session_with_containers, ) # Assert @@ -467,7 +480,9 @@ def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_w ): mock_generate_name.return_value = generated_name result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=entity, + session=db_session_with_containers, ) # Assert @@ -505,7 +520,9 @@ def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_ pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {duplicate_name} already exists"), ): DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=entity, + session=db_session_with_containers, ) def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers: Session): @@ -523,7 +540,9 @@ def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_wit # Act with patch("services.dataset_service.current_user", account): result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=entity, + session=db_session_with_containers, ) # Assert @@ -550,7 +569,9 @@ def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_contai # Act with patch("services.dataset_service.current_user", account): result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + tenant_id=tenant.id, + rag_pipeline_dataset_create_entity=entity, + session=db_session_with_containers, ) # Assert @@ -580,7 +601,9 @@ def test_update_dataset_duplicate_name_error(self, db_session_with_containers: S # Act / Assert with pytest.raises(ValueError, match="Dataset name already exists"): - DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account) + DatasetService.update_dataset( + source_dataset.id, {"name": "Existing Dataset"}, account, session=db_session_with_containers + ) def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset that already has documents.""" @@ -599,7 +622,7 @@ def test_delete_dataset_with_documents_success(self, db_session_with_containers: # Act with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: - result = DatasetService.delete_dataset(dataset.id, account) + result = DatasetService.delete_dataset(dataset.id, account, session=db_session_with_containers) # Assert assert result is True @@ -620,7 +643,7 @@ def test_delete_empty_dataset_success(self, db_session_with_containers: Session) # Act with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: - result = DatasetService.delete_dataset(dataset.id, account) + result = DatasetService.delete_dataset(dataset.id, account, session=db_session_with_containers) # Assert assert result is True @@ -641,7 +664,7 @@ def test_delete_dataset_with_partial_none_values(self, db_session_with_container # Act with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: - result = DatasetService.delete_dataset(dataset.id, account) + result = DatasetService.delete_dataset(dataset.id, account, session=db_session_with_containers) # Assert assert result is True @@ -670,7 +693,7 @@ def test_get_dataset_retrieval_configuration(self, db_session_with_containers: S ) # Act - result = DatasetService.get_dataset(dataset.id) + result = DatasetService.get_dataset(dataset.id, session=db_session_with_containers) # Assert assert result is not None @@ -702,7 +725,7 @@ def test_update_dataset_retrieval_configuration(self, db_session_with_containers } # Act - result = DatasetService.update_dataset(dataset.id, update_data, account) + result = DatasetService.update_dataset(dataset.id, update_data, account, session=db_session_with_containers) # Assert db_session_with_containers.refresh(dataset) @@ -730,7 +753,7 @@ def test_pause_document_success(self, db_session_with_containers: Session): with patch("services.dataset_service.current_user") as mock_user: mock_user.id = account.id - DocumentService.pause_document(doc) + DocumentService.pause_document(doc, session=db_session_with_containers) db_session_with_containers.refresh(doc) assert doc.is_paused is True @@ -750,7 +773,7 @@ def test_pause_document_invalid_status_error(self, db_session_with_containers: S with patch("services.dataset_service.current_user") as mock_user: mock_user.id = account.id with pytest.raises(DocumentIndexingError): - DocumentService.pause_document(doc) + DocumentService.pause_document(doc, session=db_session_with_containers) def test_recover_document_success(self, db_session_with_containers: Session): from extensions.ext_redis import redis_client @@ -761,11 +784,11 @@ def test_recover_document_success(self, db_session_with_containers: Session): # Pause first with patch("services.dataset_service.current_user") as mock_user: mock_user.id = account.id - DocumentService.pause_document(doc) + DocumentService.pause_document(doc, session=db_session_with_containers) # Recover with patch("services.dataset_service.recover_document_indexing_task") as recover_task: - DocumentService.recover_document(doc) + DocumentService.recover_document(doc, session=db_session_with_containers) db_session_with_containers.refresh(doc) assert doc.is_paused is False @@ -795,7 +818,7 @@ def test_retry_document_indexing_success(self, db_session_with_containers: Sessi patch("services.dataset_service.retry_document_indexing_task") as retry_task, ): mock_user.id = account.id - DocumentService.retry_document(dataset.id, [doc1, doc2]) + DocumentService.retry_document(dataset.id, [doc1, doc2], session=db_session_with_containers) db_session_with_containers.refresh(doc1) db_session_with_containers.refresh(doc2) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py index c1d088755c12f5..2f1022a47ca669 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -196,7 +196,11 @@ def test_batch_update_enable_documents_success(self, db_session_with_containers: # Act DocumentService.batch_update_document_status( - dataset=dataset, document_ids=document_ids, action="enable", user=user + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + session=db_session_with_containers, ) # Assert @@ -228,6 +232,7 @@ def test_batch_update_enable_already_enabled_document_skipped( document_ids=[document.id], action="enable", user=user, + session=db_session_with_containers, ) # Assert @@ -256,6 +261,7 @@ def test_batch_update_disable_documents_success(self, db_session_with_containers document_ids=document_ids, action="disable", user=user, + session=db_session_with_containers, ) # Assert @@ -291,6 +297,7 @@ def test_batch_update_disable_already_disabled_document_skipped( document_ids=[disabled_doc.id], action="disable", user=user, + session=db_session_with_containers, ) # Assert @@ -321,6 +328,7 @@ def test_batch_update_disable_non_completed_document_error( document_ids=[non_completed_doc.id], action="disable", user=user, + session=db_session_with_containers, ) def test_batch_update_archive_documents_success(self, db_session_with_containers: Session, patched_dependencies): @@ -338,6 +346,7 @@ def test_batch_update_archive_documents_success(self, db_session_with_containers document_ids=[document.id], action="archive", user=user, + session=db_session_with_containers, ) # Assert @@ -364,6 +373,7 @@ def test_batch_update_archive_already_archived_document_skipped( document_ids=[document.id], action="archive", user=user, + session=db_session_with_containers, ) # Assert @@ -389,6 +399,7 @@ def test_batch_update_archive_disabled_document_no_index_removal( document_ids=[document.id], action="archive", user=user, + session=db_session_with_containers, ) # Assert @@ -412,6 +423,7 @@ def test_batch_update_unarchive_documents_success(self, db_session_with_containe document_ids=[document.id], action="un_archive", user=user, + session=db_session_with_containers, ) # Assert @@ -439,6 +451,7 @@ def test_batch_update_unarchive_already_unarchived_document_skipped( document_ids=[document.id], action="un_archive", user=user, + session=db_session_with_containers, ) # Assert @@ -464,6 +477,7 @@ def test_batch_update_unarchive_disabled_document_no_index_addition( document_ids=[document.id], action="un_archive", user=user, + session=db_session_with_containers, ) # Assert @@ -495,6 +509,7 @@ def test_batch_update_document_indexing_error_redis_cache_hit( document_ids=[document.id], action="enable", user=user, + session=db_session_with_containers, ) assert "test_document.pdf" in str(exc_info.value) @@ -517,6 +532,7 @@ def test_batch_update_async_task_error_handling(self, db_session_with_containers document_ids=[document.id], action="enable", user=user, + session=db_session_with_containers, ) db_session_with_containers.refresh(document) @@ -531,7 +547,11 @@ def test_batch_update_empty_document_list(self, db_session_with_containers: Sess # Act result = DocumentService.batch_update_document_status( - dataset=dataset, document_ids=[], action="enable", user=user + dataset=dataset, + document_ids=[], + action="enable", + user=user, + session=db_session_with_containers, ) # Assert @@ -552,6 +572,7 @@ def test_batch_update_document_not_found_skipped(self, db_session_with_container document_ids=[missing_document_id], action="enable", user=user, + session=db_session_with_containers, ) # Assert @@ -590,6 +611,7 @@ def test_batch_update_mixed_document_states_and_actions( document_ids=document_ids, action="enable", user=user, + session=db_session_with_containers, ) # Assert @@ -628,6 +650,7 @@ def test_batch_update_large_document_list_performance( document_ids=document_ids, action="enable", user=user, + session=db_session_with_containers, ) # Assert @@ -679,6 +702,7 @@ def test_batch_update_mixed_document_states_complex_scenario( document_ids=document_ids, action="enable", user=user, + session=db_session_with_containers, ) # Assert @@ -709,5 +733,9 @@ def test_batch_update_invalid_action_raises_value_error( with pytest.raises(ValueError, match="Invalid action"): DocumentService.batch_update_document_status( - dataset=dataset, document_ids=[doc.id], action="invalid_action", user=user + dataset=dataset, + document_ids=[doc.id], + action="invalid_action", + user=user, + session=db_session_with_containers, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py index 08de79f4b7e89d..292ae46190a76e 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_create_dataset.py @@ -58,4 +58,5 @@ def test_create_rag_pipeline_dataset_raises_when_current_user_id_is_none(self, d DatasetService.create_empty_rag_pipeline_dataset( tenant_id=tenant.id, rag_pipeline_dataset_create_entity=self._build_entity(), + session=db_session_with_containers, ) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py index c43a5d59789abf..1c5ec6835bda6e 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -130,7 +130,7 @@ def test_delete_dataset_with_documents_success(self, db_session_with_containers: "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", autospec=True, ) as clean_dataset_delay: - result = DatasetService.delete_dataset(dataset.id, owner) + result = DatasetService.delete_dataset(dataset.id, owner, session=db_session_with_containers) # Assert db_session_with_containers.expire_all() @@ -166,7 +166,7 @@ def test_delete_empty_dataset_success(self, db_session_with_containers: Session) "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", autospec=True, ) as clean_dataset_delay: - result = DatasetService.delete_dataset(dataset.id, owner) + result = DatasetService.delete_dataset(dataset.id, owner, session=db_session_with_containers) # Assert db_session_with_containers.expire_all() @@ -194,7 +194,7 @@ def test_delete_dataset_with_partial_none_values(self, db_session_with_container "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", autospec=True, ) as clean_dataset_delay: - result = DatasetService.delete_dataset(dataset.id, owner) + result = DatasetService.delete_dataset(dataset.id, owner, session=db_session_with_containers) # Assert db_session_with_containers.expire_all() @@ -222,7 +222,7 @@ def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_se "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", autospec=True, ) as clean_dataset_delay: - result = DatasetService.delete_dataset(dataset.id, owner) + result = DatasetService.delete_dataset(dataset.id, owner, session=db_session_with_containers) # Assert db_session_with_containers.expire_all() @@ -241,7 +241,7 @@ def test_delete_dataset_not_found(self, db_session_with_containers: Session): "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", autospec=True, ) as clean_dataset_delay: - result = DatasetService.delete_dataset(missing_dataset_id, owner) + result = DatasetService.delete_dataset(missing_dataset_id, owner, session=db_session_with_containers) # Assert assert result is False diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py index 946ac6619402dd..ab5ba173792ece 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_document.py @@ -125,14 +125,14 @@ def current_user_mock(): def test_get_document_returns_none_when_document_id_is_missing(db_session_with_containers: Session): dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) - assert DocumentService.get_document(dataset.id, None) is None + assert DocumentService.get_document(dataset.id, None, session=db_session_with_containers) is None def test_get_document_queries_by_dataset_and_document_id(db_session_with_containers: Session): dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) document = DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset) - result = DocumentService.get_document(dataset.id, document.id) + result = DocumentService.get_document(dataset.id, document.id, session=db_session_with_containers) assert result is not None assert result.id == document.id @@ -141,7 +141,7 @@ def test_get_document_queries_by_dataset_and_document_id(db_session_with_contain def test_get_documents_by_ids_returns_empty_for_empty_input(db_session_with_containers: Session): dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) - result = DocumentService.get_documents_by_ids(dataset.id, []) + result = DocumentService.get_documents_by_ids(dataset.id, [], session=db_session_with_containers) assert result == [] @@ -156,7 +156,7 @@ def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers position=2, ) - result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id]) + result = DocumentService.get_documents_by_ids(dataset.id, [doc_a.id, doc_b.id], db_session_with_containers) assert {document.id for document in result} == {doc_a.id, doc_b.id} @@ -164,7 +164,7 @@ def test_get_documents_by_ids_uses_single_batch_query(db_session_with_containers def test_update_documents_need_summary_returns_zero_for_empty_input(db_session_with_containers: Session): dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) - assert DocumentService.update_documents_need_summary(dataset.id, []) == 0 + assert DocumentService.update_documents_need_summary(dataset.id, [], db_session_with_containers) == 0 def test_update_documents_need_summary_updates_matching_non_qa_documents(db_session_with_containers: Session): @@ -185,6 +185,7 @@ def test_update_documents_need_summary_updates_matching_non_qa_documents(db_sess updated_count = DocumentService.update_documents_need_summary( dataset.id, [paragraph_doc.id, qa_doc.id], + db_session_with_containers, need_summary=False, ) @@ -212,7 +213,7 @@ def test_get_document_download_url_uses_signed_url_helper(db_session_with_contai ) with patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url: - result = DocumentService.get_document_download_url(document) + result = DocumentService.get_document_download_url(document, session=db_session_with_containers) assert result == "signed-url" get_url.assert_called_once_with(upload_file_id=upload_file.id, as_attachment=True) @@ -282,7 +283,7 @@ def test_get_upload_file_for_upload_file_document_raises_when_file_service_retur with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): with pytest.raises(NotFound, match="Uploaded file not found"): - DocumentService._get_upload_file_for_upload_file_document(document) + DocumentService._get_upload_file_for_upload_file_document(document, session=db_session_with_containers) def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session_with_containers: Session): @@ -298,7 +299,7 @@ def test_get_upload_file_for_upload_file_document_returns_upload_file(db_session data_source_info={"upload_file_id": upload_file.id}, ) - result = DocumentService._get_upload_file_for_upload_file_document(document) + result = DocumentService._get_upload_file_for_upload_file_document(document, session=db_session_with_containers) assert result.id == upload_file.id @@ -313,6 +314,7 @@ def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_doc dataset_id=dataset.id, document_ids=[str(uuid4())], tenant_id=dataset.tenant_id, + session=db_session_with_containers, ) @@ -337,6 +339,7 @@ def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_a dataset_id=dataset.id, document_ids=[document.id], tenant_id=dataset.tenant_id, + session=db_session_with_containers, ) @@ -355,6 +358,7 @@ def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload dataset_id=dataset.id, document_ids=[document.id], tenant_id=dataset.tenant_id, + session=db_session_with_containers, ) @@ -390,6 +394,7 @@ def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed dataset_id=dataset.id, document_ids=[document_a.id, document_b.id], tenant_id=dataset.tenant_id, + session=db_session_with_containers, ) assert mapping[document_a.id].id == upload_file_a.id @@ -397,7 +402,7 @@ def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset( - current_user_mock, flask_app_with_containers + current_user_mock, flask_app_with_containers, db_session_with_containers: Session ): with flask_app_with_containers.app_context(): with pytest.raises(NotFound, match="Dataset not found"): @@ -406,6 +411,7 @@ def test_prepare_document_batch_download_zip_raises_not_found_for_missing_datase document_ids=[str(uuid4())], tenant_id=current_user_mock.current_tenant_id, current_user=current_user_mock, + session=db_session_with_containers, ) @@ -429,6 +435,7 @@ def test_prepare_document_batch_download_zip_translates_permission_error_to_forb document_ids=[], tenant_id=current_user_mock.current_tenant_id, current_user=current_user_mock, + session=db_session_with_containers, ) @@ -470,6 +477,7 @@ def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_o document_ids=[document_b.id, document_a.id], tenant_id=current_user_mock.current_tenant_id, current_user=current_user_mock, + session=db_session_with_containers, ) assert [upload_file.id for upload_file in upload_files] == [upload_file_b.id, upload_file_a.id] @@ -490,7 +498,7 @@ def test_get_document_by_dataset_id_returns_enabled_documents(db_session_with_co enabled=False, ) - result = DocumentService.get_document_by_dataset_id(dataset.id) + result = DocumentService.get_document_by_dataset_id(dataset.id, session=db_session_with_containers) assert [document.id for document in result] == [enabled_document.id] @@ -513,7 +521,7 @@ def test_get_working_documents_by_dataset_id_returns_completed_enabled_unarchive indexing_status=IndexingStatus.ERROR, ) - result = DocumentService.get_working_documents_by_dataset_id(dataset.id) + result = DocumentService.get_working_documents_by_dataset_id(dataset.id, session=db_session_with_containers) assert [document.id for document in result] == [available_document.id] @@ -538,7 +546,7 @@ def test_get_error_documents_by_dataset_id_returns_error_and_paused_documents(db indexing_status=IndexingStatus.COMPLETED, ) - result = DocumentService.get_error_documents_by_dataset_id(dataset.id) + result = DocumentService.get_error_documents_by_dataset_id(dataset.id, session=db_session_with_containers) assert {document.id for document in result} == {error_document.id, paused_document.id} @@ -561,7 +569,7 @@ def test_get_batch_documents_filters_by_current_user_tenant(db_session_with_cont with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: current_user.current_tenant_id = dataset.tenant_id - result = DocumentService.get_batch_documents(dataset.id, batch) + result = DocumentService.get_batch_documents(dataset.id, batch, session=db_session_with_containers) assert [document.id for document in result] == [matching_document.id] @@ -574,7 +582,7 @@ def test_get_document_file_detail_returns_upload_file(db_session_with_containers created_by=dataset.created_by, ) - result = DocumentService.get_document_file_detail(upload_file.id) + result = DocumentService.get_document_file_detail(upload_file.id, session=db_session_with_containers) assert result is not None assert result.id == upload_file.id @@ -594,7 +602,7 @@ def test_delete_document_emits_signal_and_commits(db_session_with_containers: Se ) with patch("services.dataset_service.document_was_deleted.send") as signal_send: - DocumentService.delete_document(document) + DocumentService.delete_document(document, session=db_session_with_containers) assert db_session_with_containers.get(Document, document.id) is None signal_send.assert_called_once_with( @@ -609,7 +617,7 @@ def test_delete_documents_ignores_empty_input(db_session_with_containers: Sessio dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) with patch("services.dataset_service.batch_clean_document_task.delay") as delay: - DocumentService.delete_documents(dataset, []) + DocumentService.delete_documents(dataset, [], session=db_session_with_containers) delay.assert_not_called() @@ -643,7 +651,7 @@ def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(db_session_wi ) with patch("services.dataset_service.batch_clean_document_task.delay") as delay: - DocumentService.delete_documents(dataset, [document_a.id, document_b.id]) + DocumentService.delete_documents(dataset, [document_a.id, document_b.id], session=db_session_with_containers) assert db_session_with_containers.get(Document, document_a.id) is None assert db_session_with_containers.get(Document, document_b.id) is None @@ -658,10 +666,10 @@ def test_get_documents_position_returns_next_position_when_documents_exist(db_se dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) DocumentServiceIntegrationFactory.create_document(db_session_with_containers, dataset=dataset, position=3) - assert DocumentService.get_documents_position(dataset.id) == 4 + assert DocumentService.get_documents_position(dataset.id, session=db_session_with_containers) == 4 def test_get_documents_position_defaults_to_one_when_dataset_is_empty(db_session_with_containers: Session): dataset = DocumentServiceIntegrationFactory.create_dataset(db_session_with_containers) - assert DocumentService.get_documents_position(dataset.id) == 1 + assert DocumentService.get_documents_position(dataset.id, session=db_session_with_containers) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py index ba5883f408db5e..6b32273624b2b2 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_permissions.py @@ -182,7 +182,7 @@ class TestDatasetServicePermissionsAndLifecycle: def test_delete_dataset_returns_false_when_dataset_is_missing(self, db_session_with_containers: Session): owner, _tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) - result = DatasetService.delete_dataset(str(uuid4()), user=owner) + result = DatasetService.delete_dataset(str(uuid4()), user=owner, session=db_session_with_containers) assert result is False @@ -195,7 +195,7 @@ def test_delete_dataset_checks_permission_and_deletes_dataset(self, db_session_w ) with patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal: - result = DatasetService.delete_dataset(dataset.id, user=owner) + result = DatasetService.delete_dataset(dataset.id, user=owner, session=db_session_with_containers) assert result is True assert db_session_with_containers.get(Dataset, dataset.id) is None @@ -213,7 +213,7 @@ def test_dataset_use_check_returns_true_when_join_exists(self, db_session_with_c dataset_id=dataset.id, ) - assert DatasetService.dataset_use_check(dataset.id) is True + assert DatasetService.dataset_use_check(dataset.id, session=db_session_with_containers) is True def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with_containers: Session): owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) @@ -223,7 +223,7 @@ def test_dataset_use_check_returns_false_when_join_missing(self, db_session_with created_by=owner.id, ) - assert DatasetService.dataset_use_check(dataset.id) is False + assert DatasetService.dataset_use_check(dataset.id, session=db_session_with_containers) is False def test_check_dataset_permission_rejects_cross_tenant_access(self, db_session_with_containers: Session): owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) @@ -320,7 +320,9 @@ def test_check_dataset_operator_permission_rejects_only_me_for_non_creator( ) with pytest.raises(NoPermissionError, match="do not have permission"): - DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + DatasetService.check_dataset_operator_permission( + user=operator, dataset=dataset, session=db_session_with_containers + ) def test_check_dataset_operator_permission_rejects_partial_team_without_binding( self, db_session_with_containers: Session @@ -339,7 +341,9 @@ def test_check_dataset_operator_permission_rejects_partial_team_without_binding( ) with pytest.raises(NoPermissionError, match="do not have permission"): - DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + DatasetService.check_dataset_operator_permission( + user=operator, dataset=dataset, session=db_session_with_containers + ) def test_check_dataset_operator_permission_allows_partial_team_with_binding( self, db_session_with_containers: Session @@ -363,12 +367,16 @@ def test_check_dataset_operator_permission_allows_partial_team_with_binding( account_id=operator.id, ) - DatasetService.check_dataset_operator_permission(user=operator, dataset=dataset) + DatasetService.check_dataset_operator_permission( + user=operator, dataset=dataset, session=db_session_with_containers + ) - def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self, flask_app_with_containers: Flask): + def test_update_dataset_api_status_raises_not_found_for_missing_dataset( + self, flask_app_with_containers: Flask, db_session_with_containers: Session + ): with flask_app_with_containers.app_context(): with pytest.raises(NotFound, match="Dataset not found"): - DatasetService.update_dataset_api_status(str(uuid4()), True) + DatasetService.update_dataset_api_status(str(uuid4()), True, session=db_session_with_containers) def test_update_dataset_api_status_requires_current_user_id(self, db_session_with_containers: Session): owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) @@ -381,7 +389,7 @@ def test_update_dataset_api_status_requires_current_user_id(self, db_session_wit with patch("services.dataset_service.current_user", SimpleNamespace(id=None)): with pytest.raises(ValueError, match="Current user or current user id not found"): - DatasetService.update_dataset_api_status(dataset.id, True) + DatasetService.update_dataset_api_status(dataset.id, True, session=db_session_with_containers) def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_with_containers: Session): owner, tenant = DatasetPermissionIntegrationFactory.create_account_with_tenant(db_session_with_containers) @@ -397,7 +405,7 @@ def test_update_dataset_api_status_updates_fields_and_commits(self, db_session_w patch("services.dataset_service.current_user", owner), patch("services.dataset_service.naive_utc_now", return_value=now), ): - DatasetService.update_dataset_api_status(dataset.id, True) + DatasetService.update_dataset_api_status(dataset.id, True, session=db_session_with_containers) db_session_with_containers.refresh(dataset) assert dataset.enable_api is True @@ -416,7 +424,7 @@ def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled( patch("services.dataset_service.current_user", owner), patch("services.dataset_service.FeatureService.get_features", return_value=features), ): - result = DatasetService.get_dataset_auto_disable_logs(str(uuid4())) + result = DatasetService.get_dataset_auto_disable_logs(str(uuid4()), session=db_session_with_containers) assert result == {"document_ids": [], "count": 0} @@ -447,7 +455,7 @@ def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self, db_sess patch("services.dataset_service.current_user", owner), patch("services.dataset_service.FeatureService.get_features", return_value=features), ): - result = DatasetService.get_dataset_auto_disable_logs(dataset.id) + result = DatasetService.get_dataset_auto_disable_logs(dataset.id, session=db_session_with_containers) assert result["count"] == 2 assert len(result["document_ids"]) == 2 @@ -461,12 +469,16 @@ def test_get_dataset_collection_binding_returns_existing_binding(self, db_sessio model_name="model", ) - result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + result = DatasetCollectionBindingService.get_dataset_collection_binding( + "provider", "model", session=db_session_with_containers + ) assert result.id == binding.id def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_session_with_containers: Session): - result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "missing-model") + result = DatasetCollectionBindingService.get_dataset_collection_binding( + "provider", "missing-model", session=db_session_with_containers + ) persisted = db_session_with_containers.get(DatasetCollectionBinding, result.id) assert persisted is not None @@ -475,10 +487,14 @@ def test_get_dataset_collection_binding_creates_binding_when_missing(self, db_se assert persisted.type == "dataset" assert persisted.collection_name - def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self, flask_app_with_containers: Flask): + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing( + self, flask_app_with_containers: Flask, db_session_with_containers: Session + ): with flask_app_with_containers.app_context(): with pytest.raises(ValueError, match="Dataset collection binding not found"): - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(str(uuid4())) + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + str(uuid4()), session=db_session_with_containers + ) def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_session_with_containers: Session): binding = DatasetPermissionIntegrationFactory.create_collection_binding( @@ -487,7 +503,9 @@ def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self, db_ model_name="model", ) - result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id) + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + binding.id, session=db_session_with_containers + ) assert result.id == binding.id @@ -516,7 +534,9 @@ def test_get_dataset_partial_member_list_returns_scalar_results(self, db_session account_id=member_b.id, ) - result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + result = DatasetPermissionService.get_dataset_partial_member_list( + dataset.id, session=db_session_with_containers + ) assert set(result) == {member_a.id, member_b.id} @@ -542,33 +562,44 @@ def test_update_partial_member_list_replaces_permissions_and_commits(self, db_se tenant.id, dataset.id, [{"user_id": member_a.id}, {"user_id": member_b.id}], + session=db_session_with_containers, ) permissions = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() assert {permission.account_id for permission in permissions} == {member_a.id, member_b.id} - def test_check_permission_requires_dataset_editor(self): + def test_check_permission_requires_dataset_editor(self, db_session_with_containers: Session): user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) with pytest.raises(NoPermissionError, match="does not have permission"): - DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ALL_TEAM, []) + DatasetPermissionService.check_permission( + user, dataset, DatasetPermissionEnum.ALL_TEAM, [], session=db_session_with_containers + ) - def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode( + self, db_session_with_containers: Session + ): user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.ALL_TEAM) with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): - DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.ONLY_ME, []) + DatasetPermissionService.check_permission( + user, dataset, DatasetPermissionEnum.ONLY_ME, [], session=db_session_with_containers + ) - def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + def test_check_permission_requires_partial_member_list_for_partial_members_mode( + self, db_session_with_containers: Session + ): user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) with pytest.raises(ValueError, match="Partial member list is required"): - DatasetPermissionService.check_permission(user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, []) + DatasetPermissionService.check_permission( + user, dataset, DatasetPermissionEnum.PARTIAL_TEAM, [], session=db_session_with_containers + ) - def test_check_permission_rejects_dataset_operator_member_list_changes(self): + def test_check_permission_rejects_dataset_operator_member_list_changes(self, db_session_with_containers: Session): user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) @@ -579,9 +610,12 @@ def test_check_permission_rejects_dataset_operator_member_list_changes(self): dataset, DatasetPermissionEnum.PARTIAL_TEAM, [{"user_id": "user-2"}], + session=db_session_with_containers, ) - def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged( + self, db_session_with_containers: Session + ): user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) dataset = SimpleNamespace(id="dataset-1", permission=DatasetPermissionEnum.PARTIAL_TEAM) @@ -591,6 +625,7 @@ def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged( dataset, DatasetPermissionEnum.PARTIAL_TEAM, [{"user_id": "user-1"}], + session=db_session_with_containers, ) def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_session_with_containers: Session): @@ -609,7 +644,7 @@ def test_clear_partial_member_list_deletes_permissions_and_commits(self, db_sess account_id=member.id, ) - DatasetPermissionService.clear_partial_member_list(dataset.id) + DatasetPermissionService.clear_partial_member_list(dataset.id, session=db_session_with_containers) remaining = db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).all() assert remaining == [] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index 05632b1ec2a1af..b6768d2ca262af 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -548,7 +548,7 @@ def test_get_dataset_success(self, db_session_with_containers: Session): ) # Act - result = DatasetService.get_dataset(dataset.id) + result = DatasetService.get_dataset(dataset.id, session=db_session_with_containers) # Assert assert result is not None @@ -560,7 +560,7 @@ def test_get_dataset_not_found(self, db_session_with_containers: Session): dataset_id = str(uuid4()) # Act - result = DatasetService.get_dataset(dataset_id) + result = DatasetService.get_dataset(dataset_id, session=db_session_with_containers) # Assert assert result is None @@ -639,7 +639,7 @@ def test_get_process_rules_with_existing_rule(self, db_session_with_containers: ) # Act - result = DatasetService.get_process_rules(dataset.id) + result = DatasetService.get_process_rules(dataset.id, session=db_session_with_containers) # Assert assert result["mode"] == "custom" @@ -654,7 +654,7 @@ def test_get_process_rules_without_existing_rule(self, db_session_with_container ) # Act - result = DatasetService.get_process_rules(dataset.id) + result = DatasetService.get_process_rules(dataset.id, session=db_session_with_containers) # Assert assert result["mode"] == DocumentService.DEFAULT_RULES["mode"] @@ -724,7 +724,7 @@ def test_get_related_apps_success(self, db_session_with_containers: Session): DatasetRetrievalTestDataFactory.create_app_dataset_join(db_session_with_containers, dataset.id) # Act - result = DatasetService.get_related_apps(dataset.id) + result = DatasetService.get_related_apps(dataset.id, session=db_session_with_containers) # Assert assert len(result) == 2 @@ -739,7 +739,7 @@ def test_get_related_apps_empty_result(self, db_session_with_containers: Session ) # Act - result = DatasetService.get_related_apps(dataset.id) + result = DatasetService.get_related_apps(dataset.id, session=db_session_with_containers) # Assert assert result == [] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index ac0483a45d7f7c..d9fb23e8e33e75 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -189,7 +189,7 @@ def test_update_external_dataset_success(self, db_session_with_containers: Sessi "external_knowledge_api_id": external_api.id, } - result = DatasetService.update_dataset(dataset.id, update_data, user) + result = DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) db_session_with_containers.refresh(dataset) updated_binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() @@ -221,7 +221,7 @@ def test_update_external_dataset_missing_knowledge_id_error(self, db_session_wit update_data = {"name": "new_name", "external_knowledge_api_id": str(uuid4())} with pytest.raises(ValueError) as context: - DatasetService.update_dataset(dataset.id, update_data, user) + DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) assert "External knowledge id is required" in str(context.value) db_session_with_containers.rollback() @@ -245,7 +245,7 @@ def test_update_external_dataset_missing_api_id_error(self, db_session_with_cont update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"} with pytest.raises(ValueError) as context: - DatasetService.update_dataset(dataset.id, update_data, user) + DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) assert "External knowledge api id is required" in str(context.value) db_session_with_containers.rollback() @@ -272,7 +272,7 @@ def test_update_external_dataset_binding_not_found_error(self, db_session_with_c } with pytest.raises(ValueError) as context: - DatasetService.update_dataset(dataset.id, update_data, user) + DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) assert "External knowledge binding not found" in str(context.value) db_session_with_containers.rollback() @@ -303,7 +303,7 @@ def test_update_internal_dataset_basic_success(self, db_session_with_containers: "embedding_model": "text-embedding-ada-002", } - result = DatasetService.update_dataset(dataset.id, update_data, user) + result = DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" @@ -338,7 +338,7 @@ def test_update_internal_dataset_filter_none_values(self, db_session_with_contai "embedding_model": None, } - result = DatasetService.update_dataset(dataset.id, update_data, user) + result = DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" @@ -371,7 +371,7 @@ def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_ } with patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task: - result = DatasetService.update_dataset(dataset.id, update_data, user) + result = DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) mock_task.delay.assert_called_once_with(dataset.id, "remove") db_session_with_containers.refresh(dataset) @@ -418,7 +418,7 @@ def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_ses mock_model_manager.return_value.get_model_instance.return_value = embedding_model mock_get_binding.return_value = binding - result = DatasetService.update_dataset(dataset.id, update_data, user) + result = DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) mock_model_manager.return_value.get_model_instance.assert_called_once_with( tenant_id=tenant.id, @@ -426,7 +426,7 @@ def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_ses model_type=ModelType.TEXT_EMBEDDING, model="text-embedding-ada-002", ) - mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002") + mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002", db_session_with_containers) mock_task.delay.assert_called_once_with(dataset.id, "add") db_session_with_containers.refresh(dataset) @@ -462,7 +462,7 @@ def test_update_internal_dataset_keep_existing_embedding_model_when_indexing_tec "retrieval_model": "new_model", } - result = DatasetService.update_dataset(dataset.id, update_data, user) + result = DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" @@ -514,7 +514,7 @@ def test_update_internal_dataset_embedding_model_update(self, db_session_with_co mock_model_manager.return_value.get_model_instance.return_value = embedding_model mock_get_binding.return_value = binding - result = DatasetService.update_dataset(dataset.id, update_data, user) + result = DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) mock_model_manager.return_value.get_model_instance.assert_called_once_with( tenant_id=tenant.id, @@ -522,7 +522,7 @@ def test_update_internal_dataset_embedding_model_update(self, db_session_with_co model_type=ModelType.TEXT_EMBEDDING, model="text-embedding-3-small", ) - mock_get_binding.assert_called_once_with("openai", "text-embedding-3-small") + mock_get_binding.assert_called_once_with("openai", "text-embedding-3-small", db_session_with_containers) mock_task.delay.assert_called_once_with(dataset.id, "update") mock_regenerate_task.delay.assert_called_once_with( dataset.id, @@ -545,7 +545,7 @@ def test_update_dataset_not_found_error(self, db_session_with_containers: Sessio update_data = {"name": "new_name"} with pytest.raises(ValueError) as context: - DatasetService.update_dataset(str(uuid4()), update_data, user) + DatasetService.update_dataset(str(uuid4()), update_data, user, session=db_session_with_containers) assert "Dataset not found" in str(context.value) @@ -568,7 +568,7 @@ def test_update_dataset_permission_error(self, db_session_with_containers: Sessi update_data = {"name": "new_name"} with pytest.raises(NoPermissionError): - DatasetService.update_dataset(dataset.id, update_data, outsider) + DatasetService.update_dataset(dataset.id, update_data, outsider, session=db_session_with_containers) def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers: Session): """Test error when embedding model is not available.""" @@ -595,6 +595,6 @@ def test_update_internal_dataset_embedding_model_error(self, db_session_with_con mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available") with pytest.raises(Exception) as context: - DatasetService.update_dataset(dataset.id, update_data, user) + DatasetService.update_dataset(dataset.id, update_data, user, session=db_session_with_containers) assert "No Embedding Model available".lower() in str(context.value).lower() diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py index 34532ed7f81f17..6a0ea17560a953 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -118,7 +118,7 @@ def test_rename_document_success(db_session_with_containers, mock_env): ) # Act - result = DocumentService.rename_document(dataset.id, document_id, new_name) + result = DocumentService.rename_document(dataset.id, document_id, new_name, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -147,7 +147,7 @@ def test_rename_document_with_built_in_fields(db_session_with_containers, mock_e ) # Act - DocumentService.rename_document(dataset.id, document.id, new_name) + DocumentService.rename_document(dataset.id, document.id, new_name, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -179,7 +179,7 @@ def test_rename_document_updates_upload_file_when_present(db_session_with_contai ) # Act - DocumentService.rename_document(dataset.id, document.id, new_name) + DocumentService.rename_document(dataset.id, document.id, new_name, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -210,7 +210,7 @@ def test_rename_document_does_not_update_upload_file_when_missing_id(db_session_ ) # Act - DocumentService.rename_document(dataset.id, document.id, new_name) + DocumentService.rename_document(dataset.id, document.id, new_name, session=db_session_with_containers) # Assert db_session_with_containers.refresh(document) @@ -226,7 +226,7 @@ def test_rename_document_dataset_not_found(db_session_with_containers, mock_env) # Act / Assert with pytest.raises(ValueError, match="Dataset not found"): - DocumentService.rename_document(missing_dataset_id, str(uuid4()), "x") + DocumentService.rename_document(missing_dataset_id, str(uuid4()), "x", session=db_session_with_containers) def test_rename_document_not_found(db_session_with_containers, mock_env): @@ -236,7 +236,7 @@ def test_rename_document_not_found(db_session_with_containers, mock_env): # Act / Assert with pytest.raises(ValueError, match="Document not found"): - DocumentService.rename_document(dataset.id, str(uuid4()), "x") + DocumentService.rename_document(dataset.id, str(uuid4()), "x", session=db_session_with_containers) def test_rename_document_permission_denied_when_tenant_mismatch(db_session_with_containers, mock_env): @@ -251,4 +251,4 @@ def test_rename_document_permission_denied_when_tenant_mismatch(db_session_with_ # Act / Assert with pytest.raises(ValueError, match="No permission"): - DocumentService.rename_document(dataset.id, document.id, "x") + DocumentService.rename_document(dataset.id, document.id, "x", session=db_session_with_containers) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py index 0b4ce39bafbef1..e8bb6f88674c92 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -1,5 +1,5 @@ import inspect -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest from flask import Flask @@ -742,7 +742,7 @@ def test_retry_archived_document_skipped(self, app: Flask, patch_tenant, patch_d resp, status = method(api, "ds-1") assert status == 204 - retry_mock.assert_called_once_with("ds-1", []) + retry_mock.assert_called_once_with("ds-1", [], ANY) def test_retry_success(self, app: Flask, patch_tenant, patch_dataset): api = DocumentRetryApi() @@ -771,7 +771,7 @@ def test_retry_success(self, app: Flask, patch_tenant, patch_dataset): response, status = method(api, "ds-1") assert status == 204 - retry_mock.assert_called_once_with("ds-1", [document]) + retry_mock.assert_called_once_with("ds-1", [document], ANY) def test_retry_skips_completed_document(self, app: Flask, patch_tenant, patch_dataset): api = DocumentRetryApi() @@ -796,7 +796,7 @@ def test_retry_skips_completed_document(self, app: Flask, patch_tenant, patch_da response, status = method(api, "ds-1") assert status == 204 - retry_mock.assert_called_once_with("ds-1", []) + retry_mock.assert_called_once_with("ds-1", [], ANY) class TestDocumentPipelineExecutionLogApi: diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py index 3b9a1dcc5ea0ac..6288fe363f559c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py @@ -107,7 +107,7 @@ def _wire_common_success_mocks( import services.dataset_service as dataset_service_module # Return a dataset object and allow permission checks to pass. - monkeypatch.setattr(module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1")) + monkeypatch.setattr(module.DatasetService, "get_dataset", lambda *_args, **_kwargs: SimpleNamespace(id="ds-1")) monkeypatch.setattr(module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None) # Return a document that will be validated inside DocumentResource.get_document. @@ -150,7 +150,7 @@ def test_batch_download_zip_returns_send_file( """Ensure batch ZIP download returns a zip attachment via `send_file`.""" monkeypatch.setattr( - datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") + datasets_document_module.DatasetService, "get_dataset", lambda *_args, **_kwargs: SimpleNamespace(id="ds-1") ) monkeypatch.setattr( datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None @@ -218,7 +218,7 @@ def test_batch_download_zip_response_is_openable_zip( # Arrange: same controller mocks as the lightweight send_file test, but we keep the real `send_file`. monkeypatch.setattr( - datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") + datasets_document_module.DatasetService, "get_dataset", lambda *_args, **_kwargs: SimpleNamespace(id="ds-1") ) monkeypatch.setattr( datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None @@ -284,7 +284,7 @@ def test_batch_download_zip_rejects_non_upload_file_document( """Ensure batch ZIP download rejects non upload-file documents.""" monkeypatch.setattr( - datasets_document_module.DatasetService, "get_dataset", lambda _dataset_id: SimpleNamespace(id="ds-1") + datasets_document_module.DatasetService, "get_dataset", lambda *_args, **_kwargs: SimpleNamespace(id="ds-1") ) monkeypatch.setattr( datasets_document_module.DatasetService, "check_dataset_permission", lambda *_args, **_kwargs: None diff --git a/api/tests/unit_tests/controllers/console/test_spec.py b/api/tests/unit_tests/controllers/console/test_spec.py index ed02923caf6f63..58d7027751b978 100644 --- a/api/tests/unit_tests/controllers/console/test_spec.py +++ b/api/tests/unit_tests/controllers/console/test_spec.py @@ -1,6 +1,8 @@ from inspect import unwrap from unittest.mock import patch +import pytest + import controllers.console.spec as spec_module @@ -22,7 +24,7 @@ def test_get_success(self): assert status == 200 assert resp == schema_definitions - def test_get_exception_returns_empty_list(self, caplog): + def test_get_exception_returns_empty_list(self, caplog: pytest.LogCaptureFixture): api = spec_module.SpecSchemaDefinitionsApi() method = unwrap(api.get) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 5eb76e309c1664..9170f38df2af87 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -356,9 +356,13 @@ def test_create_segments_returns_list(self, mock_create, mock_dataset, mock_docu """Test segment creation returns list of segments.""" mock_segments = [Mock(spec=DocumentSegment), Mock(spec=DocumentSegment)] mock_create.return_value = mock_segments + session = Mock() result = SegmentService.multi_create_segment( - segments=[{"content": "Test"}, {"content": "Test 2"}], document=mock_document, dataset=mock_dataset + segments=[{"content": "Test"}, {"content": "Test 2"}], + document=mock_document, + dataset=mock_dataset, + session=session, ) assert result is not None @@ -385,8 +389,13 @@ def test_get_segments_returns_tuple(self, mock_get, mock_document): def test_get_segment_by_id_returns_segment(self, mock_get, mock_segment): """Test get_segment_by_id returns segment.""" mock_get.return_value = mock_segment + session = Mock() - result = SegmentService.get_segment_by_id(segment_id=mock_segment.id, tenant_id=mock_segment.tenant_id) + result = SegmentService.get_segment_by_id( + segment_id=mock_segment.id, + tenant_id=mock_segment.tenant_id, + session=session, + ) assert result == mock_segment @@ -394,16 +403,22 @@ def test_get_segment_by_id_returns_segment(self, mock_get, mock_segment): def test_get_segment_by_id_returns_none_when_not_found(self, mock_get): """Test get_segment_by_id returns None when not found.""" mock_get.return_value = None + session = Mock() - result = SegmentService.get_segment_by_id(segment_id=str(uuid.uuid4()), tenant_id=str(uuid.uuid4())) + result = SegmentService.get_segment_by_id( + segment_id=str(uuid.uuid4()), + tenant_id=str(uuid.uuid4()), + session=session, + ) assert result is None @patch.object(SegmentService, "delete_segment") def test_delete_segment_called(self, mock_delete, mock_segment, mock_document, mock_dataset): """Test segment deletion is called.""" - SegmentService.delete_segment(mock_segment, mock_document, mock_dataset) - mock_delete.assert_called_once_with(mock_segment, mock_document, mock_dataset) + session = Mock() + SegmentService.delete_segment(mock_segment, mock_document, mock_dataset, session) + mock_delete.assert_called_once_with(mock_segment, mock_document, mock_dataset, session) class TestChildChunkServiceMockedBehavior: @@ -431,7 +446,11 @@ def test_create_child_chunk_returns_chunk(self, mock_create, mock_segment, mock_ mock_create.return_value = mock_child_chunk result = SegmentService.create_child_chunk( - content="New chunk content", segment=mock_segment, document=Mock(spec=Document), dataset=Mock(spec=Dataset) + content="New chunk content", + segment=mock_segment, + document=Mock(spec=Document), + dataset=Mock(spec=Dataset), + session=Mock(), ) assert result == mock_child_chunk @@ -462,7 +481,9 @@ def test_get_child_chunk_by_id_returns_chunk(self, mock_get, mock_child_chunk): mock_get.return_value = mock_child_chunk result = SegmentService.get_child_chunk_by_id( - child_chunk_id=mock_child_chunk.id, tenant_id=mock_child_chunk.tenant_id + child_chunk_id=mock_child_chunk.id, + tenant_id=mock_child_chunk.tenant_id, + session=Mock(), ) assert result == mock_child_chunk @@ -480,6 +501,7 @@ def test_update_child_chunk_returns_updated_chunk(self, mock_update, mock_child_ segment=Mock(spec=DocumentSegment), document=Mock(spec=Document), dataset=Mock(spec=Dataset), + session=Mock(), ) assert result.content == "Updated content" @@ -1156,7 +1178,7 @@ def test_delete_segment_success( # Assert assert response == ("", 204) - mock_seg_svc.delete_segment.assert_called_once_with(mock_segment, mock_doc, mock_dataset) + mock_seg_svc.delete_segment.assert_called_once_with(mock_segment, mock_doc, mock_dataset, mock_db.session) @patch("controllers.service_api.dataset.segment.SegmentService") @patch("controllers.service_api.dataset.segment.DocumentService") diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 16b54acd8c652c..e68eb6470632a9 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -269,7 +269,7 @@ def test_get_document_returns_document(self, mock_get: Mock) -> None: mock_doc.indexing_status = "completed" mock_get.return_value = mock_doc - result = DocumentService.get_document(dataset_id="dataset_id", document_id="doc_id") + result = DocumentService.get_document(dataset_id="dataset_id", document_id="doc_id", session=Mock()) assert result is not None assert result.name == "Test Document" assert result.indexing_status == "completed" @@ -278,8 +278,9 @@ def test_get_document_returns_document(self, mock_get: Mock) -> None: def test_delete_document_called(self, mock_delete): """Test delete_document is called with document.""" mock_doc = Mock() - DocumentService.delete_document(document=mock_doc) - mock_delete.assert_called_once_with(document=mock_doc) + session = Mock() + DocumentService.delete_document(document=mock_doc, session=session) + mock_delete.assert_called_once_with(document=mock_doc, session=session) class TestDocumentIndexingStatus: @@ -454,24 +455,24 @@ def test_build_display_status_filters(self): class TestDocumentServiceBatchMethods: """Test DocumentService batch operations.""" - @patch("services.dataset_service.db.session.scalars") - def test_get_documents_by_ids(self, mock_scalars): + def test_get_documents_by_ids(self): """Test batch retrieval of documents by IDs.""" dataset_id = str(uuid.uuid4()) doc_ids = [str(uuid.uuid4()), str(uuid.uuid4())] mock_result = Mock() mock_result.all.return_value = [Mock(id=doc_ids[0]), Mock(id=doc_ids[1])] - mock_scalars.return_value = mock_result + session = Mock() + session.scalars.return_value = mock_result - documents = DocumentService.get_documents_by_ids(dataset_id, doc_ids) + documents = DocumentService.get_documents_by_ids(dataset_id, doc_ids, session) assert len(documents) == 2 - mock_scalars.assert_called_once() + session.scalars.assert_called_once() def test_get_documents_by_ids_empty(self): """Test batch retrieval with empty list returns empty.""" - assert DocumentService.get_documents_by_ids("ds_id", []) == [] + assert DocumentService.get_documents_by_ids("ds_id", [], Mock()) == [] class TestDocumentServiceFileOperations: @@ -487,7 +488,7 @@ def test_get_document_download_url(self, mock_get_file, mock_signed_url): mock_get_file.return_value = mock_file mock_signed_url.return_value = "https://example.com/download" - url = DocumentService.get_document_download_url(mock_doc) + url = DocumentService.get_document_download_url(mock_doc, Mock()) assert url == "https://example.com/download" mock_signed_url.assert_called_with(upload_file_id="file_id", as_attachment=True) @@ -516,7 +517,7 @@ class TestStopError(Exception): # Skip actual logic by mocking dependent calls or raising error to stop early with pytest.raises(TestStopError): # We just want to check check_doc_form is called early - DocumentService.save_document_with_dataset_id(dataset, config, Mock()) + DocumentService.save_document_with_dataset_id(dataset, config, Mock(), session=Mock()) # This will fail if we raise exception before check_doc_form, # but check_doc_form is the first thing called. @@ -782,7 +783,7 @@ def test_delete_document_success(self, mock_db, mock_doc_svc, app: Flask, mock_t # Assert assert response == ("", 204) - mock_doc_svc.delete_document.assert_called_once_with(mock_document) + mock_doc_svc.delete_document.assert_called_once_with(mock_document, mock_db.session) @patch("controllers.service_api.dataset.document.DocumentService") @patch("controllers.service_api.dataset.document.db") diff --git a/api/tests/unit_tests/extensions/test_ext_request_logging.py b/api/tests/unit_tests/extensions/test_ext_request_logging.py index 70e807078824fc..664de8cbd8b42d 100644 --- a/api/tests/unit_tests/extensions/test_ext_request_logging.py +++ b/api/tests/unit_tests/extensions/test_ext_request_logging.py @@ -1,6 +1,7 @@ import json import logging from unittest import mock +from unittest.mock import MagicMock import pytest from flask import Flask, Response @@ -73,8 +74,8 @@ class TestRequestLoggingExtension: def test_receiver_should_not_be_invoked_if_configuration_is_disabled( self, monkeypatch: pytest.MonkeyPatch, - mock_request_receiver, - mock_response_receiver, + mock_request_receiver: MagicMock, + mock_response_receiver: MagicMock, ): monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", False) @@ -90,8 +91,8 @@ def test_receiver_should_not_be_invoked_if_configuration_is_disabled( def test_receiver_should_be_called_if_enabled( self, enable_request_logging, - mock_request_receiver, - mock_response_receiver, + mock_request_receiver: MagicMock, + mock_response_receiver: MagicMock, ): """ Test the request logging extension with JSON data. diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py index 044e0e5ab4016d..6560e5a1c0bf3d 100644 --- a/api/tests/unit_tests/services/test_dataset_service_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -18,7 +18,6 @@ TenantAccountRole, _make_knowledge_configuration, _make_retrieval_model, - _make_session_context, json, patch, pytest, @@ -345,7 +344,9 @@ def test_create_empty_dataset_raises_when_name_already_exists(self): mock_db.session.scalar.return_value = object() with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"): - DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account) + DatasetService.create_empty_dataset( + "tenant-1", "Dataset", None, "economy", account, session=mock_db.session + ) def test_create_empty_dataset_uses_default_embedding_model_for_high_quality_dataset(self): account = SimpleNamespace(id="user-1") @@ -370,6 +371,7 @@ def test_create_empty_dataset_uses_default_embedding_model_for_high_quality_data description="Description", indexing_technique="high_quality", account=account, + session=mock_db.session, ) assert dataset.embedding_model_provider == "provider" @@ -421,6 +423,7 @@ def test_create_empty_dataset_creates_external_binding_for_high_quality_dataset( embedding_model_name="embedding-model", retrieval_model=retrieval_model, summary_index_setting={"enable": True}, + session=mock_db.session, ) assert dataset.embedding_model_provider == "provider" @@ -451,7 +454,7 @@ def test_create_empty_rag_pipeline_dataset_raises_for_duplicate_name(self): mock_db.session.scalar.return_value = object() with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"): - DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity, mock_db.session) def test_create_empty_rag_pipeline_dataset_generates_name_and_creates_dataset(self): entity = RagPipelineDatasetCreateEntity( @@ -482,7 +485,7 @@ def dataset_factory(**kwargs): SimpleNamespace(name="Untitled 1"), ] - dataset = DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + dataset = DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity, mock_db.session) assert entity.name == "Untitled 2" assert dataset.pipeline_id == "pipeline-1" @@ -505,12 +508,13 @@ def test_create_empty_rag_pipeline_dataset_requires_current_user_id(self): mock_db.session.scalar.return_value = None with pytest.raises(ValueError, match="Current user or current user id not found"): - DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity, mock_db.session) def test_update_dataset_raises_when_dataset_is_missing(self): + session = MagicMock() with patch.object(DatasetService, "get_dataset", return_value=None): with pytest.raises(ValueError, match="Dataset not found"): - DatasetService.update_dataset("dataset-1", {}, SimpleNamespace(id="user-1")) + DatasetService.update_dataset("dataset-1", {}, SimpleNamespace(id="user-1"), session) def test_update_dataset_raises_when_new_name_conflicts(self): dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") @@ -521,7 +525,12 @@ def test_update_dataset_raises_when_new_name_conflicts(self): patch.object(DatasetService, "_has_dataset_same_name", return_value=True), ): with pytest.raises(ValueError, match="Dataset name already exists"): - DatasetService.update_dataset("dataset-1", {"name": "New Dataset"}, SimpleNamespace(id="user-1")) + DatasetService.update_dataset( + "dataset-1", + {"name": "New Dataset"}, + SimpleNamespace(id="user-1"), + MagicMock(), + ) def test_update_dataset_routes_external_datasets_to_external_helper(self): dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") @@ -533,13 +542,14 @@ def test_update_dataset_routes_external_datasets_to_external_helper(self): patch.object(DatasetService, "check_dataset_permission") as check_permission, patch.object(DatasetService, "_update_external_dataset", return_value="updated") as update_external, ): - result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user) + session = MagicMock() + result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user, session) assert result == "updated" check_permission.assert_called_once() assert check_permission.call_args.args[:2] == (dataset, user) assert len(check_permission.call_args.args) == 3 - update_external.assert_called_once_with(dataset, {"name": dataset.name}, user) + update_external.assert_called_once_with(dataset, {"name": dataset.name}, user, session) def test_update_dataset_routes_internal_datasets_to_internal_helper(self): dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") @@ -551,19 +561,20 @@ def test_update_dataset_routes_internal_datasets_to_internal_helper(self): patch.object(DatasetService, "check_dataset_permission") as check_permission, patch.object(DatasetService, "_update_internal_dataset", return_value="updated") as update_internal, ): - result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user) + session = MagicMock() + result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user, session) assert result == "updated" check_permission.assert_called_once() assert check_permission.call_args.args[:2] == (dataset, user) assert len(check_permission.call_args.args) == 3 - update_internal.assert_called_once_with(dataset, {"name": dataset.name}, user) + update_internal.assert_called_once_with(dataset, {"name": dataset.name}, user, session) def test_has_dataset_same_name_returns_true_when_query_matches(self): with patch("services.dataset_service.db") as mock_db: mock_db.session.scalar.return_value = object() - result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset") + result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset", mock_db.session) assert result is True @@ -592,6 +603,7 @@ def test_update_external_dataset_updates_dataset_and_binding(self): "external_knowledge_api_id": "api-1", }, user, + mock_db.session, ) assert result is dataset @@ -603,7 +615,7 @@ def test_update_external_dataset_updates_dataset_and_binding(self): assert dataset.updated_by == "user-1" assert dataset.updated_at is now get_external_knowledge_api.assert_called_once_with("api-1", dataset.tenant_id) - update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1") + update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1", mock_db.session) mock_db.session.add.assert_called_once_with(dataset) mock_db.session.commit.assert_called_once() @@ -618,7 +630,7 @@ def test_update_external_dataset_requires_external_binding_fields(self, payload, dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") with pytest.raises(ValueError, match=message): - DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1")) + DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1"), MagicMock()) def test_update_external_dataset_rejects_cross_tenant_external_api_id(self): dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") @@ -639,6 +651,7 @@ def test_update_external_dataset_rejects_cross_tenant_external_api_id(self): "external_knowledge_api_id": "foreign-api", }, SimpleNamespace(id="user-1"), + mock_db.session, ) get_external_knowledge_api.assert_called_once_with("foreign-api", dataset.tenant_id) @@ -650,16 +663,7 @@ def test_update_external_knowledge_binding_updates_changed_binding_values(self): session = MagicMock() session.scalar.return_value = binding session.add = MagicMock() - session_context = _make_session_context(session) - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value = session_context - - with ( - patch("services.dataset_service.db") as mock_db, - patch("services.dataset_service.sessionmaker", mock_sessionmaker), - ): - DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api") + DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api", session) assert binding.external_knowledge_id == "new-knowledge" assert binding.external_knowledge_api_id == "new-api" @@ -668,17 +672,8 @@ def test_update_external_knowledge_binding_updates_changed_binding_values(self): def test_update_external_knowledge_binding_raises_for_missing_binding(self): session = MagicMock() session.scalar.return_value = None - session_context = _make_session_context(session) - - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value = session_context - - with ( - patch("services.dataset_service.db"), - patch("services.dataset_service.sessionmaker", mock_sessionmaker), - ): - with pytest.raises(ValueError, match="External knowledge binding not found"): - DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1") + with pytest.raises(ValueError, match="External knowledge binding not found"): + DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1", session) def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_tasks(self): dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") @@ -704,7 +699,7 @@ def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_task patch("services.dataset_service.deal_dataset_vector_index_task") as vector_task, patch("services.dataset_service.regenerate_summary_index_task") as regenerate_task, ): - result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user) + result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user, mock_db.session) assert result is dataset updated_values = mock_db.session.execute.call_args.args[0].compile().params @@ -721,7 +716,7 @@ def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_task assert "external_retrieval_model" not in updated_values mock_db.session.commit.assert_called_once() mock_db.session.refresh.assert_called_once_with(dataset) - update_pipeline.assert_called_once_with(dataset, "user-1") + update_pipeline.assert_called_once_with(dataset, "user-1", mock_db.session) vector_task.delay.assert_called_once_with("dataset-1", "update") regenerate_task.delay.assert_called_once_with( "dataset-1", @@ -733,7 +728,7 @@ def test_update_pipeline_knowledge_base_node_data_returns_early_for_non_pipeline dataset = SimpleNamespace(runtime_mode="workflow", pipeline_id="pipeline-1") with patch("services.dataset_service.db") as mock_db: - DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1", mock_db.session) mock_db.session.get.assert_not_called() @@ -743,7 +738,7 @@ def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missi with patch("services.dataset_service.db") as mock_db: mock_db.session.get.return_value = None - DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1", mock_db.session) mock_db.session.commit.assert_not_called() @@ -782,7 +777,7 @@ def test_update_pipeline_knowledge_base_node_data_updates_published_and_draft_wo ): mock_db.session.get.return_value = pipeline - DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1", mock_db.session) published_graph = json.loads(workflow_new.call_args.kwargs["graph"]) assert published_graph["nodes"][0]["data"]["embedding_model"] == "embedding-model" @@ -805,15 +800,16 @@ def test_update_pipeline_knowledge_base_node_data_rolls_back_when_update_fails(s mock_db.session.get.return_value = pipeline with pytest.raises(RuntimeError, match="boom"): - DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1", mock_db.session) mock_db.session.rollback.assert_called_once() def test_handle_indexing_technique_change_returns_none_without_indexing_technique(self): filtered_data: dict[str, object] = {} dataset = SimpleNamespace(indexing_technique="economy") + session = MagicMock() - result = DatasetService._handle_indexing_technique_change(dataset, {}, filtered_data) + result = DatasetService._handle_indexing_technique_change(dataset, {}, filtered_data, session) assert result is None assert filtered_data == {} @@ -821,11 +817,13 @@ def test_handle_indexing_technique_change_returns_none_without_indexing_techniqu def test_handle_indexing_technique_change_switches_to_economy(self): filtered_data: dict[str, object] = {} dataset = SimpleNamespace(indexing_technique="high_quality") + session = MagicMock() result = DatasetService._handle_indexing_technique_change( dataset, {"indexing_technique": "economy"}, filtered_data, + session, ) assert result == "remove" @@ -838,20 +836,23 @@ def test_handle_indexing_technique_change_switches_to_economy(self): def test_handle_indexing_technique_change_switches_to_high_quality(self): filtered_data: dict[str, object] = {} dataset = SimpleNamespace(indexing_technique="economy") + session = MagicMock() with patch.object(DatasetService, "_configure_embedding_model_for_high_quality") as configure_embedding: result = DatasetService._handle_indexing_technique_change( dataset, {"indexing_technique": "high_quality"}, filtered_data, + session, ) assert result == "add" - configure_embedding.assert_called_once_with({"indexing_technique": "high_quality"}, filtered_data) + configure_embedding.assert_called_once_with({"indexing_technique": "high_quality"}, filtered_data, session) def test_handle_indexing_technique_change_delegates_when_technique_is_unchanged(self): filtered_data: dict[str, object] = {} dataset = SimpleNamespace(indexing_technique="high_quality") + session = MagicMock() with patch.object( DatasetService, @@ -862,10 +863,16 @@ def test_handle_indexing_technique_change_delegates_when_technique_is_unchanged( dataset, {"indexing_technique": "high_quality"}, filtered_data, + session, ) assert result == "update" - update_embedding.assert_called_once_with(dataset, {"indexing_technique": "high_quality"}, filtered_data) + update_embedding.assert_called_once_with( + dataset, + {"indexing_technique": "high_quality"}, + filtered_data, + session, + ) def test_configure_embedding_model_for_high_quality_updates_filtered_data(self): class FakeAccount: @@ -875,6 +882,7 @@ class FakeAccount: current_user.current_tenant_id = "tenant-1" embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") filtered_data: dict[str, object] = {} + session = MagicMock() with ( patch("services.dataset_service.Account", FakeAccount), @@ -890,6 +898,7 @@ class FakeAccount: DatasetService._configure_embedding_model_for_high_quality( {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, filtered_data, + session, ) assert filtered_data == { @@ -911,6 +920,7 @@ class FakeAccount: current_user = FakeAccount() current_user.current_tenant_id = "tenant-1" + session = MagicMock() with ( patch("services.dataset_service.Account", FakeAccount), @@ -923,6 +933,7 @@ class FakeAccount: DatasetService._configure_embedding_model_for_high_quality( {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, {}, + session, ) def test_handle_embedding_model_update_when_technique_unchanged_preserves_existing_settings(self): @@ -931,12 +942,14 @@ def test_handle_embedding_model_update_when_technique_unchanged_preserves_existi embedding_model="embedding-model", ) filtered_data: dict[str, object] = {} + session = MagicMock() with patch.object(DatasetService, "_preserve_existing_embedding_settings") as preserve_settings: result = DatasetService._handle_embedding_model_update_when_technique_unchanged( dataset, {}, filtered_data, + session, ) assert result is None @@ -947,16 +960,23 @@ def test_handle_embedding_model_update_when_technique_unchanged_updates_when_mod embedding_model_provider="provider", embedding_model="embedding-model", ) + session = MagicMock() with patch.object(DatasetService, "_update_embedding_model_settings", return_value="update") as update_settings: result = DatasetService._handle_embedding_model_update_when_technique_unchanged( dataset, {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, {}, + session, ) assert result == "update" - update_settings.assert_called_once() + update_settings.assert_called_once_with( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + session, + ) def test_preserve_existing_embedding_settings_keeps_current_binding(self): dataset = DatasetServiceUnitDataFactory.create_dataset_mock( @@ -991,27 +1011,36 @@ def test_update_embedding_model_settings_returns_update_for_changed_values(self) embedding_model_provider="provider", embedding_model="embedding-model", ) + session = MagicMock() with patch.object(DatasetService, "_apply_new_embedding_settings") as apply_settings: result = DatasetService._update_embedding_model_settings( dataset, {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, {}, + session, ) assert result == "update" - apply_settings.assert_called_once() + apply_settings.assert_called_once_with( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + session, + ) def test_update_embedding_model_settings_returns_none_for_unchanged_values(self): dataset = DatasetServiceUnitDataFactory.create_dataset_mock( embedding_model_provider="provider", embedding_model="embedding-model", ) + session = MagicMock() result = DatasetService._update_embedding_model_settings( dataset, {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, {}, + session, ) assert result is None @@ -1021,6 +1050,7 @@ def test_update_embedding_model_settings_wraps_bad_request_errors(self): embedding_model_provider="provider", embedding_model="embedding-model", ) + session = MagicMock() with patch.object(DatasetService, "_apply_new_embedding_settings", side_effect=LLMBadRequestError()): with pytest.raises(ValueError, match="No Embedding Model available"): @@ -1028,6 +1058,7 @@ def test_update_embedding_model_settings_wraps_bad_request_errors(self): dataset, {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, {}, + session, ) def test_apply_new_embedding_settings_updates_binding_for_new_model(self): @@ -1038,6 +1069,7 @@ class FakeAccount: current_user.current_tenant_id = "tenant-1" dataset = DatasetServiceUnitDataFactory.create_dataset_mock(collection_binding_id="binding-1") filtered_data: dict[str, object] = {} + session = MagicMock() with ( patch("services.dataset_service.Account", FakeAccount), @@ -1057,6 +1089,7 @@ class FakeAccount: dataset, {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, filtered_data, + session, ) assert filtered_data == { @@ -1077,6 +1110,7 @@ class FakeAccount: collection_binding_id="binding-1", ) filtered_data: dict[str, object] = {} + session = MagicMock() with ( patch("services.dataset_service.Account", FakeAccount), @@ -1091,6 +1125,7 @@ class FakeAccount: dataset, {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, filtered_data, + session, ) assert filtered_data == { @@ -1380,11 +1415,21 @@ class TestDatasetServicePermissionsAndLifecycle: """Unit tests for dataset permissions, deletion, and metadata helpers.""" def test_check_dataset_operator_permission_validates_required_arguments(self): + session = MagicMock() + with pytest.raises(ValueError, match="Dataset not found"): - DatasetService.check_dataset_operator_permission(user=SimpleNamespace(id="user-1"), dataset=None) + DatasetService.check_dataset_operator_permission( + user=SimpleNamespace(id="user-1"), + dataset=None, + session=session, + ) with pytest.raises(ValueError, match="User not found"): - DatasetService.check_dataset_operator_permission(user=None, dataset=SimpleNamespace(id="dataset-1")) + DatasetService.check_dataset_operator_permission( + user=None, + dataset=SimpleNamespace(id="dataset-1"), + session=session, + ) class TestDatasetCollectionBindingService: @@ -1395,44 +1440,49 @@ class TestDatasetPermissionService: """Unit tests for dataset partial-member management helpers.""" def test_update_partial_member_list_rolls_back_on_exception(self): - with patch("services.dataset_service.db") as mock_db: - mock_db.session.add_all.side_effect = RuntimeError("boom") + session = MagicMock() + session.add_all.side_effect = RuntimeError("boom") - with pytest.raises(RuntimeError, match="boom"): - DatasetPermissionService.update_partial_member_list( - "tenant-1", - "dataset-1", - [{"user_id": "user-1"}], - ) + with pytest.raises(RuntimeError, match="boom"): + DatasetPermissionService.update_partial_member_list( + "tenant-1", + "dataset-1", + [{"user_id": "user-1"}], + session, + ) - mock_db.session.rollback.assert_called_once() + session.rollback.assert_called_once() def test_check_permission_requires_dataset_editor(self): user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + session = MagicMock() with pytest.raises(NoPermissionError, match="does not have permission"): - DatasetPermissionService.check_permission(user, dataset, "all_team", []) + DatasetPermissionService.check_permission(user, dataset, "all_team", [], session) def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="all_team") + session = MagicMock() with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): - DatasetPermissionService.check_permission(user, dataset, "only_me", []) + DatasetPermissionService.check_permission(user, dataset, "only_me", [], session) def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="partial_members") + session = MagicMock() with pytest.raises(ValueError, match="Partial member list is required"): - DatasetPermissionService.check_permission(user, dataset, "partial_members", []) + DatasetPermissionService.check_permission(user, dataset, "partial_members", [], session) def test_check_permission_rejects_dataset_operator_member_list_changes(self): user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) dataset = DatasetServiceUnitDataFactory.create_dataset_mock( dataset_id="dataset-1", permission="partial_members" ) + session = MagicMock() with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): with pytest.raises(ValueError, match="cannot change the dataset permissions"): @@ -1441,6 +1491,7 @@ def test_check_permission_rejects_dataset_operator_member_list_changes(self): dataset, "partial_members", [{"user_id": "user-2"}], + session, ) def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): @@ -1448,6 +1499,7 @@ def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged( dataset = DatasetServiceUnitDataFactory.create_dataset_mock( dataset_id="dataset-1", permission="partial_members" ) + session = MagicMock() with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): DatasetPermissionService.check_permission( @@ -1455,13 +1507,14 @@ def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged( dataset, "partial_members", [{"user_id": "user-1"}], + session, ) def test_clear_partial_member_list_rolls_back_on_exception(self): - with patch("services.dataset_service.db") as mock_db: - mock_db.session.execute.side_effect = RuntimeError("boom") + session = MagicMock() + session.execute.side_effect = RuntimeError("boom") - with pytest.raises(RuntimeError, match="boom"): - DatasetPermissionService.clear_partial_member_list("dataset-1") + with pytest.raises(RuntimeError, match="boom"): + DatasetPermissionService.clear_partial_member_list("dataset-1", session) - mock_db.session.rollback.assert_called_once() + session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index 9a8243936b32e8..c108b06ac6bb4c 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -104,30 +104,34 @@ def test_check_archived_returns_boolean_status(self, archived, expected): assert DocumentService.check_archived(document) is expected def test_rename_document_raises_when_dataset_is_missing(self, rename_account_context): + session = MagicMock() + with patch.object(DatasetService, "get_dataset", return_value=None): with pytest.raises(ValueError, match="Dataset not found"): - DocumentService.rename_document("dataset-1", "doc-1", "New Name") + DocumentService.rename_document("dataset-1", "doc-1", "New Name", session) def test_rename_document_raises_when_document_is_missing(self, rename_account_context): dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + session = MagicMock() with ( patch.object(DatasetService, "get_dataset", return_value=dataset), patch.object(DocumentService, "get_document", return_value=None), ): with pytest.raises(ValueError, match="Document not found"): - DocumentService.rename_document(dataset.id, "doc-1", "New Name") + DocumentService.rename_document(dataset.id, "doc-1", "New Name", session) def test_rename_document_rejects_cross_tenant_access(self, rename_account_context): dataset = DatasetServiceUnitDataFactory.create_dataset_mock() document = DatasetServiceUnitDataFactory.create_document_mock(tenant_id="tenant-other") + session = MagicMock() with ( patch.object(DatasetService, "get_dataset", return_value=dataset), patch.object(DocumentService, "get_document", return_value=document), ): with pytest.raises(ValueError, match="No permission"): - DocumentService.rename_document(dataset.id, document.id, "New Name") + DocumentService.rename_document(dataset.id, document.id, "New Name", session) def test_rename_document_updates_document_metadata_and_upload_file_name(self, rename_account_context): dataset = DatasetServiceUnitDataFactory.create_dataset_mock( @@ -146,7 +150,7 @@ def test_rename_document_updates_document_metadata_and_upload_file_name(self, re patch.object(DocumentService, "get_document", return_value=document), patch("services.dataset_service.db") as mock_db, ): - result = DocumentService.rename_document(dataset.id, document.id, "New Name") + result = DocumentService.rename_document(dataset.id, document.id, "New Name", mock_db.session) assert result is document assert document.name == "New Name" @@ -157,27 +161,30 @@ def test_rename_document_updates_document_metadata_and_upload_file_name(self, re def test_recover_document_raises_when_document_is_not_paused(self): document = DatasetServiceUnitDataFactory.create_document_mock(is_paused=False) + session = MagicMock() with pytest.raises(DocumentIndexingError): - DocumentService.recover_document(document) + DocumentService.recover_document(document, session) def test_retry_document_raises_when_retry_flag_is_already_set(self): document = DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1") + session = MagicMock() with patch("services.dataset_service.redis_client") as mock_redis: mock_redis.get.return_value = "1" with pytest.raises(ValueError, match="being retried"): - DocumentService.retry_document("dataset-1", [document]) + DocumentService.retry_document("dataset-1", [document], session) def test_sync_website_document_raises_when_sync_flag_exists(self): document = DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1") + session = MagicMock() with patch("services.dataset_service.redis_client") as mock_redis: mock_redis.get.return_value = "1" with pytest.raises(ValueError, match="being synced"): - DocumentService.sync_website_document("dataset-1", document) + DocumentService.sync_website_document("dataset-1", document, session) def test_sync_website_document_updates_status_sets_cache_and_dispatches_task(self): document = DatasetServiceUnitDataFactory.create_document_mock( @@ -193,7 +200,7 @@ def test_sync_website_document_updates_status_sets_cache_and_dispatches_task(sel ): mock_redis.get.return_value = None - DocumentService.sync_website_document("dataset-1", document) + DocumentService.sync_website_document("dataset-1", document, mock_db.session) assert document.indexing_status == "waiting" assert '"mode": "scrape"' in document.data_source_info @@ -258,6 +265,7 @@ def test_save_document_without_dataset_id_creates_high_quality_dataset_with_defa tenant_id="tenant-1", knowledge_config=knowledge_config, account=account_context, + session=mock_db.session, ) assert dataset is created_dataset @@ -274,7 +282,12 @@ def test_save_document_without_dataset_id_creates_high_quality_dataset_with_defa == "useful for when you want to answer queries about the VeryLongDocumentNameForDataset.txt" ) dataset_cls.assert_called_once() - save_document.assert_called_once_with(created_dataset, knowledge_config, account_context) + save_document.assert_called_once_with( + created_dataset, + knowledge_config, + account_context, + session=mock_db.session, + ) assert mock_db.session.commit.call_count == 1 def test_save_document_without_dataset_id_uses_provided_retrieval_model(self, account_context): @@ -312,9 +325,14 @@ def test_save_document_without_dataset_id_uses_provided_retrieval_model(self, ac "save_document_with_dataset_id", return_value=([SimpleNamespace(name="Doc")], "batch-1"), ), - patch("services.dataset_service.db"), + patch("services.dataset_service.db") as mock_db, ): - DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + DocumentService.save_document_without_dataset_id( + "tenant-1", + knowledge_config, + account_context, + mock_db.session, + ) assert created_dataset.retrieval_model == retrieval_model.model_dump() assert created_dataset.collection_binding_id is None @@ -337,8 +355,9 @@ def test_save_document_without_dataset_id_rejects_sandbox_batch_upload(self, acc ), patch.object(DocumentService, "check_documents_upload_quota") as check_quota, ): + session = MagicMock() with pytest.raises(ValueError, match="does not support batch upload"): - DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context, session) check_quota.assert_not_called() @@ -367,13 +386,19 @@ def test_update_document_with_dataset_id_raises_when_document_is_missing(self, a ) ), ) + session = MagicMock() with ( patch.object(DocumentService, "get_document", return_value=None), patch.object(DatasetService, "check_dataset_model_setting") as check_model_setting, ): with pytest.raises(NotFound, match="Document not found"): - DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=session, + ) check_model_setting.assert_called_once_with(dataset) @@ -390,13 +415,19 @@ def test_update_document_with_dataset_id_rejects_non_available_documents(self, a ) ), ) + session = MagicMock() with ( patch.object(DocumentService, "get_document", return_value=document), patch.object(DatasetService, "check_dataset_model_setting"), ): with pytest.raises(ValueError, match="Document is not available"): - DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=session, + ) def test_update_document_with_dataset_id_upload_file_process_rule_and_name_override(self, account_context): dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") @@ -433,7 +464,12 @@ def test_update_document_with_dataset_id_upload_file_process_rule_and_name_overr ): mock_db.session.scalar.return_value = SimpleNamespace(id="file-1", name="upload.txt") - result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + result = DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=mock_db.session, + ) assert result is document assert document.dataset_process_rule_id == "rule-2" @@ -481,7 +517,12 @@ def test_update_document_with_dataset_id_notion_import_requires_binding(self, ac mock_db.session.scalar.return_value = None with pytest.raises(ValueError, match="Data source binding not found"): - DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=mock_db.session, + ) def test_update_document_with_dataset_id_website_crawl_updates_segments_and_dispatches_task(self, account_context): dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") @@ -510,7 +551,12 @@ def test_update_document_with_dataset_id_website_crawl_updates_segments_and_disp patch("services.dataset_service.naive_utc_now", return_value="now"), patch("services.dataset_service.document_indexing_update_task") as update_task, ): - result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + result = DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=mock_db.session, + ) assert result is document assert document.data_source_type == "website_crawl" @@ -681,8 +727,14 @@ def test_save_document_with_dataset_id_requires_file_info_for_upload_source(self knowledge_config = _make_upload_knowledge_config(file_ids=None) with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)): + session = MagicMock() with pytest.raises(ValueError, match="File source info is required"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=session, + ) def test_save_document_with_dataset_id_blocks_batch_upload_for_sandbox_plan(self, account_context): dataset = _make_dataset() @@ -695,8 +747,14 @@ def test_save_document_with_dataset_id_blocks_batch_upload_for_sandbox_plan(self ), patch.object(DocumentService, "check_documents_upload_quota") as check_quota, ): + session = MagicMock() with pytest.raises(ValueError, match="does not support batch upload"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=session, + ) check_quota.assert_not_called() @@ -709,8 +767,14 @@ def test_save_document_with_dataset_id_enforces_batch_upload_limit(self, account patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 1), patch.object(DocumentService, "check_documents_upload_quota") as check_quota, ): + session = MagicMock() with pytest.raises(ValueError, match="batch upload limit of 1"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=session, + ) check_quota.assert_not_called() @@ -725,20 +789,32 @@ def test_save_document_with_dataset_id_updates_existing_document_and_data_source DocumentService, "update_document_with_dataset_id", return_value=updated_document ) as update_document, ): - documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + session = MagicMock() + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=session, + ) assert dataset.data_source_type == "upload_file" assert documents == [updated_document] assert batch == "batch-existing" - update_document.assert_called_once_with(dataset, knowledge_config, account_context) + update_document.assert_called_once_with(dataset, knowledge_config, account_context, session=session) def test_save_document_with_dataset_id_requires_data_source_for_new_documents(self, account_context): dataset = _make_dataset() knowledge_config = _make_upload_knowledge_config(data_source=None) with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + session = MagicMock() with pytest.raises(ValueError, match="Data source is required when creating new documents"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=session, + ) def test_save_document_with_dataset_id_requires_existing_process_rule_for_custom_mode(self, account_context): dataset = _make_dataset(latest_process_rule=None) @@ -748,8 +824,14 @@ def test_save_document_with_dataset_id_requires_existing_process_rule_for_custom ) with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + session = MagicMock() with pytest.raises(ValueError, match="No process rule found"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=session, + ) def test_save_document_with_dataset_id_rejects_invalid_indexing_technique(self, account_context): dataset = _make_dataset(indexing_technique=None) @@ -761,8 +843,14 @@ def test_save_document_with_dataset_id_rejects_invalid_indexing_technique(self, ) with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + session = MagicMock() with pytest.raises(ValueError, match="Indexing technique is invalid"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=session, + ) def test_save_document_with_dataset_id_returns_empty_for_invalid_process_rule_mode(self, account_context): dataset = _make_dataset() @@ -770,7 +858,12 @@ def test_save_document_with_dataset_id_returns_empty_for_invalid_process_rule_mo knowledge_config.process_rule = SimpleNamespace(mode="unsupported-mode", rules=None) with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): - documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=MagicMock(), + ) assert documents == [] assert batch == "" @@ -807,6 +900,7 @@ def test_save_document_with_dataset_id_upload_file_creates_and_reindexes_documen knowledge_config, account_context, dataset_process_rule=dataset_process_rule, + session=mock_db.session, ) assert documents == [duplicate_document, created_document] @@ -887,6 +981,7 @@ def test_save_document_with_dataset_id_notion_import_truncates_names_and_cleans_ knowledge_config, account_context, dataset_process_rule=dataset_process_rule, + session=mock_db.session, ) assert created_document in documents @@ -938,6 +1033,7 @@ def test_save_document_with_dataset_id_website_crawl_truncates_long_urls(self, a knowledge_config, account_context, dataset_process_rule=dataset_process_rule, + session=mock_db.session, ) assert documents == [first_document, second_document] @@ -990,7 +1086,7 @@ def test_batch_update_document_status_rejects_indexing_documents(self): with pytest.raises(DocumentIndexingError, match="Busy document is being indexed"): DocumentService.batch_update_document_status( - dataset, [document.id], "archive", SimpleNamespace(id="user-1") + dataset, [document.id], "archive", SimpleNamespace(id="user-1"), mock_db.session ) mock_db.session.commit.assert_not_called() @@ -1009,7 +1105,7 @@ def test_batch_update_document_status_rolls_back_when_commit_fails(self): with pytest.raises(RuntimeError, match="commit failed"): DocumentService.batch_update_document_status( - dataset, [document.id], "enable", SimpleNamespace(id="user-1") + dataset, [document.id], "enable", SimpleNamespace(id="user-1"), mock_db.session ) mock_db.session.rollback.assert_called_once() @@ -1029,7 +1125,7 @@ def test_batch_update_document_status_raises_async_task_error_after_commit(self) with pytest.raises(RuntimeError, match="task failed"): DocumentService.batch_update_document_status( - dataset, [document.id], "enable", SimpleNamespace(id="user-1") + dataset, [document.id], "enable", SimpleNamespace(id="user-1"), mock_db.session ) mock_db.session.commit.assert_called_once() @@ -1052,7 +1148,7 @@ def test_get_tenant_documents_count_returns_query_count(self, account_context): with patch("services.dataset_service.db") as mock_db: mock_db.session.scalar.return_value = 12 - result = DocumentService.get_tenant_documents_count() + result = DocumentService.get_tenant_documents_count(mock_db.session) assert result == 12 @@ -1091,7 +1187,12 @@ def test_update_document_with_dataset_id_uses_automatic_process_rule_payload(sel process_rule_cls.return_value = created_process_rule mock_db.session.scalar.return_value = SimpleNamespace(id="file-1", name="upload.txt") - result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + result = DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=mock_db.session, + ) assert result is document assert document.dataset_process_rule_id == "rule-2" @@ -1117,8 +1218,14 @@ def test_update_document_with_dataset_id_requires_upload_file_info(self, account patch.object(DocumentService, "get_document", return_value=_make_document()), patch.object(DatasetService, "check_dataset_model_setting"), ): + session = MagicMock() with pytest.raises(ValueError, match="No file info list found"): - DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=session, + ) def test_update_document_with_dataset_id_raises_when_upload_file_is_missing(self, account_context): dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") @@ -1141,7 +1248,12 @@ def test_update_document_with_dataset_id_raises_when_upload_file_is_missing(self mock_db.session.scalar.return_value = None with pytest.raises(FileNotExistsError): - DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=mock_db.session, + ) def test_update_document_with_dataset_id_requires_notion_info_list(self, account_context): dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") @@ -1155,8 +1267,14 @@ def test_update_document_with_dataset_id_requires_notion_info_list(self, account patch.object(DocumentService, "get_document", return_value=_make_document()), patch.object(DatasetService, "check_dataset_model_setting"), ): + session = MagicMock() with pytest.raises(ValueError, match="No notion info list found"): - DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=session, + ) def test_update_document_with_dataset_id_notion_import_updates_page_info(self, account_context): dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") @@ -1191,7 +1309,12 @@ def test_update_document_with_dataset_id_notion_import_updates_page_info(self, a ): mock_db.session.scalar.return_value = SimpleNamespace(id="binding-1") - result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + result = DocumentService.update_document_with_dataset_id( + dataset, + document_data, + account_context, + session=mock_db.session, + ) assert result is document assert document.data_source_type == "notion_import" @@ -1260,9 +1383,14 @@ def test_save_document_without_dataset_id_counts_notion_pages_for_quota(self, ac "save_document_with_dataset_id", return_value=([SimpleNamespace(name="Doc")], "batch-1"), ), - patch("services.dataset_service.db"), + patch("services.dataset_service.db") as mock_db, ): - DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + DocumentService.save_document_without_dataset_id( + "tenant-1", + knowledge_config, + account_context, + mock_db.session, + ) check_quota.assert_called_once_with(3, features) @@ -1287,8 +1415,9 @@ def test_save_document_without_dataset_id_enforces_batch_limit_for_website_urls( patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", "1"), patch.object(DocumentService, "check_documents_upload_quota") as check_quota, ): + session = MagicMock() with pytest.raises(ValueError, match="batch upload limit of 1"): - DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context, session) check_quota.assert_not_called() @@ -1458,7 +1587,13 @@ def test_save_document_with_dataset_id_initializes_high_quality_dataset_from_def provider="default-provider", ) - documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + session = MagicMock() + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=session, + ) assert documents == [updated_document] assert batch == "batch-existing" @@ -1474,7 +1609,7 @@ def test_save_document_with_dataset_id_initializes_high_quality_dataset_from_def "top_k": 4, "score_threshold_enabled": False, } - get_binding.assert_called_once_with("default-provider", "default-embedding") + get_binding.assert_called_once_with("default-provider", "default-embedding", session) def test_save_document_with_dataset_id_uses_explicit_embedding_and_retrieval_model(self, account_context): dataset = _make_dataset(indexing_technique=None) @@ -1503,10 +1638,11 @@ def test_save_document_with_dataset_id_uses_explicit_embedding_and_retrieval_mod ) as get_binding, patch.object(DocumentService, "update_document_with_dataset_id", return_value=_make_document()), ): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + session = MagicMock() + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context, session=session) model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called() - get_binding.assert_called_once_with("explicit-provider", "explicit-model") + get_binding.assert_called_once_with("explicit-provider", "explicit-model", session) assert dataset.embedding_model == "explicit-model" assert dataset.embedding_model_provider == "explicit-provider" assert dataset.retrieval_model == knowledge_config.retrieval_model.model_dump() @@ -1541,7 +1677,12 @@ def test_save_document_with_dataset_id_creates_custom_process_rule_for_new_uploa process_rule_cls.return_value = created_process_rule mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []] - documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=mock_db.session, + ) assert documents == [created_document] assert batch == "20260101010101100023" @@ -1581,7 +1722,12 @@ def test_save_document_with_dataset_id_creates_automatic_process_rule_for_new_up process_rule_cls.return_value = created_process_rule mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []] - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=mock_db.session, + ) assert process_rule_cls.call_args.kwargs == { "dataset_id": "dataset-1", @@ -1615,7 +1761,12 @@ def test_save_document_with_dataset_id_creates_fallback_automatic_process_rule_w process_rule_cls.return_value = created_process_rule mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []] - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=mock_db.session, + ) assert process_rule_cls.call_args.kwargs == { "dataset_id": "dataset-1", @@ -1640,7 +1791,12 @@ def test_save_document_with_dataset_id_raises_when_upload_file_lookup_is_incompl mock_db.session.scalars.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] with pytest.raises(FileNotExistsError, match="One or more files not found"): - DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + session=mock_db.session, + ) def test_save_document_with_dataset_id_requires_notion_info_list_for_notion_import(self, account_context): dataset = _make_dataset() @@ -1663,6 +1819,7 @@ def test_save_document_with_dataset_id_requires_notion_info_list_for_notion_impo knowledge_config, account_context, dataset_process_rule=SimpleNamespace(id="rule-1"), + session=MagicMock(), ) def test_save_document_with_dataset_id_requires_website_info_list_for_website_crawl(self, account_context): @@ -1686,4 +1843,5 @@ def test_save_document_with_dataset_id_requires_website_info_list_for_website_cr knowledge_config, account_context, dataset_process_rule=SimpleNamespace(id="rule-1"), + session=MagicMock(), ) diff --git a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py index 352a765de28825..5e5d406edb8201 100644 --- a/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py +++ b/api/tests/unit_tests/services/test_dataset_service_lock_not_owned.py @@ -94,7 +94,9 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( # Avoid touching real doc_form logic monkeypatch.setattr("services.dataset_service.DatasetService.check_doc_form", lambda *a, **k: None) # Avoid real DB interactions - monkeypatch.setattr("services.dataset_service.db", Mock()) + db_mock = Mock() + db_mock.session = Mock() + monkeypatch.setattr("services.dataset_service.db", db_mock) # Act: this would hit the redis lock, whose __enter__ raises LockNotOwnedError. # Our implementation should catch it and still return (documents, batch). @@ -102,6 +104,7 @@ def test_save_document_with_dataset_id_ignores_lock_not_owned( dataset=dataset, knowledge_config=knowledge_config, account=account, + session=db_mock.session, ) # Assert @@ -148,7 +151,7 @@ def test_add_segment_ignores_lock_not_owned( monkeypatch.setattr("services.dataset_service.VectorService", Mock()) # Act - result = SegmentService.create_segment(args=args, document=document, dataset=dataset) + result = SegmentService.create_segment(args=args, document=document, dataset=dataset, session=db_mock.session) # Assert # Under LockNotOwnedError except, add_segment should swallow the error and return None. diff --git a/api/tests/unit_tests/services/test_dataset_service_segment.py b/api/tests/unit_tests/services/test_dataset_service_segment.py index 1f8586e32f37e0..a625f17ef37b92 100644 --- a/api/tests/unit_tests/services/test_dataset_service_segment.py +++ b/api/tests/unit_tests/services/test_dataset_service_segment.py @@ -51,7 +51,13 @@ def test_create_child_chunk_assigns_next_position_and_commits(self, account_cont mock_redis.lock.return_value = _make_lock_context() mock_db.session.scalar.return_value = 2 - child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset) + child_chunk = SegmentService.create_child_chunk( + "child content", + segment, + document, + dataset, + mock_db.session, + ) assert isinstance(child_chunk, ChildChunk) assert child_chunk.position == 3 @@ -79,7 +85,7 @@ def test_create_child_chunk_rolls_back_and_raises_on_vector_failure(self, accoun vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed") with pytest.raises(ChildChunkIndexingError, match="vector failed"): - SegmentService.create_child_chunk("child content", segment, document, dataset) + SegmentService.create_child_chunk("child content", segment, document, dataset, mock_db.session) mock_db.session.rollback.assert_called_once() mock_db.session.commit.assert_not_called() @@ -127,6 +133,7 @@ def test_update_child_chunks_updates_deletes_and_creates_records(self, account_c segment, document, dataset, + mock_db.session, ) assert [chunk.position for chunk in result] == [1, 3] @@ -164,6 +171,7 @@ def test_update_child_chunks_rolls_back_on_vector_failure(self, account_context) segment, document, dataset, + mock_db.session, ) mock_db.session.rollback.assert_called_once() @@ -179,7 +187,7 @@ def test_update_child_chunk_updates_vector_and_commits(self, account_context): patch("services.dataset_service.VectorService") as vector_service, ): result = SegmentService.update_child_chunk( - "new content", child_chunk, _make_segment(), _make_document(), dataset + "new content", child_chunk, _make_segment(), _make_document(), dataset, mock_db.session ) assert result is child_chunk @@ -202,7 +210,7 @@ def test_delete_child_chunk_raises_delete_index_error_on_vector_failure(self): vector_service.delete_child_chunk_vector.side_effect = RuntimeError("delete failed") with pytest.raises(ChildChunkDeleteIndexError, match="delete failed"): - SegmentService.delete_child_chunk(child_chunk, dataset) + SegmentService.delete_child_chunk(child_chunk, dataset, mock_db.session) mock_db.session.delete.assert_called_once_with(child_chunk) mock_db.session.rollback.assert_called_once() @@ -247,13 +255,13 @@ def test_get_child_chunk_by_id_returns_only_child_chunk_instances(self): with patch("services.dataset_service.db") as mock_db: mock_db.session.scalar.return_value = child_chunk - result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") + result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1", mock_db.session) assert result is child_chunk with patch("services.dataset_service.db") as mock_db: mock_db.session.scalar.return_value = SimpleNamespace() - result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") + result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1", mock_db.session) assert result is None @@ -294,13 +302,13 @@ def test_get_segment_by_id_returns_only_document_segment_instances(self): segment.id = "segment-1" with patch("services.dataset_service.db") as mock_db: mock_db.session.scalar.return_value = segment - result = SegmentService.get_segment_by_id("segment-1", "tenant-1") + result = SegmentService.get_segment_by_id("segment-1", "tenant-1", mock_db.session) assert result is segment with patch("services.dataset_service.db") as mock_db: mock_db.session.scalar.return_value = SimpleNamespace() - result = SegmentService.get_segment_by_id("segment-1", "tenant-1") + result = SegmentService.get_segment_by_id("segment-1", "tenant-1", mock_db.session) assert result is None @@ -323,6 +331,7 @@ def test_get_segments_by_document_and_dataset_returns_scalars_result(self): result = SegmentService.get_segments_by_document_and_dataset( document_id="doc-1", dataset_id="dataset-1", + session=mock_db.session, status="completed", enabled=True, ) @@ -409,7 +418,12 @@ def add_side_effect(obj): mock_db.session.add.side_effect = add_side_effect vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") - result = SegmentService.create_segment(args=args, document=document, dataset=dataset) + result = SegmentService.create_segment( + args=args, + document=document, + dataset=dataset, + session=mock_db.session, + ) created_segment = vector_service.create_segments_vector.call_args.args[1][0] attachment_bindings = [ @@ -459,7 +473,7 @@ def test_multi_create_segment_high_quality_marks_segments_error_when_vector_crea mock_db.session.scalar.return_value = 1 vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") - result = SegmentService.multi_create_segment(segments, document, dataset) + result = SegmentService.multi_create_segment(segments, document, dataset, mock_db.session) assert result assert len(result) == 2 @@ -488,7 +502,7 @@ def test_update_segment_disables_enabled_segment_and_dispatches_index_cleanup(se ): mock_redis.get.return_value = None - result = SegmentService.update_segment(args, segment, document, dataset) + result = SegmentService.update_segment(args, segment, document, dataset, mock_db.session) assert result is segment assert segment.enabled is False @@ -508,7 +522,9 @@ def test_update_segment_rejects_updates_for_disabled_segment(self, account_conte mock_redis.get.return_value = None with pytest.raises(ValueError, match="Can't update disabled segment"): - SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset) + SegmentService.update_segment( + SegmentUpdateArgs(content="new content"), segment, document, dataset, MagicMock() + ) def test_update_segment_rejects_when_indexing_cache_exists(self, account_context): segment = _make_segment(enabled=True) @@ -519,7 +535,9 @@ def test_update_segment_rejects_when_indexing_cache_exists(self, account_context mock_redis.get.return_value = "1" with pytest.raises(ValueError, match="Segment is indexing"): - SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset) + SegmentService.update_segment( + SegmentUpdateArgs(content="new content"), segment, document, dataset, MagicMock() + ) def test_update_segment_updates_keywords_for_same_content_segment(self, account_context): segment = _make_segment(content="same content", keywords=["old"]) @@ -536,7 +554,7 @@ def test_update_segment_updates_keywords_for_same_content_segment(self, account_ mock_redis.get.return_value = None mock_db.session.get.return_value = refreshed_segment - result = SegmentService.update_segment(args, segment, document, dataset) + result = SegmentService.update_segment(args, segment, document, dataset, mock_db.session) assert result is refreshed_segment assert segment.keywords == ["new"] @@ -575,7 +593,7 @@ def test_update_segment_regenerates_child_chunks_and_updates_manual_summary(self # scalar call: existing_summary mock_db.session.scalar.return_value = existing_summary - result = SegmentService.update_segment(args, segment, document, dataset) + result = SegmentService.update_segment(args, segment, document, dataset, mock_db.session) assert result is refreshed_segment vector_service.generate_child_chunks.assert_called_once_with( @@ -617,7 +635,7 @@ def test_update_segment_auto_regenerates_summary_after_content_change(self, acco mock_db.session.scalar.return_value = existing_summary mock_db.session.get.return_value = refreshed_segment - result = SegmentService.update_segment(args, segment, document, dataset) + result = SegmentService.update_segment(args, segment, document, dataset, mock_db.session) assert result is refreshed_segment assert segment.content == "new content" @@ -657,7 +675,7 @@ def test_update_segment_regenerates_summary_when_manual_summary_is_unchanged(sel mock_db.session.scalar.return_value = existing_summary mock_db.session.get.return_value = refreshed_segment - result = SegmentService.update_segment(args, segment, document, dataset) + result = SegmentService.update_segment(args, segment, document, dataset, mock_db.session) assert result is refreshed_segment generate_summary.assert_called_once_with(segment, dataset, {"enable": True}) @@ -677,7 +695,7 @@ def test_delete_segment_removes_index_and_updates_document_word_count(self): mock_redis.get.return_value = None mock_db.session.scalars.return_value.all.return_value = ["child-1", "child-2"] - SegmentService.delete_segment(segment, document, dataset) + SegmentService.delete_segment(segment, document, dataset, mock_db.session) assert document.word_count == 6 mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_delete_indexing", 600, 1) @@ -701,7 +719,7 @@ def test_delete_segment_rejects_when_delete_is_already_in_progress(self): mock_redis.get.return_value = "1" with pytest.raises(ValueError, match="Segment is deleting"): - SegmentService.delete_segment(segment, document, dataset) + SegmentService.delete_segment(segment, document, dataset, MagicMock()) def test_delete_segments_removes_records_and_clamps_document_word_count(self): dataset = _make_dataset() @@ -723,7 +741,7 @@ def test_delete_segments_removes_records_and_clamps_document_word_count(self): # scalars() for child_node_ids mock_db.session.scalars.return_value.all.return_value = ["child-1"] - SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset) + SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset, mock_db.session) assert document.word_count == 0 mock_db.session.add.assert_called_once_with(document) @@ -753,7 +771,9 @@ def test_update_segments_status_enables_only_segments_without_indexing_cache(sel mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b] mock_redis.get.side_effect = [None, "1"] - SegmentService.update_segments_status(["segment-a", "segment-b"], "enable", dataset, document) + SegmentService.update_segments_status( + ["segment-a", "segment-b"], "enable", dataset, document, mock_db.session + ) assert segment_a.enabled is True assert segment_a.disabled_at is None @@ -780,7 +800,9 @@ def test_update_segments_status_disables_only_segments_without_indexing_cache(se mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b] mock_redis.get.side_effect = [None, "1"] - SegmentService.update_segments_status(["segment-a", "segment-b"], "disable", dataset, document) + SegmentService.update_segments_status( + ["segment-a", "segment-b"], "disable", dataset, document, mock_db.session + ) assert segment_a.enabled is False assert segment_a.disabled_at == "now" @@ -808,7 +830,7 @@ def test_update_child_chunk_rolls_back_on_vector_failure(self): with pytest.raises(ChildChunkIndexingError, match="vector failed"): SegmentService.update_child_chunk( - "new content", child_chunk, SimpleNamespace(), SimpleNamespace(), dataset + "new content", child_chunk, SimpleNamespace(), SimpleNamespace(), dataset, mock_db.session ) mock_db.session.rollback.assert_called_once() @@ -822,7 +844,7 @@ def test_delete_child_chunk_commits_after_successful_vector_delete(self): patch("services.dataset_service.db") as mock_db, patch("services.dataset_service.VectorService") as vector_service, ): - SegmentService.delete_child_chunk(child_chunk, dataset) + SegmentService.delete_child_chunk(child_chunk, dataset, mock_db.session) mock_db.session.delete.assert_called_once_with(child_chunk) vector_service.delete_child_chunk_vector.assert_called_once_with(child_chunk, dataset) @@ -860,6 +882,7 @@ def test_update_segment_same_content_updates_answer_and_document_word_count_for_ segment, document, dataset, + mock_db.session, ) assert result is refreshed_segment @@ -895,6 +918,7 @@ def test_update_segment_content_change_uses_answer_when_counting_tokens_for_qa_s segment, document, dataset, + mock_db.session, ) assert result is refreshed_segment @@ -943,6 +967,7 @@ def test_update_segment_content_change_parent_child_uses_default_embedding_and_i segment, document, dataset, + mock_db.session, ) assert result is refreshed_segment @@ -986,6 +1011,7 @@ def test_update_segment_same_content_parent_child_marks_segment_error_for_non_hi segment, document, dataset, + mock_db.session, ) assert result is refreshed_segment diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index cef11c0038d7e2..19418c43926cd7 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -1169,7 +1169,7 @@ def test_get_document_summary_status_detail_counts_and_previews(monkeypatch: pyt monkeypatch.setattr(SummaryIndexService, "get_document_summaries", MagicMock(return_value=[summary1])) - detail = SummaryIndexService.get_document_summary_status_detail("doc-1", "dataset-1") + detail = SummaryIndexService.get_document_summary_status_detail("doc-1", "dataset-1", MagicMock()) assert detail["total_segments"] == 2 assert detail["summary_status"]["completed"] == 1 assert detail["summary_status"]["not_started"] == 1