Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 66 additions & 67 deletions api/controllers/console/datasets/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,72 +243,71 @@ def get(self, current_tenant_id: str, current_user: Account) -> tuple[dict[str,
if not credential:
raise NotFound("Credential not found.")
exist_page_ids = []
with sessionmaker(db.engine).begin() as session:
# import notion in the exist dataset
if query.dataset_id:
dataset = DatasetService.get_dataset(query.dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")

documents = session.scalars(
select(Document).where(
Document.dataset_id == query.dataset_id,
Document.tenant_id == current_tenant_id,
Document.data_source_type == "notion_import",
Document.enabled.is_(True),
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
from core.datasource.datasource_manager import DatasourceManager

datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
if credential:
datasource_runtime.runtime.credentials = credential
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
# import notion in the exist dataset
if query.dataset_id:
dataset = DatasetService.get_dataset(query.dataset_id, db.session)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
raise ValueError("Dataset is not notion type.")

documents = db.session.scalars(
select(Document).where(
Document.dataset_id == query.dataset_id,
Document.tenant_id == current_tenant_id,
Document.data_source_type == "notion_import",
Document.enabled.is_(True),
)
).all()
if documents:
for document in documents:
data_source_info = json.loads(document.data_source_info)
exist_page_ids.append(data_source_info["notion_page_id"])
# get all authorized pages
from core.datasource.datasource_manager import DatasourceManager

datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id="langgenius/notion_datasource/notion_datasource",
datasource_name="notion_datasource",
tenant_id=current_tenant_id,
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
)
datasource_provider_service = DatasourceProviderService()
if credential:
datasource_runtime.runtime.credentials = credential
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=current_user.id,
datasource_parameters={},
provider_type=datasource_runtime.datasource_provider_type(),
)
try:
pages = []
workspace_info = {}
for message in online_document_result:
result = message.result
for info in result:
workspace_info = {
"workspace_id": info.workspace_id,
"workspace_name": info.workspace_name,
"workspace_icon": info.workspace_icon,
)
try:
pages = []
workspace_info = {}
for message in online_document_result:
result = message.result
for info in result:
workspace_info = {
"workspace_id": info.workspace_id,
"workspace_name": info.workspace_name,
"workspace_icon": info.workspace_icon,
}
for page in info.pages:
page_info = {
"page_id": page.page_id,
"page_name": page.page_name,
"type": page.type,
"parent_id": page.parent_id,
"is_bound": page.page_id in exist_page_ids,
"page_icon": page.page_icon,
}
for page in info.pages:
page_info = {
"page_id": page.page_id,
"page_name": page.page_name,
"type": page.type,
"parent_id": page.parent_id,
"is_bound": page.page_id in exist_page_ids,
"page_icon": page.page_icon,
}
pages.append(page_info)
except Exception as e:
raise e
notion_info = [{**workspace_info, "pages": pages}] if workspace_info else []
return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200
pages.append(page_info)
except Exception as e:
raise e
notion_info = [{**workspace_info, "pages": pages}] if workspace_info else []
return dump_response(NotionIntegrateInfoListResponse, {"notion_info": notion_info}), 200


@console_ns.route("/notion/pages/<uuid:page_id>/<string:page_type>/preview")
Expand Down Expand Up @@ -401,11 +400,11 @@ class DataSourceNotionDatasetSyncApi(Resource):
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_CREATE_AND_MANAGEMENT)
def get(self, dataset_id: UUID) -> tuple[dict[str, str], int]:
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")

documents = DocumentService.get_document_by_dataset_id(dataset_id_str)
documents = DocumentService.get_document_by_dataset_id(dataset_id_str, db.session)
for document in documents:
document_indexing_sync_task.delay(dataset_id_str, document.id)
return {"result": "success"}, 200
Expand All @@ -421,11 +420,11 @@ class DataSourceNotionDocumentSyncApi(Resource):
def get(self, dataset_id: UUID, document_id: UUID) -> tuple[dict[str, str], int]:
dataset_id_str = str(dataset_id)
document_id_str = str(document_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")

document = DocumentService.get_document(dataset_id_str, document_id_str)
document = DocumentService.get_document(dataset_id_str, document_id_str, session=db.session)
if document is None:
raise NotFound("Document not found.")
document_indexing_sync_task.delay(dataset_id_str, document_id_str)
Expand Down
46 changes: 25 additions & 21 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ def post(self, current_tenant_id: str, current_user: Account):
provider=payload.provider,
external_knowledge_api_id=payload.external_knowledge_api_id,
external_knowledge_id=payload.external_knowledge_id,
session=db.session,
)
except services.errors.dataset.DatasetNameDuplicateError:
raise DatasetNameDuplicateError()
Expand Down Expand Up @@ -598,7 +599,7 @@ class DatasetApi(Resource):
@with_current_tenant_id
def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")
try:
Expand All @@ -618,7 +619,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
provider_id = ModelProviderID(dataset.embedding_model_provider)
data["embedding_model_provider"] = str(provider_id)
if data.get("permission") == "partial_members":
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str, db.session)
data.update({"partial_member_list": part_users_list})

# check embedding setting
Expand Down Expand Up @@ -661,7 +662,7 @@ def get(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_EDIT)
def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")

Expand All @@ -680,10 +681,10 @@ def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID)
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
if not dify_config.RBAC_ENABLED:
DatasetPermissionService.check_permission(
current_user, dataset, payload.permission, payload.partial_member_list
current_user, dataset, payload.permission, payload.partial_member_list, db.session
)

dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user)
dataset = DatasetService.update_dataset(dataset_id_str, payload_data, current_user, db.session)

if dataset is None:
raise NotFound("Dataset not found.")
Expand All @@ -698,12 +699,14 @@ def patch(self, current_tenant_id: str, current_user: Account, dataset_id: UUID)
tenant_id = current_tenant_id

if payload.partial_member_list is not None and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM:
DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list)
DatasetPermissionService.update_partial_member_list(
tenant_id, dataset_id_str, payload.partial_member_list, db.session
)
# clear partial member list when permission is only_me or all_team_members
elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}:
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
DatasetPermissionService.clear_partial_member_list(dataset_id_str, db.session)

partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str, db.session)
result_data.update({"partial_member_list": partial_member_list})

return result_data, 200
Expand All @@ -722,8 +725,8 @@ def delete(self, current_user: Account, dataset_id: UUID):
raise Forbidden()

try:
if DatasetService.delete_dataset(dataset_id_str, current_user):
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
if DatasetService.delete_dataset(dataset_id_str, current_user, db.session):
DatasetPermissionService.clear_partial_member_list(dataset_id_str, db.session)
return "", 204
else:
raise NotFound("Dataset not found.")
Expand All @@ -748,7 +751,7 @@ class DatasetUseCheckApi(Resource):
def get(self, dataset_id: UUID):
dataset_id_str = str(dataset_id)

dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
dataset_is_using = DatasetService.dataset_use_check(dataset_id_str, db.session)
return {"is_using": dataset_is_using}, 200


Expand All @@ -769,7 +772,7 @@ class DatasetQueryApi(Resource):
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY)
def get(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")

Expand Down Expand Up @@ -910,7 +913,7 @@ class DatasetRelatedAppListApi(Resource):
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY)
def get(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")

Expand All @@ -919,7 +922,7 @@ def get(self, current_user: Account, dataset_id: UUID):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))

app_dataset_joins = DatasetService.get_related_apps(dataset.id)
app_dataset_joins = DatasetService.get_related_apps(dataset.id, db.session)

related_apps = []
for app_dataset_join in app_dataset_joins:
Expand Down Expand Up @@ -1094,7 +1097,7 @@ class DatasetEnableApiApi(Resource):
def post(self, dataset_id: UUID, status: str):
dataset_id_str = str(dataset_id)

DatasetService.update_dataset_api_status(dataset_id_str, status == "enable")
DatasetService.update_dataset_api_status(dataset_id_str, status == "enable", db.session)

return {"result": "success"}, 200

Expand Down Expand Up @@ -1163,10 +1166,10 @@ class DatasetErrorDocs(Resource):
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY)
def get(self, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str, db.session)

return dump_response(ErrorDocsResponse, {"data": results, "total": len(results)}), 200

Expand All @@ -1190,15 +1193,15 @@ class DatasetPermissionUserListApi(Resource):
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY)
def get(self, current_user: Account, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user, db.session)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))

partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str, db.session)

return dump_response(PartialMemberListResponse, {"data": partial_members_list}), 200

Expand All @@ -1220,7 +1223,8 @@ class DatasetAutoDisableLogApi(Resource):
@rbac_permission_required(RBACResourceScope.DATASET, RBACPermission.DATASET_READONLY)
def get(self, dataset_id: UUID):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
dataset = DatasetService.get_dataset(dataset_id_str, db.session)
if dataset is None:
raise NotFound("Dataset not found.")
return dump_response(AutoDisableLogsResponse, DatasetService.get_dataset_auto_disable_logs(dataset_id_str)), 200
auto_disable_logs = DatasetService.get_dataset_auto_disable_logs(dataset_id_str, db.session)
return dump_response(AutoDisableLogsResponse, auto_disable_logs), 200
Loading
Loading