From af0574bb942c93f0a6c479d55a738302c9207375 Mon Sep 17 00:00:00 2001 From: Shiwani Mishra Date: Tue, 31 Mar 2026 15:24:05 +0530 Subject: [PATCH] fix: add pagination to /rest/v1/root_cres and cap per_page on list endpoints /rest/v1/root_cres had no pagination, returning all root CREs in a single unbounded query. Add get_root_cres_with_pagination() in db.py and update the endpoint to accept page/per_page params, returning page and total_pages metadata alongside data. /rest/v1/all_cres already supported pagination but accepted any positive integer for per_page, allowing a single request to fetch the entire dataset. Introduce MAX_PER_PAGE = 100 and cap per_page on both list endpoints. Format-based responses (Markdown, CSV, OSCAL) on root_cres are intentional full-export flows and are not paginated. --- application/database/db.py | 26 +++++++++++++++++++++ application/tests/db_test.py | 37 ++++++++++++++++++++++++++++++ application/tests/web_main_test.py | 36 ++++++++++++++++++++++++++++- application/web/web_main.py | 31 ++++++++++++++++++------- 4 files changed, 121 insertions(+), 9 deletions(-) diff --git a/application/database/db.py b/application/database/db.py index d4ac9b7e8..8c692a6ff 100644 --- a/application/database/db.py +++ b/application/database/db.py @@ -1915,6 +1915,32 @@ def get_root_cres(self): result.extend(self.get_CREs(external_id=c.external_id)) return result + def get_root_cres_with_pagination(self, page: int = 1, per_page: int = 20) -> tuple: + """Returns paginated root CREs (those that only have "Contains" links)""" + cres_page = ( + self.session.query(CRE) + .filter( + ~CRE.id.in_( + self.session.query(InternalLinks.cre).filter( + InternalLinks.type == cre_defs.LinkTypes.Contains, + ) + ) + ) + .filter( + ~CRE.id.in_( + self.session.query(InternalLinks.group).filter( + InternalLinks.type == cre_defs.LinkTypes.PartOf, + ) + ) + ) + .paginate(page=page, per_page=per_page, error_out=False) + ) + total_pages = cres_page.pages + result = [] + for c in cres_page.items: + result.extend(self.get_CREs(external_id=c.external_id)) + return result, page, total_pages + def get_embeddings_by_doc_type(self, doc_type: str) -> Dict[str, List[float]]: res = {} embeddings = ( diff --git a/application/tests/db_test.py b/application/tests/db_test.py index 1d13bd0be..4eb219d16 100644 --- a/application/tests/db_test.py +++ b/application/tests/db_test.py @@ -1252,6 +1252,43 @@ def test_get_root_cres(self): self.maxDiff = None self.assertCountEqual(root_cres, [cres[0], cres[1], cres[7]]) + def test_get_root_cres_with_pagination(self): + """get_root_cres_with_pagination should return paginated root CREs + with correct page and total_pages metadata""" + sqla.session.remove() + sqla.drop_all() + sqla.create_all() + + collection = db.Node_collection() + cres = [] + dbcres = [] + + # Create 4 root CREs (no internal links between them) + for i in range(0, 4): + cres.append(defs.CRE(name=f"Root C{i}", id=f"{i}{i}{i}-{i}{i}{i}")) + dbcres.append(collection.add_cre(cres[i])) + + collection.session.commit() + + # Page 1 with per_page=2 should return first 2 root CREs, total_pages=2 + result, page, total_pages = collection.get_root_cres_with_pagination( + page=1, per_page=2 + ) + self.maxDiff = None + self.assertEqual(page, 1) + self.assertEqual(total_pages, 2) + self.assertEqual(len(result), 2) + self.assertCountEqual(result, [cres[0], cres[1]]) + + # Page 2 with per_page=2 should return remaining 2 root CREs + result, page, total_pages = collection.get_root_cres_with_pagination( + page=2, per_page=2 + ) + self.assertEqual(page, 2) + self.assertEqual(total_pages, 2) + self.assertEqual(len(result), 2) + self.assertCountEqual(result, [cres[2], cres[3]]) + @patch.object(db.NEO_DB, "gap_analysis") def test_gap_analysis_disconnected(self, gap_mock): collection = db.Node_collection() diff --git a/application/tests/web_main_test.py b/application/tests/web_main_test.py index 9e219b4ce..9f16fe944 100644 --- a/application/tests/web_main_test.py +++ b/application/tests/web_main_test.py @@ -510,7 +510,11 @@ def test_find_root_cres(self) -> None: higher=dcb, lower=dcd, ltype=defs.LinkTypes.Contains ) - expected = {"data": [cres["ca"].todict(), cres["cb"].todict()]} + expected = { + "data": [cres["ca"].todict(), cres["cb"].todict()], + "page": 1, + "total_pages": 1, + } response = client.get( "/rest/v1/root_cres", headers={"Content-Type": "application/json"}, @@ -518,6 +522,36 @@ def test_find_root_cres(self) -> None: self.assertEqual(json.loads(response.data.decode()), expected) self.assertEqual(200, response.status_code) + @patch.object(db, "Node_collection") + def test_root_cres_per_page_cap(self, db_mock) -> None: + """per_page above MAX_PER_PAGE should be silently capped""" + cres = [defs.CRE(name=f"cre{i}", id=f"{i}{i}{i}-{i}{i}{i}") for i in range(3)] + db_mock.return_value.get_root_cres_with_pagination.return_value = (cres, 1, 1) + + with self.app.test_client() as client: + client.get( + "/rest/v1/root_cres?per_page=99999", + headers={"Content-Type": "application/json"}, + ) + call_args = db_mock.return_value.get_root_cres_with_pagination.call_args + _, called_per_page = call_args[0] + self.assertLessEqual(called_per_page, web_main.MAX_PER_PAGE) + + @patch.object(db, "Node_collection") + def test_all_cres_per_page_cap(self, db_mock) -> None: + """per_page above MAX_PER_PAGE should be silently capped""" + cres = [defs.CRE(name=f"cre{i}", id=f"{i}{i}{i}-{i}{i}{i}") for i in range(3)] + db_mock.return_value.all_cres_with_pagination.return_value = (cres, 1, 1) + + with self.app.test_client() as client: + client.get( + "/rest/v1/all_cres?per_page=99999", + headers={"Content-Type": "application/json"}, + ) + call_args = db_mock.return_value.all_cres_with_pagination.call_args + _, called_per_page = call_args[0] + self.assertLessEqual(called_per_page, web_main.MAX_PER_PAGE) + def test_smartlink(self) -> None: self.maxDiff = None collection = db.Node_collection().with_graph() diff --git a/application/web/web_main.py b/application/web/web_main.py index 29567470a..94d4ea722 100644 --- a/application/web/web_main.py +++ b/application/web/web_main.py @@ -48,6 +48,7 @@ ITEMS_PER_PAGE = 20 +MAX_PER_PAGE = 100 app = Blueprint( "web", @@ -506,12 +507,11 @@ def find_root_cres() -> Any: database = db.Node_collection() # opt_osib = request.args.get("osib") opt_format = request.args.get("format") - documents = database.get_root_cres() - if documents: - res = [doc.todict() for doc in documents] - result = {"data": res} - # if opt_osib: - # result["osib"] = odefs.cre2osib(documents).todict() + + if opt_format: + documents = database.get_root_cres() + if not documents: + abort(404, "No root CREs") if opt_format == SupportedFormats.Markdown.value: return f"
{mdutils.cre_to_md(documents)}
" elif opt_format == SupportedFormats.CSV.value: @@ -522,7 +522,22 @@ def find_root_cres() -> Any: elif opt_format == SupportedFormats.OSCAL.value: return jsonify(json.loads(oscal_utils.list_to_oscal(documents))) - return jsonify(result) + page = 1 + per_page = ITEMS_PER_PAGE + if request.args.get("page") is not None and int(request.args.get("page")) > 0: + page = int(request.args.get("page")) + if ( + request.args.get("per_page") is not None + and int(request.args.get("per_page")) > 0 + ): + per_page = min(int(request.args.get("per_page")), MAX_PER_PAGE) + + documents, page, total_pages = database.get_root_cres_with_pagination( + page, per_page + ) + if documents: + res = [doc.todict() for doc in documents] + return jsonify({"data": res, "page": page, "total_pages": total_pages}) abort(404, "No root CREs") @@ -814,7 +829,7 @@ def all_cres() -> Any: request.args.get("per_page") is not None and int(request.args.get("per_page")) > 0 ): - per_page = int(request.args.get("per_page")) + per_page = min(int(request.args.get("per_page")), MAX_PER_PAGE) documents, page, total_pages = database.all_cres_with_pagination(page, per_page) if documents: