From 501945682ef8759e2f112c4ccfc73d431a53daec Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Thu, 9 Apr 2026 11:33:51 -0400 Subject: [PATCH 01/60] Move pyro_tools --- examples/aristoff_bangerth.py | 2 +- examples/gaussian_mvp.py | 2 +- src/nak_torch/tools/__init__.py | 6 ++++++ {examples => src/nak_torch/tools}/pyro_tools.py | 0 4 files changed, 8 insertions(+), 2 deletions(-) rename {examples => src/nak_torch/tools}/pyro_tools.py (100%) diff --git a/examples/aristoff_bangerth.py b/examples/aristoff_bangerth.py index ef8f8ab..cc22201 100644 --- a/examples/aristoff_bangerth.py +++ b/examples/aristoff_bangerth.py @@ -9,7 +9,7 @@ from nak_torch.tools.kernel import sqexp_kernel_matrix from tqdm import tqdm import pandas as pd -import pyro_tools +from nak_torch.tools import pyro_tools from pyro.infer import mcmc if torch.cuda.is_available(): diff --git a/examples/gaussian_mvp.py b/examples/gaussian_mvp.py index 79df5de..921fc6c 100644 --- a/examples/gaussian_mvp.py +++ b/examples/gaussian_mvp.py @@ -16,7 +16,7 @@ from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre from pyro.infer import mcmc -import pyro_tools +from nak_torch.tools import pyro_tools if torch.cuda.is_available(): torch.set_default_device("cuda") diff --git a/src/nak_torch/tools/__init__.py b/src/nak_torch/tools/__init__.py index e414608..e5cdabd 100644 --- a/src/nak_torch/tools/__init__.py +++ b/src/nak_torch/tools/__init__.py @@ -3,6 +3,8 @@ # 05/12/2025 +import importlib.util + from . import kernel, types, quadrature, adaptive_step from .average import recursive_weighted_average_alpha_v from .torchify import differentiable_density_factory @@ -17,3 +19,7 @@ "quadrature", "adaptive_step", ] +if importlib.util.find_spec("pyro") is not None: + from . import pyro_tools # noqa: F401 + + __all__.append("pyro_tools") diff --git a/examples/pyro_tools.py b/src/nak_torch/tools/pyro_tools.py similarity index 100% rename from examples/pyro_tools.py rename to src/nak_torch/tools/pyro_tools.py From 4cf254a99afbebd73c1e79d210ef2c71ac2b4a0d Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Thu, 9 Apr 2026 11:34:01 -0400 Subject: [PATCH 02/60] add pyro as an optional dependency --- pyproject.toml | 5 +++++ uv.lock | 29 ++++++++++++++++++----------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a767969..15917b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,11 @@ dependencies = [ [project.scripts] nak-torch = "nak_torch:main" +[project.optional-dependencies] +pyro = [ + "pyro-ppl>=1.9.1", +] + [build-system] requires = ["uv_build>=0.8.3,<0.9.0"] build-backend = "uv_build" diff --git a/uv.lock b/uv.lock index 10e770e..4048224 100644 --- a/uv.lock +++ b/uv.lock @@ -516,11 +516,16 @@ dependencies = [ { name = "tqdm" }, ] +[package.optional-dependencies] +pyro = [ + { name = "pyro-ppl" }, +] + [package.dev-dependencies] dev = [ + { name = "matplotlib" }, { name = "pandas" }, { name = "pyro-ppl" }, - { name = "matplotlib" }, { name = "pytest" }, { name = "pytest-cov" }, { name = "ruff" }, @@ -530,15 +535,17 @@ dev = [ requires-dist = [ { name = "jaxtyping", specifier = ">=0.3.5" }, { name = "numpy", specifier = ">=2.4.1" }, + { name = "pyro-ppl", marker = "extra == 'pyro'", specifier = ">=1.9.1" }, { name = "torch", specifier = ">=2.10" }, { name = "tqdm", specifier = ">=4.67.1" }, ] +provides-extras = ["pyro"] [package.metadata.requires-dev] dev = [ + { name = "matplotlib", specifier = ">=3.10.8" }, { name = "pandas", specifier = ">=3.0.2" }, { name = "pyro-ppl", specifier = ">=1.9.1" }, - { name = "matplotlib", specifier = ">=3.10.8" }, { name = "pytest", specifier = ">=9.0.2" }, { name = "pytest-cov", specifier = ">=7.1.0" }, { name = "ruff", specifier = ">=0.15.8" }, @@ -905,6 +912,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pyparsing" +version = "3.3.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/91/9c6ee907786a473bf81c5f53cf703ba0957b23ab84c264080fb5a450416f/pyparsing-3.3.2.tar.gz", hash = "sha256:c777f4d763f140633dcb6d8a3eda953bf7a214dc4eff598413c070bcdc117cbc", size = 6851574, upload-time = "2026-01-21T03:57:59.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" }, +] + [[package]] name = "pyro-api" version = "0.1.2" @@ -930,15 +946,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/37/def183a2a2c8619d92649d62fe0622c4c6c62f60e4151e8fbaa409e7d5ab/pyro_ppl-1.9.1-py3-none-any.whl", hash = "sha256:91fb2c8740d9d3bd548180ac5ecfa04552ed8c471a1ab66870180663b8f09852", size = 755956, upload-time = "2024-06-02T00:37:37.486Z" }, ] -[[package]] -name = "pyparsing" -version = "3.3.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f3/91/9c6ee907786a473bf81c5f53cf703ba0957b23ab84c264080fb5a450416f/pyparsing-3.3.2.tar.gz", hash = "sha256:c777f4d763f140633dcb6d8a3eda953bf7a214dc4eff598413c070bcdc117cbc", size = 6851574, upload-time = "2026-01-21T03:57:59.36Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" }, -] - [[package]] name = "pytest" version = "9.0.2" From 605f670ef8c567a9d9767018c8017ee1dc1b7831 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 12 Apr 2026 18:21:06 -0400 Subject: [PATCH 03/60] update examples option --- pyproject.toml | 4 +- uv.lock | 441 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 441 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 15917b3..4153269 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,9 @@ dependencies = [ nak-torch = "nak_torch:main" [project.optional-dependencies] -pyro = [ +examples = [ + "ipykernel>=7.2.0", + "matplotlib>=3.10.8", "pyro-ppl>=1.9.1", ] diff --git a/uv.lock b/uv.lock index 4048224..b0b133f 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,81 @@ resolution-markers = [ "python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] +[[package]] +name = "appnope" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/5d/752690df9ef5b76e169e68d6a129fa6d08a7100ca7f754c89495db3c6019/appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee", size = 4170, upload-time = "2024-02-06T09:43:11.258Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321, upload-time = "2024-02-06T09:43:09.663Z" }, +] + +[[package]] +name = "asttokens" +version = "3.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/be/a5/8e3f9b6771b0b408517c82d97aed8f2036509bc247d46114925e32fe33f0/asttokens-3.0.1.tar.gz", hash = "sha256:71a4ee5de0bde6a31d64f6b13f2293ac190344478f081c3d1bccfcf5eacb0cb7", size = 62308, upload-time = "2025-11-15T16:43:48.578Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" }, +] + +[[package]] +name = "cffi" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pycparser", marker = "implementation_name != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/47/4f61023ea636104d4f16ab488e268b93008c3d0bb76893b1b31db1f96802/cffi-2.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d", size = 185271, upload-time = "2025-09-08T23:22:44.795Z" }, + { url = "https://files.pythonhosted.org/packages/df/a2/781b623f57358e360d62cdd7a8c681f074a71d445418a776eef0aadb4ab4/cffi-2.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c", size = 181048, upload-time = "2025-09-08T23:22:45.938Z" }, + { url = "https://files.pythonhosted.org/packages/ff/df/a4f0fbd47331ceeba3d37c2e51e9dfc9722498becbeec2bd8bc856c9538a/cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe", size = 212529, upload-time = "2025-09-08T23:22:47.349Z" }, + { url = "https://files.pythonhosted.org/packages/d5/72/12b5f8d3865bf0f87cf1404d8c374e7487dcf097a1c91c436e72e6badd83/cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062", size = 220097, upload-time = "2025-09-08T23:22:48.677Z" }, + { url = "https://files.pythonhosted.org/packages/c2/95/7a135d52a50dfa7c882ab0ac17e8dc11cec9d55d2c18dda414c051c5e69e/cffi-2.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e", size = 207983, upload-time = "2025-09-08T23:22:50.06Z" }, + { url = "https://files.pythonhosted.org/packages/3a/c8/15cb9ada8895957ea171c62dc78ff3e99159ee7adb13c0123c001a2546c1/cffi-2.0.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037", size = 206519, upload-time = "2025-09-08T23:22:51.364Z" }, + { url = "https://files.pythonhosted.org/packages/78/2d/7fa73dfa841b5ac06c7b8855cfc18622132e365f5b81d02230333ff26e9e/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba", size = 219572, upload-time = "2025-09-08T23:22:52.902Z" }, + { url = "https://files.pythonhosted.org/packages/07/e0/267e57e387b4ca276b90f0434ff88b2c2241ad72b16d31836adddfd6031b/cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94", size = 222963, upload-time = "2025-09-08T23:22:54.518Z" }, + { url = "https://files.pythonhosted.org/packages/b6/75/1f2747525e06f53efbd878f4d03bac5b859cbc11c633d0fb81432d98a795/cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187", size = 221361, upload-time = "2025-09-08T23:22:55.867Z" }, + { url = "https://files.pythonhosted.org/packages/7b/2b/2b6435f76bfeb6bbf055596976da087377ede68df465419d192acf00c437/cffi-2.0.0-cp312-cp312-win32.whl", hash = "sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18", size = 172932, upload-time = "2025-09-08T23:22:57.188Z" }, + { url = "https://files.pythonhosted.org/packages/f8/ed/13bd4418627013bec4ed6e54283b1959cf6db888048c7cf4b4c3b5b36002/cffi-2.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5", size = 183557, upload-time = "2025-09-08T23:22:58.351Z" }, + { url = "https://files.pythonhosted.org/packages/95/31/9f7f93ad2f8eff1dbc1c3656d7ca5bfd8fb52c9d786b4dcf19b2d02217fa/cffi-2.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6", size = 177762, upload-time = "2025-09-08T23:22:59.668Z" }, + { url = "https://files.pythonhosted.org/packages/4b/8d/a0a47a0c9e413a658623d014e91e74a50cdd2c423f7ccfd44086ef767f90/cffi-2.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb", size = 185230, upload-time = "2025-09-08T23:23:00.879Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d2/a6c0296814556c68ee32009d9c2ad4f85f2707cdecfd7727951ec228005d/cffi-2.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca", size = 181043, upload-time = "2025-09-08T23:23:02.231Z" }, + { url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446, upload-time = "2025-09-08T23:23:03.472Z" }, + { url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101, upload-time = "2025-09-08T23:23:04.792Z" }, + { url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948, upload-time = "2025-09-08T23:23:06.127Z" }, + { url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422, upload-time = "2025-09-08T23:23:07.753Z" }, + { url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499, upload-time = "2025-09-08T23:23:09.648Z" }, + { url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928, upload-time = "2025-09-08T23:23:10.928Z" }, + { url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302, upload-time = "2025-09-08T23:23:12.42Z" }, + { url = "https://files.pythonhosted.org/packages/eb/6d/bf9bda840d5f1dfdbf0feca87fbdb64a918a69bca42cfa0ba7b137c48cb8/cffi-2.0.0-cp313-cp313-win32.whl", hash = "sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27", size = 172909, upload-time = "2025-09-08T23:23:14.32Z" }, + { url = "https://files.pythonhosted.org/packages/37/18/6519e1ee6f5a1e579e04b9ddb6f1676c17368a7aba48299c3759bbc3c8b3/cffi-2.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75", size = 183402, upload-time = "2025-09-08T23:23:15.535Z" }, + { url = "https://files.pythonhosted.org/packages/cb/0e/02ceeec9a7d6ee63bb596121c2c8e9b3a9e150936f4fbef6ca1943e6137c/cffi-2.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91", size = 177780, upload-time = "2025-09-08T23:23:16.761Z" }, + { url = "https://files.pythonhosted.org/packages/92/c4/3ce07396253a83250ee98564f8d7e9789fab8e58858f35d07a9a2c78de9f/cffi-2.0.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:fc33c5141b55ed366cfaad382df24fe7dcbc686de5be719b207bb248e3053dc5", size = 185320, upload-time = "2025-09-08T23:23:18.087Z" }, + { url = "https://files.pythonhosted.org/packages/59/dd/27e9fa567a23931c838c6b02d0764611c62290062a6d4e8ff7863daf9730/cffi-2.0.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:c654de545946e0db659b3400168c9ad31b5d29593291482c43e3564effbcee13", size = 181487, upload-time = "2025-09-08T23:23:19.622Z" }, + { url = "https://files.pythonhosted.org/packages/d6/43/0e822876f87ea8a4ef95442c3d766a06a51fc5298823f884ef87aaad168c/cffi-2.0.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:24b6f81f1983e6df8db3adc38562c83f7d4a0c36162885ec7f7b77c7dcbec97b", size = 220049, upload-time = "2025-09-08T23:23:20.853Z" }, + { url = "https://files.pythonhosted.org/packages/b4/89/76799151d9c2d2d1ead63c2429da9ea9d7aac304603de0c6e8764e6e8e70/cffi-2.0.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:12873ca6cb9b0f0d3a0da705d6086fe911591737a59f28b7936bdfed27c0d47c", size = 207793, upload-time = "2025-09-08T23:23:22.08Z" }, + { url = "https://files.pythonhosted.org/packages/bb/dd/3465b14bb9e24ee24cb88c9e3730f6de63111fffe513492bf8c808a3547e/cffi-2.0.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:d9b97165e8aed9272a6bb17c01e3cc5871a594a446ebedc996e2397a1c1ea8ef", size = 206300, upload-time = "2025-09-08T23:23:23.314Z" }, + { url = "https://files.pythonhosted.org/packages/47/d9/d83e293854571c877a92da46fdec39158f8d7e68da75bf73581225d28e90/cffi-2.0.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:afb8db5439b81cf9c9d0c80404b60c3cc9c3add93e114dcae767f1477cb53775", size = 219244, upload-time = "2025-09-08T23:23:24.541Z" }, + { url = "https://files.pythonhosted.org/packages/2b/0f/1f177e3683aead2bb00f7679a16451d302c436b5cbf2505f0ea8146ef59e/cffi-2.0.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:737fe7d37e1a1bffe70bd5754ea763a62a066dc5913ca57e957824b72a85e205", size = 222828, upload-time = "2025-09-08T23:23:26.143Z" }, + { url = "https://files.pythonhosted.org/packages/c6/0f/cafacebd4b040e3119dcb32fed8bdef8dfe94da653155f9d0b9dc660166e/cffi-2.0.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:38100abb9d1b1435bc4cc340bb4489635dc2f0da7456590877030c9b3d40b0c1", size = 220926, upload-time = "2025-09-08T23:23:27.873Z" }, + { url = "https://files.pythonhosted.org/packages/3e/aa/df335faa45b395396fcbc03de2dfcab242cd61a9900e914fe682a59170b1/cffi-2.0.0-cp314-cp314-win32.whl", hash = "sha256:087067fa8953339c723661eda6b54bc98c5625757ea62e95eb4898ad5e776e9f", size = 175328, upload-time = "2025-09-08T23:23:44.61Z" }, + { url = "https://files.pythonhosted.org/packages/bb/92/882c2d30831744296ce713f0feb4c1cd30f346ef747b530b5318715cc367/cffi-2.0.0-cp314-cp314-win_amd64.whl", hash = "sha256:203a48d1fb583fc7d78a4c6655692963b860a417c0528492a6bc21f1aaefab25", size = 185650, upload-time = "2025-09-08T23:23:45.848Z" }, + { url = "https://files.pythonhosted.org/packages/9f/2c/98ece204b9d35a7366b5b2c6539c350313ca13932143e79dc133ba757104/cffi-2.0.0-cp314-cp314-win_arm64.whl", hash = "sha256:dbd5c7a25a7cb98f5ca55d258b103a2054f859a46ae11aaf23134f9cc0d356ad", size = 180687, upload-time = "2025-09-08T23:23:47.105Z" }, + { url = "https://files.pythonhosted.org/packages/3e/61/c768e4d548bfa607abcda77423448df8c471f25dbe64fb2ef6d555eae006/cffi-2.0.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:9a67fc9e8eb39039280526379fb3a70023d77caec1852002b4da7e8b270c4dd9", size = 188773, upload-time = "2025-09-08T23:23:29.347Z" }, + { url = "https://files.pythonhosted.org/packages/2c/ea/5f76bce7cf6fcd0ab1a1058b5af899bfbef198bea4d5686da88471ea0336/cffi-2.0.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7a66c7204d8869299919db4d5069a82f1561581af12b11b3c9f48c584eb8743d", size = 185013, upload-time = "2025-09-08T23:23:30.63Z" }, + { url = "https://files.pythonhosted.org/packages/be/b4/c56878d0d1755cf9caa54ba71e5d049479c52f9e4afc230f06822162ab2f/cffi-2.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cc09976e8b56f8cebd752f7113ad07752461f48a58cbba644139015ac24954c", size = 221593, upload-time = "2025-09-08T23:23:31.91Z" }, + { url = "https://files.pythonhosted.org/packages/e0/0d/eb704606dfe8033e7128df5e90fee946bbcb64a04fcdaa97321309004000/cffi-2.0.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:92b68146a71df78564e4ef48af17551a5ddd142e5190cdf2c5624d0c3ff5b2e8", size = 209354, upload-time = "2025-09-08T23:23:33.214Z" }, + { url = "https://files.pythonhosted.org/packages/d8/19/3c435d727b368ca475fb8742ab97c9cb13a0de600ce86f62eab7fa3eea60/cffi-2.0.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:b1e74d11748e7e98e2f426ab176d4ed720a64412b6a15054378afdb71e0f37dc", size = 208480, upload-time = "2025-09-08T23:23:34.495Z" }, + { url = "https://files.pythonhosted.org/packages/d0/44/681604464ed9541673e486521497406fadcc15b5217c3e326b061696899a/cffi-2.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:28a3a209b96630bca57cce802da70c266eb08c6e97e5afd61a75611ee6c64592", size = 221584, upload-time = "2025-09-08T23:23:36.096Z" }, + { url = "https://files.pythonhosted.org/packages/25/8e/342a504ff018a2825d395d44d63a767dd8ebc927ebda557fecdaca3ac33a/cffi-2.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7553fb2090d71822f02c629afe6042c299edf91ba1bf94951165613553984512", size = 224443, upload-time = "2025-09-08T23:23:37.328Z" }, + { url = "https://files.pythonhosted.org/packages/e1/5e/b666bacbbc60fbf415ba9988324a132c9a7a0448a9a8f125074671c0f2c3/cffi-2.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c6c373cfc5c83a975506110d17457138c8c63016b563cc9ed6e056a82f13ce4", size = 223437, upload-time = "2025-09-08T23:23:38.945Z" }, + { url = "https://files.pythonhosted.org/packages/a0/1d/ec1a60bd1a10daa292d3cd6bb0b359a81607154fb8165f3ec95fe003b85c/cffi-2.0.0-cp314-cp314t-win32.whl", hash = "sha256:1fc9ea04857caf665289b7a75923f2c6ed559b8298a1b8c49e59f7dd95c8481e", size = 180487, upload-time = "2025-09-08T23:23:40.423Z" }, + { url = "https://files.pythonhosted.org/packages/bf/41/4c1168c74fac325c0c8156f04b6749c8b6a8f405bbf91413ba088359f60d/cffi-2.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:d68b6cef7827e8641e8ef16f4494edda8b36104d79773a334beaa1e3521430f6", size = 191726, upload-time = "2025-09-08T23:23:41.742Z" }, + { url = "https://files.pythonhosted.org/packages/ae/3a/dbeec9d1ee0844c679f6bb5d6ad4e9f198b1224f4e7a32825f47f6192b0c/cffi-2.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9", size = 184195, upload-time = "2025-09-08T23:23:43.004Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -19,6 +94,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "comm" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/13/7d740c5849255756bc17888787313b61fd38a0a8304fc4f073dfc46122aa/comm-0.2.3.tar.gz", hash = "sha256:2dc8048c10962d55d7ad693be1e7045d891b7ce8d999c97963a5e3e99c055971", size = 6319, upload-time = "2025-07-25T14:02:04.452Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417", size = 7294, upload-time = "2025-07-25T14:02:02.896Z" }, +] + [[package]] name = "contourpy" version = "1.3.3" @@ -201,6 +285,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, ] +[[package]] +name = "debugpy" +version = "1.8.20" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/b7/cd8080344452e4874aae67c40d8940e2b4d47b01601a8fd9f44786c757c7/debugpy-1.8.20.tar.gz", hash = "sha256:55bc8701714969f1ab89a6d5f2f3d40c36f91b2cbe2f65d98bf8196f6a6a2c33", size = 1645207, upload-time = "2026-01-29T23:03:28.199Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/57/7f34f4736bfb6e00f2e4c96351b07805d83c9a7b33d28580ae01374430f7/debugpy-1.8.20-cp312-cp312-macosx_15_0_universal2.whl", hash = "sha256:4ae3135e2089905a916909ef31922b2d733d756f66d87345b3e5e52b7a55f13d", size = 2550686, upload-time = "2026-01-29T23:03:42.023Z" }, + { url = "https://files.pythonhosted.org/packages/ab/78/b193a3975ca34458f6f0e24aaf5c3e3da72f5401f6054c0dfd004b41726f/debugpy-1.8.20-cp312-cp312-manylinux_2_34_x86_64.whl", hash = "sha256:88f47850a4284b88bd2bfee1f26132147d5d504e4e86c22485dfa44b97e19b4b", size = 4310588, upload-time = "2026-01-29T23:03:43.314Z" }, + { url = "https://files.pythonhosted.org/packages/c1/55/f14deb95eaf4f30f07ef4b90a8590fc05d9e04df85ee379712f6fb6736d7/debugpy-1.8.20-cp312-cp312-win32.whl", hash = "sha256:4057ac68f892064e5f98209ab582abfee3b543fb55d2e87610ddc133a954d390", size = 5331372, upload-time = "2026-01-29T23:03:45.526Z" }, + { url = "https://files.pythonhosted.org/packages/a1/39/2bef246368bd42f9bd7cba99844542b74b84dacbdbea0833e610f384fee8/debugpy-1.8.20-cp312-cp312-win_amd64.whl", hash = "sha256:a1a8f851e7cf171330679ef6997e9c579ef6dd33c9098458bd9986a0f4ca52e3", size = 5372835, upload-time = "2026-01-29T23:03:47.245Z" }, + { url = "https://files.pythonhosted.org/packages/15/e2/fc500524cc6f104a9d049abc85a0a8b3f0d14c0a39b9c140511c61e5b40b/debugpy-1.8.20-cp313-cp313-macosx_15_0_universal2.whl", hash = "sha256:5dff4bb27027821fdfcc9e8f87309a28988231165147c31730128b1c983e282a", size = 2539560, upload-time = "2026-01-29T23:03:48.738Z" }, + { url = "https://files.pythonhosted.org/packages/90/83/fb33dcea789ed6018f8da20c5a9bc9d82adc65c0c990faed43f7c955da46/debugpy-1.8.20-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:84562982dd7cf5ebebfdea667ca20a064e096099997b175fe204e86817f64eaf", size = 4293272, upload-time = "2026-01-29T23:03:50.169Z" }, + { url = "https://files.pythonhosted.org/packages/a6/25/b1e4a01bfb824d79a6af24b99ef291e24189080c93576dfd9b1a2815cd0f/debugpy-1.8.20-cp313-cp313-win32.whl", hash = "sha256:da11dea6447b2cadbf8ce2bec59ecea87cc18d2c574980f643f2d2dfe4862393", size = 5331208, upload-time = "2026-01-29T23:03:51.547Z" }, + { url = "https://files.pythonhosted.org/packages/13/f7/a0b368ce54ffff9e9028c098bd2d28cfc5b54f9f6c186929083d4c60ba58/debugpy-1.8.20-cp313-cp313-win_amd64.whl", hash = "sha256:eb506e45943cab2efb7c6eafdd65b842f3ae779f020c82221f55aca9de135ed7", size = 5372930, upload-time = "2026-01-29T23:03:53.585Z" }, + { url = "https://files.pythonhosted.org/packages/33/2e/f6cb9a8a13f5058f0a20fe09711a7b726232cd5a78c6a7c05b2ec726cff9/debugpy-1.8.20-cp314-cp314-macosx_15_0_universal2.whl", hash = "sha256:9c74df62fc064cd5e5eaca1353a3ef5a5d50da5eb8058fcef63106f7bebe6173", size = 2538066, upload-time = "2026-01-29T23:03:54.999Z" }, + { url = "https://files.pythonhosted.org/packages/c5/56/6ddca50b53624e1ca3ce1d1e49ff22db46c47ea5fb4c0cc5c9b90a616364/debugpy-1.8.20-cp314-cp314-manylinux_2_34_x86_64.whl", hash = "sha256:077a7447589ee9bc1ff0cdf443566d0ecf540ac8aa7333b775ebcb8ce9f4ecad", size = 4269425, upload-time = "2026-01-29T23:03:56.518Z" }, + { url = "https://files.pythonhosted.org/packages/c5/d9/d64199c14a0d4c476df46c82470a3ce45c8d183a6796cfb5e66533b3663c/debugpy-1.8.20-cp314-cp314-win32.whl", hash = "sha256:352036a99dd35053b37b7803f748efc456076f929c6a895556932eaf2d23b07f", size = 5331407, upload-time = "2026-01-29T23:03:58.481Z" }, + { url = "https://files.pythonhosted.org/packages/e0/d9/1f07395b54413432624d61524dfd98c1a7c7827d2abfdb8829ac92638205/debugpy-1.8.20-cp314-cp314-win_amd64.whl", hash = "sha256:a98eec61135465b062846112e5ecf2eebb855305acc1dfbae43b72903b8ab5be", size = 5372521, upload-time = "2026-01-29T23:03:59.864Z" }, + { url = "https://files.pythonhosted.org/packages/e0/c3/7f67dea8ccf8fdcb9c99033bbe3e90b9e7395415843accb81428c441be2d/debugpy-1.8.20-py2.py3-none-any.whl", hash = "sha256:5be9bed9ae3be00665a06acaa48f8329d2b9632f15fd09f6a9a8c8d9907e54d7", size = 5337658, upload-time = "2026-01-29T23:04:17.404Z" }, +] + +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, +] + +[[package]] +name = "executing" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488, upload-time = "2025-09-01T09:48:10.866Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, +] + [[package]] name = "filelock" version = "3.25.2" @@ -269,6 +392,63 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "ipykernel" +version = "7.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appnope", marker = "sys_platform == 'darwin'" }, + { name = "comm" }, + { name = "debugpy" }, + { name = "ipython" }, + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "matplotlib-inline" }, + { name = "nest-asyncio" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/8d/b68b728e2d06b9e0051019640a40a9eb7a88fcd82c2e1b5ce70bef5ff044/ipykernel-7.2.0.tar.gz", hash = "sha256:18ed160b6dee2cbb16e5f3575858bc19d8f1fe6046a9a680c708494ce31d909e", size = 176046, upload-time = "2026-02-06T16:43:27.403Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/b9/e73d5d9f405cba7706c539aa8b311b49d4c2f3d698d9c12f815231169c71/ipykernel-7.2.0-py3-none-any.whl", hash = "sha256:3bbd4420d2b3cc105cbdf3756bfc04500b1e52f090a90716851f3916c62e1661", size = 118788, upload-time = "2026-02-06T16:43:25.149Z" }, +] + +[[package]] +name = "ipython" +version = "9.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "decorator" }, + { name = "ipython-pygments-lexers" }, + { name = "jedi" }, + { name = "matplotlib-inline" }, + { name = "pexpect", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit" }, + { name = "pygments" }, + { name = "stack-data" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/73/7114f80a8f9cabdb13c27732dce24af945b2923dcab80723602f7c8bc2d8/ipython-9.12.0.tar.gz", hash = "sha256:01daa83f504b693ba523b5a407246cabde4eb4513285a3c6acaff11a66735ee4", size = 4428879, upload-time = "2026-03-27T09:42:45.312Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/22/906c8108974c673ebef6356c506cebb6870d48cedea3c41e949e2dd556bb/ipython-9.12.0-py3-none-any.whl", hash = "sha256:0f2701e8ee86e117e37f50563205d36feaa259d2e08d4a6bc6b6d74b18ce128d", size = 625661, upload-time = "2026-03-27T09:42:42.831Z" }, +] + +[[package]] +name = "ipython-pygments-lexers" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393, upload-time = "2025-01-17T11:24:34.505Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, +] + [[package]] name = "jaxtyping" version = "0.3.7" @@ -281,6 +461,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/42/caf65e9a0576a3abadc537e2f831701ba9081f21317fb3be87d64451587a/jaxtyping-0.3.7-py3-none-any.whl", hash = "sha256:303ab8599edf412eeb40bf06c863e3168fa186cf0e7334703fa741ddd7046e66", size = 56101, upload-time = "2026-01-30T14:18:45.954Z" }, ] +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -293,6 +485,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "jupyter-client" +version = "8.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-core" }, + { name = "python-dateutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/e4/ba649102a3bc3fbca54e7239fb924fd434c766f855693d86de0b1f2bec81/jupyter_client-8.8.0.tar.gz", hash = "sha256:d556811419a4f2d96c869af34e854e3f059b7cc2d6d01a9cd9c85c267691be3e", size = 348020, upload-time = "2026-01-08T13:55:47.938Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/0b/ceb7694d864abc0a047649aec263878acb9f792e1fec3e676f22dc9015e3/jupyter_client-8.8.0-py3-none-any.whl", hash = "sha256:f93a5b99c5e23a507b773d3a1136bd6e16c67883ccdbd9a829b0bbdb98cd7d7a", size = 107371, upload-time = "2026-01-08T13:55:45.562Z" }, +] + +[[package]] +name = "jupyter-core" +version = "5.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "platformdirs" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/02/49/9d1284d0dc65e2c757b74c6687b6d319b02f822ad039e5c512df9194d9dd/jupyter_core-5.9.1.tar.gz", hash = "sha256:4d09aaff303b9566c3ce657f580bd089ff5c91f5f89cf7d8846c3cdf465b5508", size = 89814, upload-time = "2025-10-16T19:19:18.444Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" }, +] + [[package]] name = "kiwisolver" version = "1.5.0" @@ -496,6 +717,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/49/d651878698a0b67f23aa28e17f45a6d6dd3d3f933fa29087fa4ce5947b5a/matplotlib-3.10.8-cp314-cp314t-win_arm64.whl", hash = "sha256:113bb52413ea508ce954a02c10ffd0d565f9c3bc7f2eddc27dfe1731e71c7b5f", size = 8192560, upload-time = "2025-12-10T22:56:38.008Z" }, ] +[[package]] +name = "matplotlib-inline" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/74/97e72a36efd4ae2bccb3463284300f8953f199b5ffbc04cbbb0ec78f74b1/matplotlib_inline-0.2.1.tar.gz", hash = "sha256:e1ee949c340d771fc39e241ea75683deb94762c8fa5f2927ec57c83c4dffa9fe", size = 8110, upload-time = "2025-10-23T09:00:22.126Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -517,7 +750,9 @@ dependencies = [ ] [package.optional-dependencies] -pyro = [ +examples = [ + { name = "ipykernel" }, + { name = "matplotlib" }, { name = "pyro-ppl" }, ] @@ -533,13 +768,15 @@ dev = [ [package.metadata] requires-dist = [ + { name = "ipykernel", marker = "extra == 'examples'", specifier = ">=7.2.0" }, { name = "jaxtyping", specifier = ">=0.3.5" }, + { name = "matplotlib", marker = "extra == 'examples'", specifier = ">=3.10.8" }, { name = "numpy", specifier = ">=2.4.1" }, - { name = "pyro-ppl", marker = "extra == 'pyro'", specifier = ">=1.9.1" }, + { name = "pyro-ppl", marker = "extra == 'examples'", specifier = ">=1.9.1" }, { name = "torch", specifier = ">=2.10" }, { name = "tqdm", specifier = ">=4.67.1" }, ] -provides-extras = ["pyro"] +provides-extras = ["examples"] [package.metadata.requires-dev] dev = [ @@ -551,6 +788,15 @@ dev = [ { name = "ruff", specifier = ">=0.15.8" }, ] +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, +] + [[package]] name = "networkx" version = "3.6.1" @@ -825,6 +1071,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/2b/f8434233fab2bd66a02ec014febe4e5adced20e2693e0e90a07d118ed30e/pandas-3.0.2-cp314-cp314t-win_arm64.whl", hash = "sha256:5371b72c2d4d415d08765f32d689217a43227484e81b2305b52076e328f6f482", size = 9455341, upload-time = "2026-03-31T06:48:28.418Z" }, ] +[[package]] +name = "parso" +version = "0.8.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/81/76/a1e769043c0c0c9fe391b702539d594731a4362334cdf4dc25d0c09761e7/parso-0.8.6.tar.gz", hash = "sha256:2b9a0332696df97d454fa67b81618fd69c35a7b90327cbe6ba5c92d2c68a7bfd", size = 401621, upload-time = "2026-02-09T15:45:24.425Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/61/fae042894f4296ec49e3f193aff5d7c18440da9e48102c3315e1bc4519a7/parso-0.8.6-py2.py3-none-any.whl", hash = "sha256:2c549f800b70a5c4952197248825584cb00f033b29c692671d3bf08bf380baff", size = 106894, upload-time = "2026-02-09T15:45:21.391Z" }, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772, upload-time = "2023-11-25T06:56:14.81Z" }, +] + [[package]] name = "pillow" version = "12.2.0" @@ -894,6 +1161,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" }, ] +[[package]] +name = "platformdirs" +version = "4.9.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9f/4a/0883b8e3802965322523f0b200ecf33d31f10991d0401162f4b23c698b42/platformdirs-4.9.6.tar.gz", hash = "sha256:3bfa75b0ad0db84096ae777218481852c0ebc6c727b3168c1b9e0118e458cf0a", size = 29400, upload-time = "2026-04-09T00:04:10.812Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/75/a6/a0a304dc33b49145b21f4808d763822111e67d1c3a32b524a1baf947b6e1/platformdirs-4.9.6-py3-none-any.whl", hash = "sha256:e61adb1d5e5cb3441b4b7710bea7e4c12250ca49439228cc1021c00dcfac0917", size = 21348, upload-time = "2026-04-09T00:04:09.463Z" }, +] + [[package]] name = "pluggy" version = "1.6.0" @@ -903,6 +1179,73 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, +] + +[[package]] +name = "psutil" +version = "7.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b", size = 130595, upload-time = "2026-01-28T18:14:57.293Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea", size = 131082, upload-time = "2026-01-28T18:14:59.732Z" }, + { url = "https://files.pythonhosted.org/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63", size = 181476, upload-time = "2026-01-28T18:15:01.884Z" }, + { url = "https://files.pythonhosted.org/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312", size = 184062, upload-time = "2026-01-28T18:15:04.436Z" }, + { url = "https://files.pythonhosted.org/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b", size = 139893, upload-time = "2026-01-28T18:15:06.378Z" }, + { url = "https://files.pythonhosted.org/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9", size = 135589, upload-time = "2026-01-28T18:15:08.03Z" }, + { url = "https://files.pythonhosted.org/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00", size = 130664, upload-time = "2026-01-28T18:15:09.469Z" }, + { url = "https://files.pythonhosted.org/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9", size = 131087, upload-time = "2026-01-28T18:15:11.724Z" }, + { url = "https://files.pythonhosted.org/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a", size = 182383, upload-time = "2026-01-28T18:15:13.445Z" }, + { url = "https://files.pythonhosted.org/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf", size = 185210, upload-time = "2026-01-28T18:15:16.002Z" }, + { url = "https://files.pythonhosted.org/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1", size = 141228, upload-time = "2026-01-28T18:15:18.385Z" }, + { url = "https://files.pythonhosted.org/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841", size = 136284, upload-time = "2026-01-28T18:15:19.912Z" }, + { url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" }, + { url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" }, + { url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" }, + { url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" }, + { url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" }, + { url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" }, + { url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" }, + { url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" }, +] + +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762, upload-time = "2020-12-28T15:15:30.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993, upload-time = "2020-12-28T15:15:28.35Z" }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752, upload-time = "2024-07-21T12:58:21.801Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, +] + +[[package]] +name = "pycparser" +version = "3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/7d/92392ff7815c21062bea51aa7b87d45576f649f16458d78b7cf94b9ab2e6/pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29", size = 103492, upload-time = "2026-01-21T14:26:51.89Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992", size = 48172, upload-time = "2026-01-21T14:26:50.693Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -988,6 +1331,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "pyzmq" +version = "27.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/0b/3c9baedbdf613ecaa7aa07027780b8867f57b6293b6ee50de316c9f3222b/pyzmq-27.1.0.tar.gz", hash = "sha256:ac0765e3d44455adb6ddbf4417dcce460fc40a05978c08efdf2948072f6db540", size = 281750, upload-time = "2025-09-08T23:10:18.157Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/e7/038aab64a946d535901103da16b953c8c9cc9c961dadcbf3609ed6428d23/pyzmq-27.1.0-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:452631b640340c928fa343801b0d07eb0c3789a5ffa843f6e1a9cee0ba4eb4fc", size = 1306279, upload-time = "2025-09-08T23:08:03.807Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5e/c3c49fdd0f535ef45eefcc16934648e9e59dace4a37ee88fc53f6cd8e641/pyzmq-27.1.0-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1c179799b118e554b66da67d88ed66cd37a169f1f23b5d9f0a231b4e8d44a113", size = 895645, upload-time = "2025-09-08T23:08:05.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e5/b0b2504cb4e903a74dcf1ebae157f9e20ebb6ea76095f6cfffea28c42ecd/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3837439b7f99e60312f0c926a6ad437b067356dc2bc2ec96eb395fd0fe804233", size = 652574, upload-time = "2025-09-08T23:08:06.828Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9b/c108cdb55560eaf253f0cbdb61b29971e9fb34d9c3499b0e96e4e60ed8a5/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43ad9a73e3da1fab5b0e7e13402f0b2fb934ae1c876c51d0afff0e7c052eca31", size = 840995, upload-time = "2025-09-08T23:08:08.396Z" }, + { url = "https://files.pythonhosted.org/packages/c2/bb/b79798ca177b9eb0825b4c9998c6af8cd2a7f15a6a1a4272c1d1a21d382f/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0de3028d69d4cdc475bfe47a6128eb38d8bc0e8f4d69646adfbcd840facbac28", size = 1642070, upload-time = "2025-09-08T23:08:09.989Z" }, + { url = "https://files.pythonhosted.org/packages/9c/80/2df2e7977c4ede24c79ae39dcef3899bfc5f34d1ca7a5b24f182c9b7a9ca/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:cf44a7763aea9298c0aa7dbf859f87ed7012de8bda0f3977b6fb1d96745df856", size = 2021121, upload-time = "2025-09-08T23:08:11.907Z" }, + { url = "https://files.pythonhosted.org/packages/46/bd/2d45ad24f5f5ae7e8d01525eb76786fa7557136555cac7d929880519e33a/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f30f395a9e6fbca195400ce833c731e7b64c3919aa481af4d88c3759e0cb7496", size = 1878550, upload-time = "2025-09-08T23:08:13.513Z" }, + { url = "https://files.pythonhosted.org/packages/e6/2f/104c0a3c778d7c2ab8190e9db4f62f0b6957b53c9d87db77c284b69f33ea/pyzmq-27.1.0-cp312-abi3-win32.whl", hash = "sha256:250e5436a4ba13885494412b3da5d518cd0d3a278a1ae640e113c073a5f88edd", size = 559184, upload-time = "2025-09-08T23:08:15.163Z" }, + { url = "https://files.pythonhosted.org/packages/fc/7f/a21b20d577e4100c6a41795842028235998a643b1ad406a6d4163ea8f53e/pyzmq-27.1.0-cp312-abi3-win_amd64.whl", hash = "sha256:9ce490cf1d2ca2ad84733aa1d69ce6855372cb5ce9223802450c9b2a7cba0ccf", size = 619480, upload-time = "2025-09-08T23:08:17.192Z" }, + { url = "https://files.pythonhosted.org/packages/78/c2/c012beae5f76b72f007a9e91ee9401cb88c51d0f83c6257a03e785c81cc2/pyzmq-27.1.0-cp312-abi3-win_arm64.whl", hash = "sha256:75a2f36223f0d535a0c919e23615fc85a1e23b71f40c7eb43d7b1dedb4d8f15f", size = 552993, upload-time = "2025-09-08T23:08:18.926Z" }, + { url = "https://files.pythonhosted.org/packages/60/cb/84a13459c51da6cec1b7b1dc1a47e6db6da50b77ad7fd9c145842750a011/pyzmq-27.1.0-cp313-cp313-android_24_arm64_v8a.whl", hash = "sha256:93ad4b0855a664229559e45c8d23797ceac03183c7b6f5b4428152a6b06684a5", size = 1122436, upload-time = "2025-09-08T23:08:20.801Z" }, + { url = "https://files.pythonhosted.org/packages/dc/b6/94414759a69a26c3dd674570a81813c46a078767d931a6c70ad29fc585cb/pyzmq-27.1.0-cp313-cp313-android_24_x86_64.whl", hash = "sha256:fbb4f2400bfda24f12f009cba62ad5734148569ff4949b1b6ec3b519444342e6", size = 1156301, upload-time = "2025-09-08T23:08:22.47Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ad/15906493fd40c316377fd8a8f6b1f93104f97a752667763c9b9c1b71d42d/pyzmq-27.1.0-cp313-cp313t-macosx_10_15_universal2.whl", hash = "sha256:e343d067f7b151cfe4eb3bb796a7752c9d369eed007b91231e817071d2c2fec7", size = 1341197, upload-time = "2025-09-08T23:08:24.286Z" }, + { url = "https://files.pythonhosted.org/packages/14/1d/d343f3ce13db53a54cb8946594e567410b2125394dafcc0268d8dda027e0/pyzmq-27.1.0-cp313-cp313t-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:08363b2011dec81c354d694bdecaef4770e0ae96b9afea70b3f47b973655cc05", size = 897275, upload-time = "2025-09-08T23:08:26.063Z" }, + { url = "https://files.pythonhosted.org/packages/69/2d/d83dd6d7ca929a2fc67d2c3005415cdf322af7751d773524809f9e585129/pyzmq-27.1.0-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d54530c8c8b5b8ddb3318f481297441af102517602b569146185fa10b63f4fa9", size = 660469, upload-time = "2025-09-08T23:08:27.623Z" }, + { url = "https://files.pythonhosted.org/packages/3e/cd/9822a7af117f4bc0f1952dbe9ef8358eb50a24928efd5edf54210b850259/pyzmq-27.1.0-cp313-cp313t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6f3afa12c392f0a44a2414056d730eebc33ec0926aae92b5ad5cf26ebb6cc128", size = 847961, upload-time = "2025-09-08T23:08:29.672Z" }, + { url = "https://files.pythonhosted.org/packages/9a/12/f003e824a19ed73be15542f172fd0ec4ad0b60cf37436652c93b9df7c585/pyzmq-27.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c65047adafe573ff023b3187bb93faa583151627bc9c51fc4fb2c561ed689d39", size = 1650282, upload-time = "2025-09-08T23:08:31.349Z" }, + { url = "https://files.pythonhosted.org/packages/d5/4a/e82d788ed58e9a23995cee70dbc20c9aded3d13a92d30d57ec2291f1e8a3/pyzmq-27.1.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:90e6e9441c946a8b0a667356f7078d96411391a3b8f80980315455574177ec97", size = 2024468, upload-time = "2025-09-08T23:08:33.543Z" }, + { url = "https://files.pythonhosted.org/packages/d9/94/2da0a60841f757481e402b34bf4c8bf57fa54a5466b965de791b1e6f747d/pyzmq-27.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:add071b2d25f84e8189aaf0882d39a285b42fa3853016ebab234a5e78c7a43db", size = 1885394, upload-time = "2025-09-08T23:08:35.51Z" }, + { url = "https://files.pythonhosted.org/packages/4f/6f/55c10e2e49ad52d080dc24e37adb215e5b0d64990b57598abc2e3f01725b/pyzmq-27.1.0-cp313-cp313t-win32.whl", hash = "sha256:7ccc0700cfdf7bd487bea8d850ec38f204478681ea02a582a8da8171b7f90a1c", size = 574964, upload-time = "2025-09-08T23:08:37.178Z" }, + { url = "https://files.pythonhosted.org/packages/87/4d/2534970ba63dd7c522d8ca80fb92777f362c0f321900667c615e2067cb29/pyzmq-27.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:8085a9fba668216b9b4323be338ee5437a235fe275b9d1610e422ccc279733e2", size = 641029, upload-time = "2025-09-08T23:08:40.595Z" }, + { url = "https://files.pythonhosted.org/packages/f6/fa/f8aea7a28b0641f31d40dea42d7ef003fded31e184ef47db696bc74cd610/pyzmq-27.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:6bb54ca21bcfe361e445256c15eedf083f153811c37be87e0514934d6913061e", size = 561541, upload-time = "2025-09-08T23:08:42.668Z" }, + { url = "https://files.pythonhosted.org/packages/87/45/19efbb3000956e82d0331bafca5d9ac19ea2857722fa2caacefb6042f39d/pyzmq-27.1.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:ce980af330231615756acd5154f29813d553ea555485ae712c491cd483df6b7a", size = 1341197, upload-time = "2025-09-08T23:08:44.973Z" }, + { url = "https://files.pythonhosted.org/packages/48/43/d72ccdbf0d73d1343936296665826350cb1e825f92f2db9db3e61c2162a2/pyzmq-27.1.0-cp314-cp314t-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1779be8c549e54a1c38f805e56d2a2e5c009d26de10921d7d51cfd1c8d4632ea", size = 897175, upload-time = "2025-09-08T23:08:46.601Z" }, + { url = "https://files.pythonhosted.org/packages/2f/2e/a483f73a10b65a9ef0161e817321d39a770b2acf8bcf3004a28d90d14a94/pyzmq-27.1.0-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7200bb0f03345515df50d99d3db206a0a6bee1955fbb8c453c76f5bf0e08fb96", size = 660427, upload-time = "2025-09-08T23:08:48.187Z" }, + { url = "https://files.pythonhosted.org/packages/f5/d2/5f36552c2d3e5685abe60dfa56f91169f7a2d99bbaf67c5271022ab40863/pyzmq-27.1.0-cp314-cp314t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:01c0e07d558b06a60773744ea6251f769cd79a41a97d11b8bf4ab8f034b0424d", size = 847929, upload-time = "2025-09-08T23:08:49.76Z" }, + { url = "https://files.pythonhosted.org/packages/c4/2a/404b331f2b7bf3198e9945f75c4c521f0c6a3a23b51f7a4a401b94a13833/pyzmq-27.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:80d834abee71f65253c91540445d37c4c561e293ba6e741b992f20a105d69146", size = 1650193, upload-time = "2025-09-08T23:08:51.7Z" }, + { url = "https://files.pythonhosted.org/packages/1c/0b/f4107e33f62a5acf60e3ded67ed33d79b4ce18de432625ce2fc5093d6388/pyzmq-27.1.0-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:544b4e3b7198dde4a62b8ff6685e9802a9a1ebf47e77478a5eb88eca2a82f2fd", size = 2024388, upload-time = "2025-09-08T23:08:53.393Z" }, + { url = "https://files.pythonhosted.org/packages/0d/01/add31fe76512642fd6e40e3a3bd21f4b47e242c8ba33efb6809e37076d9b/pyzmq-27.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cedc4c68178e59a4046f97eca31b148ddcf51e88677de1ef4e78cf06c5376c9a", size = 1885316, upload-time = "2025-09-08T23:08:55.702Z" }, + { url = "https://files.pythonhosted.org/packages/c4/59/a5f38970f9bf07cee96128de79590bb354917914a9be11272cfc7ff26af0/pyzmq-27.1.0-cp314-cp314t-win32.whl", hash = "sha256:1f0b2a577fd770aa6f053211a55d1c47901f4d537389a034c690291485e5fe92", size = 587472, upload-time = "2025-09-08T23:08:58.18Z" }, + { url = "https://files.pythonhosted.org/packages/70/d8/78b1bad170f93fcf5e3536e70e8fadac55030002275c9a29e8f5719185de/pyzmq-27.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:19c9468ae0437f8074af379e986c5d3d7d7bfe033506af442e8c879732bedbe0", size = 661401, upload-time = "2025-09-08T23:08:59.802Z" }, + { url = "https://files.pythonhosted.org/packages/81/d6/4bfbb40c9a0b42fc53c7cf442f6385db70b40f74a783130c5d0a5aa62228/pyzmq-27.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:dc5dbf68a7857b59473f7df42650c621d7e8923fb03fa74a526890f4d33cc4d7", size = 575170, upload-time = "2025-09-08T23:09:01.418Z" }, +] + [[package]] name = "ruff" version = "0.15.8" @@ -1031,6 +1417,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707, upload-time = "2023-09-30T13:58:05.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -1103,6 +1503,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/66/4d/35352043ee0eaffdeff154fad67cd4a31dbed7ff8e3be1cc4549717d6d51/torch-2.10.0-cp314-cp314t-win_amd64.whl", hash = "sha256:71283a373f0ee2c89e0f0d5f446039bdabe8dbc3c9ccf35f0f784908b0acd185", size = 113995816, upload-time = "2026-01-21T16:22:05.312Z" }, ] +[[package]] +name = "tornado" +version = "6.5.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/f1/3173dfa4a18db4a9b03e5d55325559dab51ee653763bb8745a75af491286/tornado-6.5.5.tar.gz", hash = "sha256:192b8f3ea91bd7f1f50c06955416ed76c6b72f96779b962f07f911b91e8d30e9", size = 516006, upload-time = "2026-03-10T21:31:02.067Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/8c/77f5097695f4dd8255ecbd08b2a1ed8ba8b953d337804dd7080f199e12bf/tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa", size = 445983, upload-time = "2026-03-10T21:30:44.28Z" }, + { url = "https://files.pythonhosted.org/packages/ab/5e/7625b76cd10f98f1516c36ce0346de62061156352353ef2da44e5c21523c/tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521", size = 444246, upload-time = "2026-03-10T21:30:46.571Z" }, + { url = "https://files.pythonhosted.org/packages/b2/04/7b5705d5b3c0fab088f434f9c83edac1573830ca49ccf29fb83bf7178eec/tornado-6.5.5-cp39-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e74c92e8e65086b338fd56333fb9a68b9f6f2fe7ad532645a290a464bcf46be5", size = 447229, upload-time = "2026-03-10T21:30:48.273Z" }, + { url = "https://files.pythonhosted.org/packages/34/01/74e034a30ef59afb4097ef8659515e96a39d910b712a89af76f5e4e1f93c/tornado-6.5.5-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:435319e9e340276428bbdb4e7fa732c2d399386d1de5686cb331ec8eee754f07", size = 448192, upload-time = "2026-03-10T21:30:51.22Z" }, + { url = "https://files.pythonhosted.org/packages/be/00/fe9e02c5a96429fce1a1d15a517f5d8444f9c412e0bb9eadfbe3b0fc55bf/tornado-6.5.5-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3f54aa540bdbfee7b9eb268ead60e7d199de5021facd276819c193c0fb28ea4e", size = 448039, upload-time = "2026-03-10T21:30:53.52Z" }, + { url = "https://files.pythonhosted.org/packages/82/9e/656ee4cec0398b1d18d0f1eb6372c41c6b889722641d84948351ae19556d/tornado-6.5.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:36abed1754faeb80fbd6e64db2758091e1320f6bba74a4cf8c09cd18ccce8aca", size = 447445, upload-time = "2026-03-10T21:30:55.541Z" }, + { url = "https://files.pythonhosted.org/packages/5a/76/4921c00511f88af86a33de770d64141170f1cfd9c00311aea689949e274e/tornado-6.5.5-cp39-abi3-win32.whl", hash = "sha256:dd3eafaaeec1c7f2f8fdcd5f964e8907ad788fe8a5a32c4426fbbdda621223b7", size = 448582, upload-time = "2026-03-10T21:30:57.142Z" }, + { url = "https://files.pythonhosted.org/packages/2c/23/f6c6112a04d28eed765e374435fb1a9198f73e1ec4b4024184f21faeb1ad/tornado-6.5.5-cp39-abi3-win_amd64.whl", hash = "sha256:6443a794ba961a9f619b1ae926a2e900ac20c34483eea67be4ed8f1e58d3ef7b", size = 448990, upload-time = "2026-03-10T21:30:58.857Z" }, + { url = "https://files.pythonhosted.org/packages/b7/c8/876602cbc96469911f0939f703453c1157b0c826ecb05bdd32e023397d4e/tornado-6.5.5-cp39-abi3-win_arm64.whl", hash = "sha256:2c9a876e094109333f888539ddb2de4361743e5d21eece20688e3e351e4990a6", size = 448016, upload-time = "2026-03-10T21:31:00.43Z" }, +] + [[package]] name = "tqdm" version = "4.67.3" @@ -1115,6 +1532,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/16/e1/3079a9ff9b8e11b846c6ac5c8b5bfb7ff225eee721825310c91b3b50304f/tqdm-4.67.3-py3-none-any.whl", hash = "sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf", size = 78374, upload-time = "2026-02-03T17:35:50.982Z" }, ] +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621, upload-time = "2024-04-19T11:11:49.746Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, +] + [[package]] name = "triton" version = "3.6.0" @@ -1153,3 +1579,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/1e/67/cbae4bf7683a64755 wheels = [ { url = "https://files.pythonhosted.org/packages/8d/96/04e7b441807b26b794da5b11e59ed7f83b2cf8af202bd7eba8ad2fa6046e/wadler_lindig-0.1.7-py3-none-any.whl", hash = "sha256:e3ec83835570fd0a9509f969162aeb9c65618f998b1f42918cfc8d45122fe953", size = 20516, upload-time = "2025-06-18T07:00:41.684Z" }, ] + +[[package]] +name = "wcwidth" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" }, +] From acb22530455d11de954e1d34e5f72c2b3c1cf6a1 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 12 Apr 2026 18:22:20 -0400 Subject: [PATCH 04/60] Working logistic regression example --- .../data/create_simple_linear.py | 23 ++++ .../data/simple_linear.npy | Bin 0 -> 24128 bytes examples/logistic_regression/simple_linear.py | 127 ++++++++++++++++++ src/nak_torch/__init__.py | 4 +- src/nak_torch/tools/__init__.py | 3 +- src/nak_torch/tools/pyro_tools.py | 96 ++++++++++++- src/nak_torch/tools/types.py | 109 ++++++++++++++- 7 files changed, 354 insertions(+), 8 deletions(-) create mode 100644 examples/logistic_regression/data/create_simple_linear.py create mode 100644 examples/logistic_regression/data/simple_linear.npy create mode 100644 examples/logistic_regression/simple_linear.py diff --git a/examples/logistic_regression/data/create_simple_linear.py b/examples/logistic_regression/data/create_simple_linear.py new file mode 100644 index 0000000..8c28ccd --- /dev/null +++ b/examples/logistic_regression/data/create_simple_linear.py @@ -0,0 +1,23 @@ +import numpy as np +import os + +def sigmoid(x): + out = np.empty_like(x) + x_neg, x_pos = x[x < 0], x[x >= 0] + out[x < 0] = np.exp(x_neg) / (1 + np.exp(x_neg)) + out[x >= 0] = 1 / (1 + np.exp(-x_pos)) + return out + +if __name__ == '__main__': + np.random.seed(0) + N_SAMPLES = 1000 + BETA_X, BETA_Y, BETA_0 = -3., -3., 2. + x = np.random.randn(N_SAMPLES) + y = np.random.randn(N_SAMPLES) + logits = BETA_X * x + BETA_Y * y + BETA_0 + probs = sigmoid(logits) + aux_z = np.random.rand(N_SAMPLES) + label_pts = aux_z < probs + data = np.column_stack((x, y, label_pts)) + fname = os.path.join(os.path.dirname(__file__), "simple_linear.npy") + np.save(fname, data) \ No newline at end of file diff --git a/examples/logistic_regression/data/simple_linear.npy b/examples/logistic_regression/data/simple_linear.npy new file mode 100644 index 0000000000000000000000000000000000000000..c053887ad444d955b819de124334a9b4efd50ae9 GIT binary patch literal 24128 zcmbV!cRZKh`~Pi4B4wAPlvQTKXgIQql2tS*6p51TLJ=vm%v6ZVUfJW=d+)vXeA%1t z>-TuO@B8!jum8Fp*W-Df=Q`Is_j&npL-YDgZ4$%+vE|XzfA!dm=fW|bO9tXR{Kt3< zOwG*AbWI+bn(68PKe@cF@hg3z{MA$47y88i^NI)y3-cd4%g2B0&9VQ#f7B7R!wVh~ zn;^1!NYbpLA0WH`)IV0O^9Q$sp_xW~F3%}U9$n6<)}T5IRM$Q7&rx>5|I`WKhRC3y z_B4r9D1ZCl$C^w`zvFVEom2l7xNP=izFoK)lfU~MA>Ac63om?|zs6cozAGnyao;nI z98YFIr{;CFEiYQy()khG939s^J zcgDb0)R%i}u|*h;SMnDsR&EFqR>Bkik*#6+m9IH9uRQGn9nuN2%Sz)I{*dpm*W{gM z5coP-CWvYplV1$H{u++1fJ1`rYfAsBFq~AQ1gQ||3fQ90t1;HOj_Esjldxak+XUFW z^60nUS7Q1E=eK@&8k0~ia+~(vVEV3{04`iem~u8I&#QjBed&}G0h1f>nSOdXF$4l` z2};eFj9@rMUS@|aA}0aU^vTKjv<`^+L#xoe&N?^>wbV9#YkAJ?$_XH+7#?}eODZKo zHOd>Lzk%Vu@-;X@M{r24o*!0*8Yzpebc-D4qh^jUhf_$nIO2v99|DM0J0HC_<(Qj@z&Z^GSXUpYn(CxBN-}j>SU3~&r@~jS<)@pr4ezt!M$TDi`web%^G;ff*MST9yum-z4=c`Z5g<$%>Dk$se zFIK{1znOQgYEEMMoS_fjME;qAC;l>I=ItHDAi* zKs#)f|HG9sfbviIlV2l8nFtPlVGcUbVSVu~rybqZan>1nzI)zxsR|U52?|b$#lccRs z`X-xTVEY1w^Q49HjuOQLO#i(lA7hQ?69TYEbaU+Zu>pmo$1a!tOvUt7(#7tMq?CX| z{m=B2gfcLE=jm&YOHl)KKOlJ2L;8#e*gN^wZC=+T#A;{Zi+?m5#)>P+d)swZSdz#u$O!gnx7HGP9SAP zp=%yU^;voRO~Cg>9bYn0LRkup!fP%GJ;(R;V^qwc`uH5QY)w8{96OBdOPP_ogsPVS zEWZh5pLLDel_Q8E-HW*2)$OpVFTje46OYqA!BJ06&jm1ZM6)vcYxORF1kvD5;mZG8 z0dl>|NIk^yI;bfj`YuFj78Hz^KU20u^8kWev%X1^QquysJuY!-oFT#DS2yXbCI}t@ zEQRYH;)l~P`&5}VPEk2JgK>xF8_;kR**s`ln0Ow6K14J_0!SZc;N<#_M4hV8RhcV>Q@oiu^NoW2p2CZuda(A4k=9$+L ztk))G%8&ht@%h#_#xSP72uC)Ixv6B(zJdU>UsOpdQ_KRVjB4*3bM(9j;PBZqH0;Im zaJ{_SO8(dYhEvoa?LvQO2b}YMoG+S)<{1R3_}rSCWz_*KE>Xb*v3iXEBTD5v+>@iw z2?%~ZR@H~eujuD=J}BA-%aRT+RTUO@fumXfwuDvK->BZ!CPjWjt6Z^fT z3oVbI7vu3*DT6#v-JSxedw!UF=E8A?Io3|`c}~JBm%p_zxT1Xs0Ze7b$cA`s1CQCY z@ar=x5RG>^+WWyP-1YE!r;0lwN{*fHlhd?&Gh)ZVTGkYEpz0>Zr=7j|?x4yH2yf5h z-{)Gqi-RD3B*J44S_@%j_k4S-5T4(7j|7%}{5=RLUN2Yq2rpqcD?4>wY44^%@x?du zG1Mg(PWf%)BX?QaKwN!k8udfme}q{+D{}eI!Pjw@ZSKUN^^^cUdGdN|GB<*-i0dqw zm+<&1y-oV8Pkad|D|o)TE5oju-|l9q+yW>D6QFe&z7?HF-4v?k^TqyH(-+@Pl6t-&`Xrpy|YFUj)TJY+bVR zs~`Eb+JT*WNgBEJ@~#{~4okQm5PUNS7ycmPKDh;${^HH?N#U*)*!@mrMKBrn!$RJo z2ir%xz*kS1N&S0^7#|L%p6?!-#Xy9X-@iq96~lig7h@)%HU!3Uobu>%aJvcU37n5R zKqe(hXu4NuF^1_6WPREvATa~RgvDCO1@Qf@?1^4%IX?|=2g;WH%0&H{*yn%zW2M|X z2gE1-Xv&<<-{pfKIy#lpi?mBH^~$Hk3b%z_d;)k(?Z^}4)(s!Gm^o6%m+k5kz#Ugs z9!0A}s8e$-Sb{DTlcx$4@9!|&0G=~FHK*UMVsd^jmjhJIqi|%2Ij-3p?@u%(mM;pK zO+YbGs!ru?{CwLZzsm}p%Ycs5^Db(^=v;;%TNg*iB9dhoUxXXCD%08IWYGE?1~htZ9xIdVF9KWu07{Jtbb z7Zg628QWQf^Y4B`?oim<3muO)4etQ2m#7BSdLP#;rx}&ONVvSR$;g;q4Jynp09!~(7K!^T>v{qK^!Z| zs67ebk?h5&a>XWa$U%>!KntH+-w%D7im{vq3)0O;i{jBf7D2MPO`@p`ngE@Ej{C{? zxPK1CQd99hn*_`6fU8i^pRe*yEp^@ zYs9Kw@^r)Z7V|=0^#RCV(PTFo`^0WNU zr>l0Bpm~65yOWFuh9ev^_QN?~2xeE_c)0H(S~m&6gWTEQrJN*B$?nSYiu4vte{50h z@i55{SR02j>m32;e7yYo#tjmwtu^x$4OS!%ts_l0ypBJ6 zv;@XGTvxfLRul2r>O^B{1nD>mc@x_{Z|ci;_i>teuCF8L#Bg~=K|({ z!SHO6@m;jPLXbl-so|3=L*V7|=sb7!5Qgu5RqV);?GPv=OZlTCfyV_O5A$`}-Zl8Z z6yXv!!RN2EXVX6OV*${yL`8D4XL1*Z00?3#JM-U1;kC|h&6=z@&+3z&-HH?SAa)@u zoH7UPiwM9#;M>(l$1T2aAbZi==K6_|cb-EwpIjXc7a{YP}=p%_O zav5pI);&+jEgi@=4r1vKs#-YW@t$NCYg^mi4LKBC8n^_eF#Xj9cxP0627ZoAVEjXe z=fj*&?+#v(?0|O+rW@vwR7}6OBj8unb8omMa{byo9iAV$z1YGjWZS@rwMa5Da{Rmo z#6*1#Fpx_5*=^cJ9YyViARH%a(=)>&V2Af>&bwr|zj-)aT0P`H0|hfE@@E9_c%da9 z&-FUl0-dBrGscgg{!ajnx<;=&qPw9vH=PcmkN1H_e7({Bbv?lLt{J&8A3g_p?-3|I z{AU%8pZ>DqQ&nRJk7gyAols; z-;>zVt3ge!s`Z_K0t~0%TauqB`yi}KDV`8HKZeQ6{vJ6dZe9aF8ZCS$QO4_6r-;e89F6zk5yrO2!o8t4GNtg%rI%~jc-4Ob5)Ct`eNiT zsAfJplKW*2lM~KLX%6TQ!TwV|jyqnYQvcZ>K>|Lx6^t0SL*v`t9Z_pJ82)n!M}~`} z`S6x!8La>->i+~#w?uA3Le~h*oR4Q;&_e4eac+{>quEo^2=%jn<_|S?VmLS5NFR*X zt-}I$KLK+^9Otjxeu)Lsc`#tjY(xG6@8`uX&~Y7=p94t&J9Xh?ZM*ytq}M&fmXUG- zW-`#nsRs6h|`>cjKj0fNv ziMQRjHa?e~(<6K7e7XS4nrNCBrKA3UAYXkMU-y|d!8?xa1dY>peh7ND+RJo&1T>K5 zNRL)GW9wf3ec$8J#Su{aHL_~){30f=p;Oyt5g&yOoww+z#P{U=XMX}96A_e@=3IuK zNK=Cxp3ehRZZ7;zG);XCTC3fM%N_W<$@6Gh^MdOjAnX))|4#10@Ry=C6ul`HAV-Ps z3!~ON3`g^Z_DqXdA5^Zo;>!O9zn{rdjuaqa8ibk;7^<2c;dNF|BCP1pL=?QrR7^)N zj7a^*2SHw532{X>w&A#4$c5wlc%S)gPfElvUm4h{^W`UOwP5SEapkk674`uJr@G8x zRou^`8xNT`g{1PAp&KHgW5 zic-d}eQXD=P78KnlX(3aZ8AL~C%7k1716WOv0KLYXurAl&q!$pyybH3{XB@*qtL6( z8&_A_z?XwfmppYRFng*Q|6F$I8-k9TDW5#{;r<*T&3{9eq8;+6(yT8z;r=Y>5pcvZ zuoqAS>SYE|<2bqf%-x&mIWWF3Y~EK0&jYJ7pY2;Ek-Smgfzk&|OW3-EF|}q!Cq&9^ z&nqi}3W|>)Er%0r^^~)r5IEb;S&@S2r<))A8KOq)YsB=Z^Y`O^;ypK=F2^+jTc+A9 z3un=MLI8s+C)p6+Rgi1d8_FV$o+E;oXja>pt2cnD%7IgP{CK=L2*n-vZA+Zf?>jPa zh@*9c_#KLG-T!U|0a{<^Ht;0&DX8BH$sQeO*I57_K6myGoJI2zf*k+F_GycL1X$k? zc=|>^1Jhr8A)k7$X%Pfly_0hm$Iqo?a*0vt&=3?*c^E^_faBBDc&~7k_rSytpOzE+ z@%u4x=BU1BCZtl&d`WUYc@^#A6Yo!MJ5OhoGo zv7hvpkM!|shc-<`d8&_aJ`b992Djr{fk#8nbFo`H7@yiQ@nqN0McCQZXXthk_Y*cb zmc9R6*1?6fHSZe>c)T-?KTg>JxiH@RzL?KTJP$`!Br{SJ4Z|Vl)Or|}kKwzGZv~U8 z4}(<6rtF0f)D8&JL(|=6V%H64Pi*RN{m;F+)6Eu z_iB3bcpg*9*O=-FZG(H6sb6vL#qaUTWnXR^^Q41A48pCe2;P6byLXAz-ai5w#V%Q` zQ{nOZSK#uYd-odP^9hlUR}=AmDNp_6Pr&Xy2R$eXM=%_OCUSqQcRk>z3ypJXMe86j z?$kN7b>2GYB(Y9)bocN-D8uuN zoSG2jo7^$*>m%>!Go6_j{;3uFM*g{ZXzTAoMt3rQS55%;$+=iJq?s0 zH5d&J!{GeAQB&;rJ^##lCa1IeI$*4eV&iaI!t|fLZBzNqItOxNx9pxy;&IVlfegu! zcSE;QqtBlL@qA@&;YLwJIs!Xt{H0?jQU532Z-=Tr?=oJ1l@W$7GtLY_G_QVXeBsF; z-Ua7ME)~=XqkTRBPeYt7X|l8F~-<;r*;LZ{>5f8us4jJC~xBoEO(mUzz}Q6k#A zzX(Wsv)|U;K=S|rTpDL}0`CXFCD}yR=Rp(LzT&2yG3V&CfX7#~O6D%$_s`D~)R($n zjlmkSB*oI__;>b~$3{a%NLC>AyA0c`lN3^@pIDvm)#7Mg0%YGym=dr)5J76A?6Ozr zMqux}j6t6oDW-4gKG^8jI|LVdBc3M(;rThw{MmWm*ghz?l`x-ux(u^VPd%r-=$kEA z&{%vh@=*sShu5a$kNNjNzkN=Y@(p+%v)U7)9GBDsS8hEguTt2;^lK1n(M|JJ6ngZ<+kW8lglF|>qCpk$T8+mNV1k`9HTB{l1_mJfs=5fjO)8NSw-$S1> zQ9nnJM7QcUoB7RP!+Ex#Mq+B$&ICYnuPbqNzZYa>tG)S?8qXWXN>i#zfkmJ-@vkUn zAzt4rO-7z+J}HCw<7OorS8;!=!4k#jz4ic$Fu0hU|F>bx_X*m=DV8X=WGz6Ph6CA*Z>@p~L8Jo=-VQ}`YM3S1U8a{G$s6DwU$iASZs zK>O{M#}gx{eF(roWHE8naUFhR7Vz-RM)NB1dtT;;y6YcjAsI~rna&7aFPeWo>JIz7 z0y9VIa&Ol;VEeiq!jN6?U>ztdrrc4^K<$YjsswWxBZ^}1$x2LqZvwh61Q8u?AzbDe z0MvRPD6Ab&yCDe8b^hZ&9IBv6F;9`qMHGhs3TC~}g#2iMty)LKgsTcM`#jX1+soL! z2H77z{Vo%Q;uF8C3Jd?Rh;4w9vR^o74~Jts>n$%$oP00_Z<_ku{(B$KKW{Z6dykk; zf~xDTq;>;&82(eiNy^15^T5UT%P0K=ykGJ+l$T_)TmcHw>qA9>xS!l(kR@SLoCM82 z?34+2WsQ(^gKC`b3SLJFqR4Mt8mNYq8i7nJOrsdiKZ}`miHr>(y>M0`FuiM6jvz;P z9J@_Mhe2*e=`)crd>*g*BV8Vl)e6=szg%qhUI%DhFlURXWE3g@uj5M$pGM(xU-U)q z<0Go8fU$mCdE97E-hcW8z+=(F@Z$XjSg>?aFDJw2kB zT7U9jKT8*V{B|Iw-}lk$rz&kaxOzZPMT0n(qdY^bmNrcUX28>TPwXgfH0;U|B;2UE zt8;=PPj$@XzVXdjOrMN!cf49`00xRueSDUO-`jKv=}sJxuLmi2Y2sv^`{9512=ea; z*U75BZpiIQuNkB=jNv#i_~k}KFGCQi_ohgm9J5=F^}qL@YbF8Uu6U7;;C6of^~{`X zdOgTD4Cp#Wh4$?Pz+R*~bd_NiPAJrx8T^kwYsx*LTj6Q~pKtOu`2Wv$86@WYQIl&E zKs0iWp=M|t+pkW<<*A(5*I=>Z=ZnXu@$+h>=VZL_V;l^(w3tQj>A?8Ep&IGxr1}Tk zB)eRAE-Ya3=&&pY6Qy=&n4V;Cyte?8b8V8jY<`}Hyeh25dJ1@-=CISX*K;QYKGZp? zvwQ-LYXrHnPe;kWrwk-eFxKm94q-eAwARlX_Dq4BNoTGdypPV^1d!?#+ZAV-0}J~9 z#I25@@kjtWpS0PO<-37H&n>@|MM})hdzjqF4)@GJdamY`U&83R1aK@kLRop=IAs3E z_>*l0y)Pwz-oxDN*3#LKM@_(CU=F`$*f;l!wV;;(?|!gBw4%$f{i=U?qh|#sz?@cg zzQk3O4}xeta@U*w)eYEE_P(}@!0$^>avrz+*HsU?D)e;pI`RG1?(>-rN@xc30X}vO zhqN#|G|!U`vYtkyILzPtXmcHdsJ{iN=o5NKdfW+`vEL@0>yWoBt&GQ5>nx^nZE$R<8F zO}K~CFtGH&Q&0Fved@9?dzMm{y#J?6%%_K2yd~CA90cK?(TabSJp+sB-VT>3CSv;S zWUY~6y5rEMH-VmW0Oundxn~cn3Ta-bi7H30au{YG3f{NOYOHm@k?$G5q9cATrpY&| z0*@2(&jT^X6Ay4d84E`4i^}xFVg-hq-;C-poRkwmd;G%dVWjRuZPE*9J|WiS8YN?v z)Df7&Js8tk)rjf4o;}2C+^_+Px6}8&)-9n|-@AUoipwUS${KRo|9zl=~ za?N=*xxFAR+bH62CLX6>#f4q&#*cw!jbZDPllUCL6Z5?Ns}%u+F?Fz=Abz(&@fl5i z|Mg>>h7V)Ko9(6XyfNQ(ot1*O6Y^+yU2ILj^M?1~YYUHph~G~q8=E=E@VMZLXOT#e zq{w5HrIZ=^5Q*_Wr9S28FS-g&ix{uCdoS(E5k%X*h)g=UACv-n?Vf%3cjvr>s2A@H z{=gThTO747@wsMNbw@Ekq6+*|7aC*vSB~+opEZ8=-G%^F*pK9SFXM5MCgR^=rQZT6 z@3(88jmP)9s7jYt7&He2+o(G|SO<1-5G0P7Iw7fO5%3+G)l{~`=eU0V_Y>XT6jHPk z&aX0LNOtoN@jj=7nl*DO2_BoOS}0@1_hs!l7~M@j3?DH$D9eS`U_8}7%U1-M_rcdz zWrj0;wV2$Ri9%uIDgpkhTe!6)JBP_jLW1^Rn@NT@1KZ%h;Ps+|`&06j+DnvQ-#^EQHrExYK-ajZk=q$(sQ zow18g0OSYPzkOj^gK276UoKoO$K+)nYW_8K&qHpbWzUj1v>zkR*IGT7e>W_GHRZoH z#>BoJ^-mufq2ASqDIj4jbLselNlboGgG;SrX9+wOlDXr@DS$x8GJls>y};p?KV{G?JXg=ESC9)SJz# z(zf_~%~h*HFE_FdsdaR24gM+GU6%kt#0%-n&8DFJH{}JoR6NhWQ#c%zzEuOPoA^2$ zP6uOr;%FEx^Plr#tcQ1)&F1G%DAofp7VbqYJ&682)NR#q3bPGFS+?>pIhi z^W<+#TOVH}&C5wDpuA?cj_H4}@+DRMJqhd&o#*)(Rk*t@f{433bWNEc$=k?(xq#eR z!}MiEAq!tpJ9zvzb8j*6T?y*vf6`v`9SmFmY4s{hV^7e0g&?zg6Gxv0C4f8XlO+m# zXx=7%&r*eVzF9BAj-h)>_H+#x{^<&L@s6wk_^XcR0OE@GtKj8Px2oJa*gtDZHZq3u zxv$7UTcutK%arn5Ufx0THUUIzl+b-xErx8Z2dJSY6@w&maAFp3p zCjJst*H*#D<738@7q>9|(4LPeMu{%O_dcjS zzdY@)Q<9nke3T5uABS=OA>S~oJ#S5($N#!^Vdi)$hGX;V6PO;E2k#V2jwXEW+?68; z-IdM@9^ymLk8>!j^YJ`}?|Qn-HF%>4c+W6KSW`D*@)Acs5vRZDKz?f0irMY=uAKOt zPA6$3iFXFvZxcT@`UUT!4n3Mkl{;1sZyw=g&S}8wV4hS8$7AB}8FX(V2DGA;7>AOqV zfWzKPFl9C2;=wOy-$H!9$M%k=mN;L#WUExJiK6~Y07qBO=hCtuQUuM25UVF^7>;zp z-E*SNzhP>g+KQJij?;6yLww(A8QAeSc=5~$+z%Ucy=WC3H(`u|osIc0KEIq~?X=L; z7=@2Tm1m2n3o-s<_m(y)c*_CZpP!0ZU($EiC4Og@v<=q&*8*-CFDkGc`h)5JEZE=c znllJm&&5fpUHOOUJDs=Qa)@>WGFKtTXE8kPj%!oQh;{dYCgjc9;ST)1_rq*ZySwk(&>U0+8{~OOg zJI3q4;Q>{i^n{peO($C43BX4r^hfQZeo(Pd{&r9A9ELxwClsb|s0+MZv!^kl55wdg zR8K=D4^_dkv)c~Ey=dJefcQ)PWujVh@MJATyTLv*KM+9EkAw`b_bot-G_h6U3~Dz7 z`AO@@t}oXHZr-tu<#0j!1mZn*bZr$m6LCH|{&6PWYZ_a3L;s+5XHF$_ou$|1C)Oo2 zzo{Lq^KLDihCf@ZJS3TyF@5H(l>w5j66jWKtt85g&ie%LM_v-~tXYK(kdHOr8=n)u zdNUz+QaXUVxOh#35FU5&4KuQ`#P>y)o~0blWX1hwwasyr&!z_a-QnosyoSe%C->*o zhPTTw+LTxQckwF5hugx+C$E119)wn?KR~Pt4DEvQ8|@w86g6IJ+>96--}FzOk6&I8$NgdP>pDGC$08Kx{GgtffcFE^Ck-o4-X8?L0(RuO1)CVoi-MFr`hTZD zX>IJ41~S}FQi{Ml`Pyz#(JpmB(Ke0k6)FG0_vapFENmyAyR}6`CKU z8)F#g!*ICcZ?1b0-(6Sp__!FK8pZfXvciOBvo`qr;cKOxah$)yDT!-N|El2T!MoFf z_T!j-Mqj@a9r5>T#YbUce-U0^_-0gS4N2ResV%u-eQ-RcuRt<5Z2w~c7TfG)UJ2;J z;yX@*B>jm|SsUu|P0;9Z>HNg>sM4z7#}$E%6rg8|7OL&_q6+Ox)Zy``2u7cysOHS#9n_nE({ zzWj@B?gp!yQxg8SyDy+F$)Fd+<2TI5cITU9bdZ`tHg^2jY2hKHK26{E0L;l#xVFez6$i|BgK|Jhp!c zIKEN)%_qCGD<{5__wd%8`%L`an|)F*B=0ZA^PxbQ##-MjOfq{~5HqVu8Ze#fNV%!IuEeIg^PuDe)i+J4W-__4Ca7=<4 zzHj45E!r@gMAy@6r0`1`r<{J_u|^8&rBn@UybPMFx^vb0zRIB_sbUW{#?$3LX1DW58Pt? zTmPVT99vhIpyrj^R|T?k6(~kB@I1!AS1Qd^)&*?7b$&`x!tHP8a?*HBJQ2S5Qqmdz z5a(GjXfDgmx&-Ym5)$IM@Vrfq_$q$SXoW)`9o1>gaKC-g?z!JidIqX6<#yVzptbH2p-MjJC)VK+ge?KC{>q2}7M!wlIED?qFnFQcl5?%XssS8N-u_kVa;`==p zwosBzmjMzuWUQEt@jRAiwO*VP*9D)1U!|AyNBtZ@80XA2bxMbznDuHAy8zmM67MJL zxmP7N{o$;t!SNlglwCdu(s5!jL05JT%w5&XC;gxN|MN>e2%|#s-ZkqE-_-en@&C{? znUs694Y-L-{1Zk}AR4E!Ck20ZzP`yuXpFn{HxRC2%uQeIjDXq1GB^Q=c3zg zMazIqD>(Dh1H6tDnu~n4)our#En3e~h0u5>09)3OtHo0T@RKxCR02eCi1%PtE~`Jb z=D@Fk`IEI$czxU!rwusEKL&W06Yb;aQZajOTayY`=T8ESq3<7;KH%$`TJ^EzkqpD6 zm+!vMTchzp0CWug1Y-MQ(tdp~P;v=e90n4a5YNv? zEVm!{PG`VI17IG?h3AbxMP_f!hvjhom3fdsokH6irJvg8J-`>Aa z)#!t<+;fZhwq?&vni+jRI_A~Vka_!Y^PqoHeX zJ$I2$>L`BhYIb+8mOm!G*K)0T&n1iIZQ?t+N~0%FiS~}1cAj*??-)=kE9X7N; z{7$;C`t}LlpG2v1Z5xwJ0CnMsH-E6-AtDHsMxbC*#Q?a#__ZxR5|4K>k=S#?r~6>) zX;+6P;&(jMo})|L-OGn3U>W&zh)4rk*9d_A>0fu>%YBf>BAY>D4FA4ta&Pm3LlXh| z&n(U-aH97}1mLMZc}8}fOv>J+;6^$(de2Y%zD6T|=Gihy-d8Do)jZ8D%>Ik@fgyKk z20%emNo-qT3ue!7`ks^z?u+okXZO_3Pdykg zaXE$zoVA2TuNJeTzj+SG5LcNEaCo;xBPqIkb%~NQ7A@+@p%gkfG>+t9F^HSLq z{CrJbos+*9-v#PFex9xxn#SA#k4K2# zjZy!J(Q!Szq&NX)$i+>hzLa7(@#|uIH(w6})55G;+MZ^NXMx7St$|-7@Tds*lhzx6 z;T%7>SLDXMaBv*CbNYzQB8GG7O~F1x`ZcJ#E~?w2jh~mL=p|0>fi6h0qV)90-*yb= z&sYa16-hmKG##SFU-}+fxACc4#IF;hfaWE2yih*gN5z{QDw;eS2<+sSTldG{bJOkI zMpZWAxr_>)ED+Bs!}#zDoZWPYn1ioq*AE0qTLn`uC6Ee)3i&ZEMPZ3q+S^Y-BT|`G@#EcAoQCona*`m|80+uHVGg{YlSk zJ2=}5h1$c%c0S{ILvQl44r{kBXc|e)qCbr1LxokHKpn$r7$)c?Lpv3~CU4EJUUlWn`k%kwvS00pX=MHqj6-U& zPVs);LNw?DM&O|`DTIvVG@9}uQQ9!0=BQ*vW6jo$xjL`(4@;lxO9p@v&IL^mR?&O!w_90-@X*1si>%-E|SfhSNJT*=`aRPyQc+D(tgOj^C(q69hVHjr90GVI-*K_(;v77XTrQ(* zi^e+vT#X@Y#tC%7%Kc#{BaY#D`_A*H69+vf;K!xj&9|@XF&y_r3tqmj%fKfT(dp%% zz<4sP@*TXvH4bbA-$-Y**JFIRN`KQt7{@~OeMSDsqi7yO5J7!Gki54Y25%?FU(m<< zPm(VkS9aKDfO>UDd538yhBK|E%-TKE4x}<=0)-9nzER`%D-o8yUSM6r+3_j__aCRh z^GYWgf}lbS)jQ*}c^FQJ`K>owq#MvFX2#V^8b2>CTM{Opifow8N7J}Djo$YX|4xyu zhcYj9KYaKyblQgp??vgR$1wSewd?!&N;W`VsMlB&_X;NO{m+^a7mrSW*?(ucNoK)STO#HhPyJ3wS;@4n*q&NvF?seOU*DPW<^V+TK zh41_Uf%Ajy$qYPyN=UaPy=(jcJARVe>?y_hlydAB%l_;F#>*H(;x426CI0^Xx|*PL z%>@tymT}F06kqbVjtL1w;FHSXq zCVQW>IB|Txr+%~qDseOe>dq;f5?w!l#;>+?*GN_kiIfSivt?Xq#I77co^vvs_L*G* zvJ-ieNpJCfFJ19bN!ZOkDE-|eo zn|={74P-$egsDF1#P2MOe^P*J%ty<4kh!0iMKA@;0|f9uC1m<1V+rWBaYQ&Q2QmFh z>8I+7-o8NP+Tg`T2RyI37(E_qdz}GZo*o-N!99=Ze_WC5ILB5EZu1!tUiIR6k}s<` z*He}N-Jga%G!ALN^!23`j3#t*pyj#S1%_FuKO;!;Qy=fI1&yGTqo+IiB|c{@kZh2L z%`8Ew{$HZaqj4C{LOEB#^W;hJlOg(iTiysJSFE2{ym_5?&rm@7$(eE#pm|TPIm_tc zxgofAbVu-9RV}99f3e~i?@$f2w~IhjALI7u7v5r%* zv3fdM^x@SAN@P1Jlnix_Lxdq6e6S(?zBiZ0_kMtxu#}NSK#f-Wuk5^#R1ntuxpC7yH62NEQA^Myi;yVPJ zGbJyiaxwWS_x+Te&aEJ>&`i-^2Coao+Do*@@0Eedzc-l;!WS@od&V)p`&TC6B!^p? ziSw^rJ_Jz8T`1niw*}OyjlT*$MZaee@9ScobeAUd!9i2c{pp@~{_MZRK5e}?4<|#d zb5@PfyiEZ5LmP}p?hH^y%6|O0jr;jIOPMFFGE2Z9n(f3+FRstO)T21cOZ+|feh-6m zjzx@*GX2jJlh2mn>8FR|o+sjSH;KDQON<=>s2SufOFD(`>Jz`)DmR{;)Isvrs51&? z*P5{Xek`@2k_=gZOPU4>Th~@FeamNqA*xF=AnnTZVH?93Y+WuTGL_S}W8hPFQLQTl zxPGbQAsg+XI#~27mnL%y=kr=R;PB_KEg-*jr@F-ukFQZ{{k$U{h49lui)za$G;a{& zcdupUv(iQAzUcZX1mJP;@QVIMwb3Ft5+_rzERFk*(}lOQSI#bhzrE!@MB55?_eHF0 z7VM92DOP|Yeu9Gfcml?M*~C<;hV2iK>{E zKW0d*hqqsO>|2d#-o+#P z6CZMQTLC8Q^c5GL<9Y1tTA|RE?HUMF6<;qIM&pk79{47|g!$(>XdOinJxo4_t?QLz zvYwRH2<{6OurT@Kb(2A05QNeY|32sYn3#2S{QJVyPrSd3RvMx0pgZ&DxM6HxX<#Mp z@8AUdlY!08cyVDU^7K5aPn>szw4Vy3UQb4GVGBe_a{B|WdP4(bWssiFW8e| z!zHW2Oe(bRCB6&1V`5P2Q3o!0+h2~^LhXYftt`5vFMrj8uoh-Pt}drt`w+mWk7p98sKEyZT237A~w@BW}h z)fsry437G3%^;p@jm2iPu(dt;{5J@DP1*X6qj$G5>Y93tOds%4GP{K`iGsn-186X3!#(1 zJ6hqdWKIVrKQkXa;g&ZJoph_2rd#oRSU20zH&aafovaI=hh+5>rVsWU`1R6q5x&SQ zncrwe>k9!Sy*_3jtltW?vdNj4BP%dFuq>}dbdb-0e=hrme+sW*IAc_O$78n|K-$3F zZab^;T{-c7<_US{xCEg9>WL__H>GKcK>zl*mls|;;JmXG=;UbmY zc@%39orAXc+-aF;KZ28D=i+G@zF_!#ZRc0x+15buVZ-kCJ$Rg2j%hWEx89C}<gufm=Im8t!L!~NAJ#uIF9;ILU}LnbMKQGrTi3ym#;5WKB9&VgRM2z} z%|8gzd1q*nCae+8OFdP}2>!auhxmTRRky7N>44iCTz*?sX#Yw4zTT`Gcjfma4EuP_ zJ2<%$TUP@FUar$AfHfI&eq{Z4oEA1b=sEnV4dzyS8-Ejp=O6aVSvf~ZTY=X+znHod zu5Tao^zJFg0w{fg_RJG=4UGSxlx=xxo}KP|0mu@ zy>#98gq>^#%0@35(VWElG(YNZQo!Jy~gR7SN`49B|Kk91+Z9ex&g{hB=&^=AUe zd(zywM{pESRY(MCz6syeN051`HBzcm%iyrtRxva0B8H>oZ!5N}-w404Pdg{x_zqD2 zsft!+9+2n)TaF=Oy0%NZa^l|?dhaF^KTrZT7SAra$KrkRMTIoQp0DFjog1mFrpEQ( zNY+svYp4Njzuv5VNN&Q`O$bfAw&2hJBIJD*Ow@4yXSZ+2vNrAjZXurPXZB=Ze2$bS z`*7-)!}Q~F^}meK`bd1QwRYco^GG`I-;_4>R4BrDS{}Qm^qu(k>$1yE8()x{$JR|^ U{-%6;vW575M&w}j8g7UG1HPKrb^rhX literal 0 HcmV?d00001 diff --git a/examples/logistic_regression/simple_linear.py b/examples/logistic_regression/simple_linear.py new file mode 100644 index 0000000..c46b640 --- /dev/null +++ b/examples/logistic_regression/simple_linear.py @@ -0,0 +1,127 @@ +# %% +import torch +from nak_torch.algorithms import msip, msip_gs, svgd +import matplotlib.pyplot as plt +from nak_torch import LogisticRegressionModel +from nak_torch.tools import pyro_tools +from pyro.infer import mcmc + +from nak_torch.algorithms.msip import ( + MSIPFredholm, + MSIPQuadGradientInformed, +) +from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre + +import os + +if torch.cuda.is_available(): + torch.set_default_device("cuda") +else: + torch.set_default_device("cpu") +torch.set_default_dtype(torch.float64) + +# %% +data_path = os.path.join(os.path.dirname(__file__), "data", "simple_linear.npy") +regression_model = LogisticRegressionModel(data_path, None, hyperprior_b=0.01) +log_dens = regression_model.to_log_dens(use_compiled=False) + +plt.scatter(regression_model.data[1], regression_model.data[2], c=regression_model.labels, alpha=0.4) +plt.show() + +# %% +n_particles, state_dim = 20, regression_model.dim +coeff_init = torch.randn((n_particles, regression_model.dim - 1)) +alpha_init = torch.log(regression_model.hyperprior.sample((n_particles,))) +init_particles = torch.column_stack((coeff_init, alpha_init)) +log_dens(init_particles) # test eval + +# %% +kernel_length_scale = 0.05 +bounds = (-100.0, 100.0) +gradient_decay = 0.95 +lr_msip = 80e-2 +kernel_diag_infl = 1e-7 +n_steps = 1000 +grad_val_log_p = torch.vmap(torch.func.grad_and_value(log_dens)) + +@torch.compile(dynamic=False) +def mc_quad_rule(batch_size: int, N_quad: int = 500, dim: int = 4): + pts = torch.randn((batch_size, N_quad, dim), dtype=torch.get_default_dtype()) + wts = torch.ones((batch_size, N_quad), dtype=torch.get_default_dtype()).div_(N_quad) + return pts, wts + +@torch.compile(dynamic=False) +def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, dim: int = 4): + pts, wts = spherical_MC_radial_Laguerre(batch_size, N_spherical, dim, N_radial) + return pts, wts + + +msip_f = MSIPFredholm(gradient_decay, grad_val_log_p) +msip_gi = MSIPQuadGradientInformed(grad_val_log_p, mc_quad_rule, gradient_decay) + +# %% +trajectories_msip, traj_wts_msip = msip( + msip_f, + n_particles, + n_steps, + dim=state_dim, + lr=lr_msip, + init_particles=init_particles[:n_particles], + kernel_length_scale=kernel_length_scale, + is_log_density_batched=True, + kernel_diag_infl=kernel_diag_infl, + bounds=bounds, + keep_all=True, + compile_step=False, + verbose=True, +) + +# %% +msip_idx = 999 +msip_final_pts, msip_final_wts = trajectories_msip[msip_idx], traj_wts_msip[msip_idx] +logit_out = msip_final_pts[:,:-1] @ regression_model.data +prob_out = torch.nn.functional.sigmoid(logit_out) + +fig, axs = plt.subplots(4,5,figsize=(5*1.25,4.5*1.25)) +sc_data = regression_model.data[1:] +for i in range(4): + for j in range(5): + ax = axs[i,j] + idx = (5*i + j) + prob_ij = prob_out[idx] + wt = msip_final_wts[idx] + ax.scatter(sc_data[0], sc_data[1], c=prob_ij, alpha=0.1) + ax.set_axis_off() + ax.set_title("{:.2e}".format(wt), fontdict={'fontsize': 10}) +fig.suptitle("Different regression outcomes, MSIP wt as title") +plt.show() + +# %% +plt.scatter(sc_data[0], sc_data[1], c=regression_model.labels) + +# %% +n_steps_hmc = 1000 +pyro_model = pyro_tools.pyro_model_factory(regression_model, 4) +pyro_data = torch.concat((regression_model.data, regression_model.labels.reshape(1,-1))) +hmc_kernel = mcmc.NUTS(pyro_model) +mcmc_setup = mcmc.MCMC(hmc_kernel, num_samples=n_steps_hmc, warmup_steps=100) +mcmc_setup.run(pyro_data) + +# %% +hmc_samples = mcmc_setup.get_samples()["theta"] +thin_samples = hmc_samples[::(len(hmc_samples) // 20)] +logit_out = thin_samples @ regression_model.data +prob_out = torch.nn.functional.sigmoid(logit_out) + +fig, axs = plt.subplots(4,5,figsize=(5*1.25,4.5*1.25)) +sc_data = regression_model.data[1:] +for i in range(4): + for j in range(5): + ax = axs[i,j] + idx = (5*i + j) + prob_ij = prob_out[idx] + ax.scatter(sc_data[0], sc_data[1], c=prob_ij, alpha=0.1) + ax.set_axis_off() + # ax.set_title("{:.2e}".format(wt), fontdict={'fontsize': 10}) +fig.suptitle("Different regression outcomes, MSIP wt as title") +plt.show() diff --git a/src/nak_torch/__init__.py b/src/nak_torch/__init__.py index 59ead4e..0674175 100644 --- a/src/nak_torch/__init__.py +++ b/src/nak_torch/__init__.py @@ -1,5 +1,5 @@ from . import algorithms, tools -from .tools import GaussianModel +from .tools import GaussianModel, LogisticRegressionModel -__all__ = ["algorithms", "tools", "GaussianModel"] +__all__ = ["algorithms", "tools", "GaussianModel", "LogisticRegressionModel"] diff --git a/src/nak_torch/tools/__init__.py b/src/nak_torch/tools/__init__.py index e5cdabd..0edabc3 100644 --- a/src/nak_torch/tools/__init__.py +++ b/src/nak_torch/tools/__init__.py @@ -8,7 +8,7 @@ from . import kernel, types, quadrature, adaptive_step from .average import recursive_weighted_average_alpha_v from .torchify import differentiable_density_factory -from .types import GaussianModel +from .types import GaussianModel, LogisticRegressionModel __all__ = [ "kernel", @@ -16,6 +16,7 @@ "recursive_weighted_average_alpha_v", "differentiable_density_factory", "GaussianModel", + "LogisticRegressionModel", "quadrature", "adaptive_step", ] diff --git a/src/nak_torch/tools/pyro_tools.py b/src/nak_torch/tools/pyro_tools.py index 3508e2b..4ba4cf1 100644 --- a/src/nak_torch/tools/pyro_tools.py +++ b/src/nak_torch/tools/pyro_tools.py @@ -1,9 +1,15 @@ from typing import Optional, Union import torch +from torch import Tensor +from jaxtyping import Float import pyro import pyro.distributions as dist -from nak_torch import GaussianModel +from .types import GaussianModel, LogisticRegressionModel, AbstractModel +from abc import ABC, abstractmethod +from typing import TypeVar, Generic + +__all__ = ["pyro_model_factory"] DeviceLike = Union[str, torch.device, int] @@ -25,7 +31,39 @@ def get_pyro_std_from_prec( raise ValueError(f"Expected precision to be ndim=0,1,2, got {nd}.") -class PyroModel: +ModelT = TypeVar("ModelT", bound=AbstractModel) + + +class PyroModel(ABC, Generic[ModelT]): + @abstractmethod + def __init__( + self, + model: ModelT, + param_dim: Optional[int] = None, + device: Optional[DeviceLike] = None, + ): + pass + + @abstractmethod + def __call__(self, data: Float[Tensor, "batch dim"]) -> Float[Tensor, " batch"]: + pass + + +def pyro_model_factory( + model: AbstractModel, + param_dim: Optional[int] = None, + device: Optional[DeviceLike] = None, +): + match model: + case GaussianModel(): + return PyroGaussianModel(model, param_dim, device) + case LogisticRegressionModel(): + return PyroLogisticRegressionModel(model, param_dim, device) + case _ as unreachable: + raise ValueError(f"Unexpected model type {unreachable}") + + +class PyroGaussianModel(PyroModel[GaussianModel]): def __init__( self, model: GaussianModel, @@ -62,3 +100,57 @@ def __call__(self, data): mean_out = self.forward_model(theta.unsqueeze(0)) with pyro.plate("data", len(data)): return pyro.sample("obs", dist.Normal(mean_out, self.like_std), obs=data) + + +class PyroLogisticRegressionModel(PyroModel[LogisticRegressionModel]): + concentration: Float + rate: Float + param_dim: int + prior_mean: Float | Float[Tensor, " dim"] + + def __init__( + self, + model: LogisticRegressionModel, + param_dim: Optional[int] = None, + device: Optional[DeviceLike] = None, + ): + self.concentration, self.rate = ( + model.hyperprior.concentration, + model.hyperprior.rate, + ) + prior_mean = model.prior_mean + if prior_mean is None: + prior_mean = 0.0 + prior_mean = torch.as_tensor(prior_mean, device=device) + if param_dim is not None: + coeff_dim = param_dim - 1 + if prior_mean.numel() == 1: + self.prior_mean = prior_mean.item() * torch.ones(coeff_dim) + elif prior_mean.numel() == coeff_dim: + self.prior_mean = prior_mean.flatten() + else: + raise ValueError( + f"Unexpected arguments: prior_mean.size = {prior_mean.shape}, coeff_dim = {coeff_dim}" + ) + self.param_dim = param_dim + else: + self.param_dim = prior_mean.shape[0] + 1 + self.prior_mean = prior_mean + + def __call__(self, data): + # Data should be dimension equiv to all but alpha, plus labels as last row + if data.shape[0] != (self.param_dim - 1) + 1: + raise ValueError( + f"Got data.shape[0] = {data.shape[0]}, expected {self.param_dim}" + ) + prior_precision = pyro.sample( + "alpha", dist.Gamma(self.concentration, self.rate) + ) + prior_std = 1 / prior_precision.sqrt() + theta = pyro.sample("theta", dist.Normal(self.prior_mean, prior_std)) + dataset, labels = data[:-1], data[-1] + with pyro.plate("data", dataset.shape[1]): + logits = theta @ dataset + return pyro.sample( + "obs", dist.Bernoulli(logits=logits, validate_args=True), obs=labels + ) diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index 1c1c216..fe5a69e 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -1,8 +1,10 @@ import torch +import numpy as np from torch import Tensor from jaxtyping import Float -from typing import Callable, Optional, Protocol +from typing import Callable, Optional, Protocol, Self from dataclasses import dataclass +from abc import ABC, abstractmethod BatchType = Float[Tensor, "batch"] PtType = Float[Tensor, " d"] @@ -45,8 +47,16 @@ def __call__( BatchForwardModel = Callable[[Float[Tensor, "batch dim"]], Float[Tensor, "batch obs"]] +class AbstractModel(ABC): + @abstractmethod + def to_log_dens( + self: Self, use_compiled: bool = True + ) -> Callable[[BatchPtType], BatchType]: + pass + + @dataclass -class GaussianModel: +class GaussianModel(AbstractModel): forward_model: BatchForwardModel likelihood_precision: float | Float[Tensor, "obs obs"] prior_precision: float | Float[Tensor, "dim dim"] @@ -71,10 +81,103 @@ def __init__( self.true_obs = true_obs self.prior_mean = prior_mean - def to_log_dens(self, use_compiled: bool = True): + def to_log_dens(self, use_compiled=True): return gaussian_log_dens_factory(self, use_compiled) +def bernoulli_loglikelihood_logit(logits, labels): + # If logit(p) = log(p / (1-p)), bernoulli log-likelihood is + # log pi(y | p) = y*logit(p) + log(1-p) + # If q = logit(p), then log(1-p) = -softplus(q) and + # log pi(y | q) = y * q - softplus(q) + constant_term = -torch.nn.functional.softplus(logits) + return torch.sum(labels * logits + constant_term) + + +bernoulli_loglikelihood_logit_v = torch.vmap(bernoulli_loglikelihood_logit, (0, None)) + + +@dataclass +class LogisticRegressionModel(AbstractModel): + """Assumes a gaussian prior and linear model for logits""" + + dim: int + prior_mean: float | Float[Tensor, " dim"] | None + data: Float | Float[Tensor, "dim labels"] + labels: Float | Float[Tensor, " labels"] + hyperprior: torch.distributions.Gamma + + def __init__( + self, + data_or_fname: Float[Tensor, "dim-1 labels"] | str, + labels: Optional[Float[Tensor, " labels"]], + prior_mean: float | Float[Tensor, " dim"] | None = None, + dtype=None, + device=None, + hyperprior_a=1.0, + hyperprior_b=0.1, + ): + data: torch.Tensor + dtype = torch.get_default_dtype() if dtype is None else dtype + device = torch.get_default_device() if device is None else device + + def as_tensor(t): + return torch.as_tensor(t, dtype=dtype, device=device) + + self.prior_mean = prior_mean if prior_mean is None else as_tensor(prior_mean) + if isinstance(data_or_fname, str): + data = as_tensor(np.load(data_or_fname)) + if labels is None: # Split labels from data + labels = data[:, -1] + data = data[:, :-1] + elif isinstance(data_or_fname, torch.Tensor): + data = data_or_fname + else: + raise ValueError( + f"Expected data_or_fname to be str or tensor, got {type(data_or_fname)}" + ) + if labels is None or labels.shape[0] != data.shape[0]: + raise ValueError("Unexpected type or size of argument `labels`.") + constant = as_tensor(torch.ones(data.shape[0])) + self.data = torch.column_stack((constant, data)).T + self.dim = self.data.shape[0] + 1 + self.labels = labels + self.prior_mean = prior_mean + self.hyperprior = torch.distributions.Gamma( + as_tensor(hyperprior_a), as_tensor(hyperprior_b) + ) + + def to_log_dens(self, use_compiled: bool = True): + def log_hyperprior(t): + return self.hyperprior.log_prob(t) + + def log_dens(params: BatchPtType) -> BatchType: + is_batch = params.ndim == 2 + if not is_batch: + params = params.unsqueeze(0) + if params.shape[1] != self.dim: + raise ValueError( + f"Got params.shape[1] = {params.shape[1]}, expected {self.dim}" + ) + prior_diff = params.clone() + if self.prior_mean is not None: + prior_diff -= self.prior_mean + coeffs = params[:, :-1] + alpha = torch.exp(params[:, -1]) + hyperprior_term = log_hyperprior(alpha) + prior_term = -torch.sum(torch.square_(prior_diff), dim=-1).mul_(2 * alpha) + logits = coeffs @ self.data + likelihood = bernoulli_loglikelihood_logit_v(logits, self.labels) + # print("alpha:",alpha,"\n\n") + # print("likely:",likelihood,"\n\n") + # print("prior:",prior_term,"\n\n") + # print("hyperprior:",hyperprior_term,"\n\n") + post = likelihood + prior_term + hyperprior_term + return post if is_batch else post[0] + + return torch.compile(log_dens) if use_compiled else log_dens + + def gaussian_log_dens_factory( model: GaussianModel, use_compiled: bool = True ) -> BatchLogDensity: From 15fc8129bf45113be7c282a55337cd16ba4ad954 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Wed, 15 Apr 2026 14:08:55 -0400 Subject: [PATCH 05/60] Keep working on covtype --- examples/logistic_regression/covertype.py | 155 +++++++++++++++++++ examples/logistic_regression/data/.gitignore | 2 + pyproject.toml | 1 + src/nak_torch/tools/types.py | 52 +++++-- uv.lock | 63 ++++++++ 5 files changed, 259 insertions(+), 14 deletions(-) create mode 100644 examples/logistic_regression/covertype.py create mode 100644 examples/logistic_regression/data/.gitignore diff --git a/examples/logistic_regression/covertype.py b/examples/logistic_regression/covertype.py new file mode 100644 index 0000000..488aa34 --- /dev/null +++ b/examples/logistic_regression/covertype.py @@ -0,0 +1,155 @@ +# %% +import os +from urllib.request import urlretrieve +import torch +from nak_torch.algorithms import msip, msip_gs, svgd +import matplotlib.pyplot as plt +from nak_torch import LogisticRegressionModel +from nak_torch.tools import pyro_tools +from pyro.infer import mcmc + +from nak_torch.algorithms.msip import ( + MSIPFredholm, + MSIPQuadGradientInformed, +) +from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre +import scipy.io +import numpy as np + +if torch.cuda.is_available(): + torch.set_default_device("cuda") +else: + torch.set_default_device("cpu") +torch.set_default_dtype(torch.float64) + +# %% +DATA_URL = "https://raw.githubusercontent.com/DartML/Stein-Variational-Gradient-Descent/refs/heads/master/data/covertype.mat" +DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "covertype.npy") + +def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): + urlretrieve(data_url, data_path) + data_mat = scipy.io.loadmat(data_path) + data_arr = data_mat['covtype'] + # Flip first col to be (0,1) instead of (2,1) (where 2 is false) + covariates = data_arr[:, 1:] + labels = -1 * (data_arr[:, 0] - 2) + data_arr = np.column_stack((covariates, labels)) + # Save + np.save(data_path, data_arr) + +if not os.path.isfile(DATA_PATH): + download_file() + +# %% +data_path = DATA_PATH +regression_model = LogisticRegressionModel(data_path, None, hyperprior_b=0.01, train_proportion=0.8, sum_bernoulli=True) +log_dens = regression_model.to_log_dens(use_compiled=False) + +# %% +N_plot = 10000 +plt.scatter(regression_model.train_data[2,:N_plot], regression_model.train_data[3, :N_plot], c=regression_model.train_labels[:N_plot], alpha=0.2) +plt.show() + +# %% +n_particles, state_dim = 20, regression_model.dim +coeff_init = torch.randn((n_particles, regression_model.dim - 1)) +alpha_init = torch.log(regression_model.hyperprior.sample((n_particles,))) +init_particles = torch.column_stack((coeff_init, alpha_init)) +log_dens(init_particles) # test eval + +# %% +kernel_length_scale = 0.05 +bounds = (-100.0, 100.0) +gradient_decay = 0.75 +lr_msip = 1e-1 +kernel_diag_infl = 1e-6 +n_steps = 20 +grad_val_log_p = torch.vmap(torch.func.grad_and_value(log_dens)) + +@torch.compile(dynamic=False) +def mc_quad_rule(batch_size: int, N_quad: int = 500, dim: int = 56): + pts = torch.randn((batch_size, N_quad, dim), dtype=torch.get_default_dtype()) + wts = torch.ones((batch_size, N_quad), dtype=torch.get_default_dtype()).div_(N_quad) + return pts, wts + +@torch.compile(dynamic=False) +def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, dim: int = 56): + pts, wts = spherical_MC_radial_Laguerre(batch_size, N_spherical, dim, N_radial) + return pts, wts + + +msip_f = MSIPFredholm(gradient_decay, grad_val_log_p) +msip_gi = MSIPQuadGradientInformed(grad_val_log_p, mc_quad_rule, gradient_decay) + +# %% +trajectories_msip, traj_wts_msip = msip( + msip_gi, + n_particles, + n_steps, + dim=state_dim, + lr=lr_msip, + init_particles=init_particles[:n_particles], + kernel_length_scale=kernel_length_scale, + is_log_density_batched=True, + kernel_diag_infl=kernel_diag_infl, + bounds=bounds, + keep_all=True, + compile_step=True, + verbose=True, +) +trajectories_msip[-1] + +# %% +msip_end = trajectories_msip[-1] +dist_end = torch.sqrt(torch.sum(torch.square_(msip_end[None,:] - msip_end[:,None]), -1)) +lower_tri_idx = torch.tril_indices(*dist_end.shape, -1) +lower_tri_dist = dist_end[*lower_tri_idx] +plt.hist(lower_tri_dist) + +# %% +bce_logit_v = torch.vmap(torch.nn.functional.binary_cross_entropy_with_logits, in_dims=(0,None)) + +# @torch.compile +def bce_logit_t(traj_t): + logits_t = traj_t[:,:-1] @ regression_model.test_data + return bce_logit_v(logits_t, regression_model.test_labels) +# bce_logit_traj = torch.vmap(bce_logit_t) +bce_traj = torch.stack([bce_logit_t(trajectories_msip[j]) for j in range(trajectories_msip.shape[0])]) +# logits_t = trajectories_msip[:,:,:-1].reshape(-1, trajectories_msip.shape[-1] - 1) @ regression_model.data +# bce_traj = bce_logit_v(logits_t, regression_model.labels).reshape(*trajectories_msip.shape[:2], -1) +# print("BCE t=0: {}, BCE t=T: {}".format(bce_0.mean(), bce_T.mean())) + +fig, ax = plt.subplots() +for particle_idx in range(n_particles): + ax.plot(bce_traj[:,particle_idx], alpha= 0.4) +plt.show() + +# %% +def accuracy(coeffs): + data, labels = regression_model.test_data, regression_model.test_labels + prob = torch.sigmoid(coeffs[:-1] @ data) + pred_labels = prob > 0.5 + print(pred_labels.sum()) + N_true = torch.sum(pred_labels == labels) + return N_true / data.shape[1] + +accuracy_v = torch.vmap(accuracy) +accuracy_v(trajectories_msip[-1]) + +# %% + +trajectories_msip, traj_wts_msip = svgd( + msip_f, + n_particles, + n_steps, + dim=state_dim, + lr=lr_msip, + init_particles=init_particles[:n_particles], + kernel_length_scale=kernel_length_scale, + is_log_density_batched=True, + kernel_diag_infl=kernel_diag_infl, + bounds=bounds, + keep_all=True, + compile_step=True, + verbose=True, +) diff --git a/examples/logistic_regression/data/.gitignore b/examples/logistic_regression/data/.gitignore new file mode 100644 index 0000000..ce8d242 --- /dev/null +++ b/examples/logistic_regression/data/.gitignore @@ -0,0 +1,2 @@ +*.npy +!simple_linear.npy \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 4153269..22bcd09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ examples = [ "ipykernel>=7.2.0", "matplotlib>=3.10.8", "pyro-ppl>=1.9.1", + "scipy>=1.17.1", ] [build-system] diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index fe5a69e..33aad6d 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -103,8 +103,11 @@ class LogisticRegressionModel(AbstractModel): dim: int prior_mean: float | Float[Tensor, " dim"] | None - data: Float | Float[Tensor, "dim labels"] - labels: Float | Float[Tensor, " labels"] + train_data: Float | Float[Tensor, "dim labels"] + test_data: Optional[Float | Float[Tensor, "dim labels"]] + train_labels: Float | Float[Tensor, " labels"] + test_labels: Optional[Float | Float[Tensor, " labels"]] + sum_bernoulli: bool hyperprior: torch.distributions.Gamma def __init__( @@ -116,6 +119,8 @@ def __init__( device=None, hyperprior_a=1.0, hyperprior_b=0.1, + train_proportion=1.0, + sum_bernoulli=True, ): data: torch.Tensor dtype = torch.get_default_dtype() if dtype is None else dtype @@ -136,13 +141,24 @@ def as_tensor(t): raise ValueError( f"Expected data_or_fname to be str or tensor, got {type(data_or_fname)}" ) - if labels is None or labels.shape[0] != data.shape[0]: + N_pts = data.shape[0] + if labels is None or labels.shape[0] != N_pts: raise ValueError("Unexpected type or size of argument `labels`.") - constant = as_tensor(torch.ones(data.shape[0])) - self.data = torch.column_stack((constant, data)).T - self.dim = self.data.shape[0] + 1 - self.labels = labels + constant = as_tensor(torch.ones(N_pts)) + data = torch.column_stack((constant, data)).T + if train_proportion >= 1.0: + self.train_data, self.test_data = data, None + self.train_labels, self.test_labels = labels, None + else: + ridx = torch.randperm(N_pts) + num_train = int(np.floor(N_pts * train_proportion)) + self.train_data = data[:, ridx[:num_train]] + self.train_labels = labels[ridx[:num_train]] + self.test_data = data[:, ridx[num_train:]] + self.test_labels = labels[ridx[num_train:]] + self.dim = data.shape[0] + 1 self.prior_mean = prior_mean + self.sum_bernoulli = sum_bernoulli self.hyperprior = torch.distributions.Gamma( as_tensor(hyperprior_a), as_tensor(hyperprior_b) ) @@ -151,7 +167,7 @@ def to_log_dens(self, use_compiled: bool = True): def log_hyperprior(t): return self.hyperprior.log_prob(t) - def log_dens(params: BatchPtType) -> BatchType: + def log_dens(params: BatchPtType, use_train: bool = True) -> BatchType: is_batch = params.ndim == 2 if not is_batch: params = params.unsqueeze(0) @@ -166,12 +182,20 @@ def log_dens(params: BatchPtType) -> BatchType: alpha = torch.exp(params[:, -1]) hyperprior_term = log_hyperprior(alpha) prior_term = -torch.sum(torch.square_(prior_diff), dim=-1).mul_(2 * alpha) - logits = coeffs @ self.data - likelihood = bernoulli_loglikelihood_logit_v(logits, self.labels) - # print("alpha:",alpha,"\n\n") - # print("likely:",likelihood,"\n\n") - # print("prior:",prior_term,"\n\n") - # print("hyperprior:",hyperprior_term,"\n\n") + data: Float[Tensor, "dim-1 N_pts"] + labels: Float[Tensor, " N_pts"] + if use_train: + data = self.train_data + labels = self.train_data + elif self.test_data is not None and self.test_labels is not None: + data = self.test_data + labels = self.test_labels + else: + raise ValueError("Expected test_data and test_labels to be initialized") + logits = coeffs @ data + likelihood = bernoulli_loglikelihood_logit_v(logits, labels) + if not self.sum_bernoulli: + likelihood /= labels.numel() post = likelihood + prior_term + hyperprior_term return post if is_batch else post[0] diff --git a/uv.lock b/uv.lock index b0b133f..08c4634 100644 --- a/uv.lock +++ b/uv.lock @@ -754,6 +754,7 @@ examples = [ { name = "ipykernel" }, { name = "matplotlib" }, { name = "pyro-ppl" }, + { name = "scipy" }, ] [package.dev-dependencies] @@ -773,6 +774,7 @@ requires-dist = [ { name = "matplotlib", marker = "extra == 'examples'", specifier = ">=3.10.8" }, { name = "numpy", specifier = ">=2.4.1" }, { name = "pyro-ppl", marker = "extra == 'examples'", specifier = ">=1.9.1" }, + { name = "scipy", marker = "extra == 'examples'", specifier = ">=1.17.1" }, { name = "torch", specifier = ">=2.10" }, { name = "tqdm", specifier = ">=4.67.1" }, ] @@ -1399,6 +1401,67 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/15/e2/77be4fff062fa78d9b2a4dea85d14785dac5f1d0c1fb58ed52331f0ebe28/ruff-0.15.8-py3-none-win_arm64.whl", hash = "sha256:cf891fa8e3bb430c0e7fac93851a5978fc99c8fa2c053b57b118972866f8e5f2", size = 11048175, upload-time = "2026-03-26T18:40:01.06Z" }, ] +[[package]] +name = "scipy" +version = "1.17.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7a/97/5a3609c4f8d58b039179648e62dd220f89864f56f7357f5d4f45c29eb2cc/scipy-1.17.1.tar.gz", hash = "sha256:95d8e012d8cb8816c226aef832200b1d45109ed4464303e997c5b13122b297c0", size = 30573822, upload-time = "2026-02-23T00:26:24.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/35/48/b992b488d6f299dbe3f11a20b24d3dda3d46f1a635ede1c46b5b17a7b163/scipy-1.17.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:35c3a56d2ef83efc372eaec584314bd0ef2e2f0d2adb21c55e6ad5b344c0dcb8", size = 31610954, upload-time = "2026-02-23T00:17:49.855Z" }, + { url = "https://files.pythonhosted.org/packages/b2/02/cf107b01494c19dc100f1d0b7ac3cc08666e96ba2d64db7626066cee895e/scipy-1.17.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fcb310ddb270a06114bb64bbe53c94926b943f5b7f0842194d585c65eb4edd76", size = 28172662, upload-time = "2026-02-23T00:18:01.64Z" }, + { url = "https://files.pythonhosted.org/packages/cf/a9/599c28631bad314d219cf9ffd40e985b24d603fc8a2f4ccc5ae8419a535b/scipy-1.17.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cc90d2e9c7e5c7f1a482c9875007c095c3194b1cfedca3c2f3291cdc2bc7c086", size = 20344366, upload-time = "2026-02-23T00:18:12.015Z" }, + { url = "https://files.pythonhosted.org/packages/35/f5/906eda513271c8deb5af284e5ef0206d17a96239af79f9fa0aebfe0e36b4/scipy-1.17.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c80be5ede8f3f8eded4eff73cc99a25c388ce98e555b17d31da05287015ffa5b", size = 22704017, upload-time = "2026-02-23T00:18:21.502Z" }, + { url = "https://files.pythonhosted.org/packages/da/34/16f10e3042d2f1d6b66e0428308ab52224b6a23049cb2f5c1756f713815f/scipy-1.17.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e19ebea31758fac5893a2ac360fedd00116cbb7628e650842a6691ba7ca28a21", size = 32927842, upload-time = "2026-02-23T00:18:35.367Z" }, + { url = "https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02ae3b274fde71c5e92ac4d54bc06c42d80e399fec704383dcd99b301df37458", size = 35235890, upload-time = "2026-02-23T00:18:49.188Z" }, + { url = "https://files.pythonhosted.org/packages/c5/5c/9d7f4c88bea6e0d5a4f1bc0506a53a00e9fcb198de372bfe4d3652cef482/scipy-1.17.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a604bae87c6195d8b1045eddece0514d041604b14f2727bbc2b3020172045eb", size = 35003557, upload-time = "2026-02-23T00:18:54.74Z" }, + { url = "https://files.pythonhosted.org/packages/65/94/7698add8f276dbab7a9de9fb6b0e02fc13ee61d51c7c3f85ac28b65e1239/scipy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f590cd684941912d10becc07325a3eeb77886fe981415660d9265c4c418d0bea", size = 37625856, upload-time = "2026-02-23T00:19:00.307Z" }, + { url = "https://files.pythonhosted.org/packages/a2/84/dc08d77fbf3d87d3ee27f6a0c6dcce1de5829a64f2eae85a0ecc1f0daa73/scipy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:41b71f4a3a4cab9d366cd9065b288efc4d4f3c0b37a91a8e0947fb5bd7f31d87", size = 36549682, upload-time = "2026-02-23T00:19:07.67Z" }, + { url = "https://files.pythonhosted.org/packages/bc/98/fe9ae9ffb3b54b62559f52dedaebe204b408db8109a8c66fdd04869e6424/scipy-1.17.1-cp312-cp312-win_arm64.whl", hash = "sha256:f4115102802df98b2b0db3cce5cb9b92572633a1197c77b7553e5203f284a5b3", size = 24547340, upload-time = "2026-02-23T00:19:12.024Z" }, + { url = "https://files.pythonhosted.org/packages/76/27/07ee1b57b65e92645f219b37148a7e7928b82e2b5dbeccecb4dff7c64f0b/scipy-1.17.1-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:5e3c5c011904115f88a39308379c17f91546f77c1667cea98739fe0fccea804c", size = 31590199, upload-time = "2026-02-23T00:19:17.192Z" }, + { url = "https://files.pythonhosted.org/packages/ec/ae/db19f8ab842e9b724bf5dbb7db29302a91f1e55bc4d04b1025d6d605a2c5/scipy-1.17.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6fac755ca3d2c3edcb22f479fceaa241704111414831ddd3bc6056e18516892f", size = 28154001, upload-time = "2026-02-23T00:19:22.241Z" }, + { url = "https://files.pythonhosted.org/packages/5b/58/3ce96251560107b381cbd6e8413c483bbb1228a6b919fa8652b0d4090e7f/scipy-1.17.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:7ff200bf9d24f2e4d5dc6ee8c3ac64d739d3a89e2326ba68aaf6c4a2b838fd7d", size = 20325719, upload-time = "2026-02-23T00:19:26.329Z" }, + { url = "https://files.pythonhosted.org/packages/b2/83/15087d945e0e4d48ce2377498abf5ad171ae013232ae31d06f336e64c999/scipy-1.17.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:4b400bdc6f79fa02a4d86640310dde87a21fba0c979efff5248908c6f15fad1b", size = 22683595, upload-time = "2026-02-23T00:19:30.304Z" }, + { url = "https://files.pythonhosted.org/packages/b4/e0/e58fbde4a1a594c8be8114eb4aac1a55bcd6587047efc18a61eb1f5c0d30/scipy-1.17.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b64ca7d4aee0102a97f3ba22124052b4bd2152522355073580bf4845e2550b6", size = 32896429, upload-time = "2026-02-23T00:19:35.536Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5f/f17563f28ff03c7b6799c50d01d5d856a1d55f2676f537ca8d28c7f627cd/scipy-1.17.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:581b2264fc0aa555f3f435a5944da7504ea3a065d7029ad60e7c3d1ae09c5464", size = 35203952, upload-time = "2026-02-23T00:19:42.259Z" }, + { url = "https://files.pythonhosted.org/packages/8d/a5/9afd17de24f657fdfe4df9a3f1ea049b39aef7c06000c13db1530d81ccca/scipy-1.17.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:beeda3d4ae615106d7094f7e7cef6218392e4465cc95d25f900bebabfded0950", size = 34979063, upload-time = "2026-02-23T00:19:47.547Z" }, + { url = "https://files.pythonhosted.org/packages/8b/13/88b1d2384b424bf7c924f2038c1c409f8d88bb2a8d49d097861dd64a57b2/scipy-1.17.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6609bc224e9568f65064cfa72edc0f24ee6655b47575954ec6339534b2798369", size = 37598449, upload-time = "2026-02-23T00:19:53.238Z" }, + { url = "https://files.pythonhosted.org/packages/35/e5/d6d0e51fc888f692a35134336866341c08655d92614f492c6860dc45bb2c/scipy-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:37425bc9175607b0268f493d79a292c39f9d001a357bebb6b88fdfaff13f6448", size = 36510943, upload-time = "2026-02-23T00:20:50.89Z" }, + { url = "https://files.pythonhosted.org/packages/2a/fd/3be73c564e2a01e690e19cc618811540ba5354c67c8680dce3281123fb79/scipy-1.17.1-cp313-cp313-win_arm64.whl", hash = "sha256:5cf36e801231b6a2059bf354720274b7558746f3b1a4efb43fcf557ccd484a87", size = 24545621, upload-time = "2026-02-23T00:20:55.871Z" }, + { url = "https://files.pythonhosted.org/packages/6f/6b/17787db8b8114933a66f9dcc479a8272e4b4da75fe03b0c282f7b0ade8cd/scipy-1.17.1-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:d59c30000a16d8edc7e64152e30220bfbd724c9bbb08368c054e24c651314f0a", size = 31936708, upload-time = "2026-02-23T00:19:58.694Z" }, + { url = "https://files.pythonhosted.org/packages/38/2e/524405c2b6392765ab1e2b722a41d5da33dc5c7b7278184a8ad29b6cb206/scipy-1.17.1-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:010f4333c96c9bb1a4516269e33cb5917b08ef2166d5556ca2fd9f082a9e6ea0", size = 28570135, upload-time = "2026-02-23T00:20:03.934Z" }, + { url = "https://files.pythonhosted.org/packages/fd/c3/5bd7199f4ea8556c0c8e39f04ccb014ac37d1468e6cfa6a95c6b3562b76e/scipy-1.17.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:2ceb2d3e01c5f1d83c4189737a42d9cb2fc38a6eeed225e7515eef71ad301dce", size = 20741977, upload-time = "2026-02-23T00:20:07.935Z" }, + { url = "https://files.pythonhosted.org/packages/d9/b8/8ccd9b766ad14c78386599708eb745f6b44f08400a5fd0ade7cf89b6fc93/scipy-1.17.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:844e165636711ef41f80b4103ed234181646b98a53c8f05da12ca5ca289134f6", size = 23029601, upload-time = "2026-02-23T00:20:12.161Z" }, + { url = "https://files.pythonhosted.org/packages/6d/a0/3cb6f4d2fb3e17428ad2880333cac878909ad1a89f678527b5328b93c1d4/scipy-1.17.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:158dd96d2207e21c966063e1635b1063cd7787b627b6f07305315dd73d9c679e", size = 33019667, upload-time = "2026-02-23T00:20:17.208Z" }, + { url = "https://files.pythonhosted.org/packages/f3/c3/2d834a5ac7bf3a0c806ad1508efc02dda3c8c61472a56132d7894c312dea/scipy-1.17.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74cbb80d93260fe2ffa334efa24cb8f2f0f622a9b9febf8b483c0b865bfb3475", size = 35264159, upload-time = "2026-02-23T00:20:23.087Z" }, + { url = "https://files.pythonhosted.org/packages/4d/77/d3ed4becfdbd217c52062fafe35a72388d1bd82c2d0ba5ca19d6fcc93e11/scipy-1.17.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:dbc12c9f3d185f5c737d801da555fb74b3dcfa1a50b66a1a93e09190f41fab50", size = 35102771, upload-time = "2026-02-23T00:20:28.636Z" }, + { url = "https://files.pythonhosted.org/packages/bd/12/d19da97efde68ca1ee5538bb261d5d2c062f0c055575128f11a2730e3ac1/scipy-1.17.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:94055a11dfebe37c656e70317e1996dc197e1a15bbcc351bcdd4610e128fe1ca", size = 37665910, upload-time = "2026-02-23T00:20:34.743Z" }, + { url = "https://files.pythonhosted.org/packages/06/1c/1172a88d507a4baaf72c5a09bb6c018fe2ae0ab622e5830b703a46cc9e44/scipy-1.17.1-cp313-cp313t-win_amd64.whl", hash = "sha256:e30bdeaa5deed6bc27b4cc490823cd0347d7dae09119b8803ae576ea0ce52e4c", size = 36562980, upload-time = "2026-02-23T00:20:40.575Z" }, + { url = "https://files.pythonhosted.org/packages/70/b0/eb757336e5a76dfa7911f63252e3b7d1de00935d7705cf772db5b45ec238/scipy-1.17.1-cp313-cp313t-win_arm64.whl", hash = "sha256:a720477885a9d2411f94a93d16f9d89bad0f28ca23c3f8daa521e2dcc3f44d49", size = 24856543, upload-time = "2026-02-23T00:20:45.313Z" }, + { url = "https://files.pythonhosted.org/packages/cf/83/333afb452af6f0fd70414dc04f898647ee1423979ce02efa75c3b0f2c28e/scipy-1.17.1-cp314-cp314-macosx_10_14_x86_64.whl", hash = "sha256:a48a72c77a310327f6a3a920092fa2b8fd03d7deaa60f093038f22d98e096717", size = 31584510, upload-time = "2026-02-23T00:21:01.015Z" }, + { url = "https://files.pythonhosted.org/packages/ed/a6/d05a85fd51daeb2e4ea71d102f15b34fedca8e931af02594193ae4fd25f7/scipy-1.17.1-cp314-cp314-macosx_12_0_arm64.whl", hash = "sha256:45abad819184f07240d8a696117a7aacd39787af9e0b719d00285549ed19a1e9", size = 28170131, upload-time = "2026-02-23T00:21:05.888Z" }, + { url = "https://files.pythonhosted.org/packages/db/7b/8624a203326675d7746a254083a187398090a179335b2e4a20e2ddc46e83/scipy-1.17.1-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:3fd1fcdab3ea951b610dc4cef356d416d5802991e7e32b5254828d342f7b7e0b", size = 20342032, upload-time = "2026-02-23T00:21:09.904Z" }, + { url = "https://files.pythonhosted.org/packages/c9/35/2c342897c00775d688d8ff3987aced3426858fd89d5a0e26e020b660b301/scipy-1.17.1-cp314-cp314-macosx_14_0_x86_64.whl", hash = "sha256:7bdf2da170b67fdf10bca777614b1c7d96ae3ca5794fd9587dce41eb2966e866", size = 22678766, upload-time = "2026-02-23T00:21:14.313Z" }, + { url = "https://files.pythonhosted.org/packages/ef/f2/7cdb8eb308a1a6ae1e19f945913c82c23c0c442a462a46480ce487fdc0ac/scipy-1.17.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:adb2642e060a6549c343603a3851ba76ef0b74cc8c079a9a58121c7ec9fe2350", size = 32957007, upload-time = "2026-02-23T00:21:19.663Z" }, + { url = "https://files.pythonhosted.org/packages/0b/2e/7eea398450457ecb54e18e9d10110993fa65561c4f3add5e8eccd2b9cd41/scipy-1.17.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eee2cfda04c00a857206a4330f0c5e3e56535494e30ca445eb19ec624ae75118", size = 35221333, upload-time = "2026-02-23T00:21:25.278Z" }, + { url = "https://files.pythonhosted.org/packages/d9/77/5b8509d03b77f093a0d52e606d3c4f79e8b06d1d38c441dacb1e26cacf46/scipy-1.17.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:d2650c1fb97e184d12d8ba010493ee7b322864f7d3d00d3f9bb97d9c21de4068", size = 35042066, upload-time = "2026-02-23T00:21:31.358Z" }, + { url = "https://files.pythonhosted.org/packages/f9/df/18f80fb99df40b4070328d5ae5c596f2f00fffb50167e31439e932f29e7d/scipy-1.17.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:08b900519463543aa604a06bec02461558a6e1cef8fdbb8098f77a48a83c8118", size = 37612763, upload-time = "2026-02-23T00:21:37.247Z" }, + { url = "https://files.pythonhosted.org/packages/4b/39/f0e8ea762a764a9dc52aa7dabcfad51a354819de1f0d4652b6a1122424d6/scipy-1.17.1-cp314-cp314-win_amd64.whl", hash = "sha256:3877ac408e14da24a6196de0ddcace62092bfc12a83823e92e49e40747e52c19", size = 37290984, upload-time = "2026-02-23T00:22:35.023Z" }, + { url = "https://files.pythonhosted.org/packages/7c/56/fe201e3b0f93d1a8bcf75d3379affd228a63d7e2d80ab45467a74b494947/scipy-1.17.1-cp314-cp314-win_arm64.whl", hash = "sha256:f8885db0bc2bffa59d5c1b72fad7a6a92d3e80e7257f967dd81abb553a90d293", size = 25192877, upload-time = "2026-02-23T00:22:39.798Z" }, + { url = "https://files.pythonhosted.org/packages/96/ad/f8c414e121f82e02d76f310f16db9899c4fcde36710329502a6b2a3c0392/scipy-1.17.1-cp314-cp314t-macosx_10_14_x86_64.whl", hash = "sha256:1cc682cea2ae55524432f3cdff9e9a3be743d52a7443d0cba9017c23c87ae2f6", size = 31949750, upload-time = "2026-02-23T00:21:42.289Z" }, + { url = "https://files.pythonhosted.org/packages/7c/b0/c741e8865d61b67c81e255f4f0a832846c064e426636cd7de84e74d209be/scipy-1.17.1-cp314-cp314t-macosx_12_0_arm64.whl", hash = "sha256:2040ad4d1795a0ae89bfc7e8429677f365d45aa9fd5e4587cf1ea737f927b4a1", size = 28585858, upload-time = "2026-02-23T00:21:47.706Z" }, + { url = "https://files.pythonhosted.org/packages/ed/1b/3985219c6177866628fa7c2595bfd23f193ceebbe472c98a08824b9466ff/scipy-1.17.1-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:131f5aaea57602008f9822e2115029b55d4b5f7c070287699fe45c661d051e39", size = 20757723, upload-time = "2026-02-23T00:21:52.039Z" }, + { url = "https://files.pythonhosted.org/packages/c0/19/2a04aa25050d656d6f7b9e7b685cc83d6957fb101665bfd9369ca6534563/scipy-1.17.1-cp314-cp314t-macosx_14_0_x86_64.whl", hash = "sha256:9cdc1a2fcfd5c52cfb3045feb399f7b3ce822abdde3a193a6b9a60b3cb5854ca", size = 23043098, upload-time = "2026-02-23T00:21:56.185Z" }, + { url = "https://files.pythonhosted.org/packages/86/f1/3383beb9b5d0dbddd030335bf8a8b32d4317185efe495374f134d8be6cce/scipy-1.17.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6e3dcd57ab780c741fde8dc68619de988b966db759a3c3152e8e9142c26295ad", size = 33030397, upload-time = "2026-02-23T00:22:01.404Z" }, + { url = "https://files.pythonhosted.org/packages/41/68/8f21e8a65a5a03f25a79165ec9d2b28c00e66dc80546cf5eb803aeeff35b/scipy-1.17.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a9956e4d4f4a301ebf6cde39850333a6b6110799d470dbbb1e25326ac447f52a", size = 35281163, upload-time = "2026-02-23T00:22:07.024Z" }, + { url = "https://files.pythonhosted.org/packages/84/8d/c8a5e19479554007a5632ed7529e665c315ae7492b4f946b0deb39870e39/scipy-1.17.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:a4328d245944d09fd639771de275701ccadf5f781ba0ff092ad141e017eccda4", size = 35116291, upload-time = "2026-02-23T00:22:12.585Z" }, + { url = "https://files.pythonhosted.org/packages/52/52/e57eceff0e342a1f50e274264ed47497b59e6a4e3118808ee58ddda7b74a/scipy-1.17.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:a77cbd07b940d326d39a1d1b37817e2ee4d79cb30e7338f3d0cddffae70fcaa2", size = 37682317, upload-time = "2026-02-23T00:22:18.513Z" }, + { url = "https://files.pythonhosted.org/packages/11/2f/b29eafe4a3fbc3d6de9662b36e028d5f039e72d345e05c250e121a230dd4/scipy-1.17.1-cp314-cp314t-win_amd64.whl", hash = "sha256:eb092099205ef62cd1782b006658db09e2fed75bffcae7cc0d44052d8aa0f484", size = 37345327, upload-time = "2026-02-23T00:22:24.442Z" }, + { url = "https://files.pythonhosted.org/packages/07/39/338d9219c4e87f3e708f18857ecd24d22a0c3094752393319553096b98af/scipy-1.17.1-cp314-cp314t-win_arm64.whl", hash = "sha256:200e1050faffacc162be6a486a984a0497866ec54149a01270adc8a59b7c7d21", size = 25489165, upload-time = "2026-02-23T00:22:29.563Z" }, +] + [[package]] name = "setuptools" version = "80.10.2" From 491f4a5f63e6426b7e90572ff18f7b7235012895 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 17 Apr 2026 18:40:27 -0400 Subject: [PATCH 06/60] Fix __init__ --- src/nak_torch/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nak_torch/__init__.py b/src/nak_torch/__init__.py index 93ca127..e77b948 100644 --- a/src/nak_torch/__init__.py +++ b/src/nak_torch/__init__.py @@ -1,6 +1,6 @@ from . import algorithms, tools -from .tools import GaussianModel, metrics +from .tools import GaussianModel, LogisticRegressionModel, metrics __all__ = ["algorithms", "tools", "GaussianModel", "LogisticRegressionModel", "metrics"] From 09274f73a87aa18f0998a314c52047bcf69ae977 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 18 Apr 2026 10:09:56 -0400 Subject: [PATCH 07/60] Fix logistic regression prior --- src/nak_torch/tools/types.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index 01feeb9..17199bd 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -166,8 +166,7 @@ def as_tensor(t): ) def to_log_dens(self, use_compiled: bool = True): - def log_hyperprior(t): - return self.hyperprior.log_prob(t) + log_hyperprior = self.hyperprior.log_prob def log_dens(params: BatchPtType, use_train: bool = True) -> BatchType: is_batch = params.ndim == 2 @@ -181,9 +180,12 @@ def log_dens(params: BatchPtType, use_train: bool = True) -> BatchType: if self.prior_mean is not None: prior_diff -= self.prior_mean coeffs = params[:, :-1] - alpha = torch.exp(params[:, -1]) + log_alpha = params[:, -1] + alpha = torch.exp(log_alpha) hyperprior_term = log_hyperprior(alpha) - prior_term = -torch.sum(torch.square_(prior_diff), dim=-1).mul_(2 * alpha) + prior_term = prior_diff.square_().sum(dim=-1).mul_(0.5 * alpha).neg_() + # log-normalization constant of prior w.r.t. alpha = precision + prior_term += 0.5 * self.dim * log_alpha data: Float[Tensor, "dim-1 N_pts"] labels: Float[Tensor, " N_pts"] if use_train: From 5aabe44a55185b98cedcdeb825ef358913ed882b Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 18 Apr 2026 10:10:12 -0400 Subject: [PATCH 08/60] Clean up simple logreg --- examples/logistic_regression/simple_linear.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/logistic_regression/simple_linear.py b/examples/logistic_regression/simple_linear.py index c46b640..41c08b3 100644 --- a/examples/logistic_regression/simple_linear.py +++ b/examples/logistic_regression/simple_linear.py @@ -1,3 +1,4 @@ + # %% import torch from nak_torch.algorithms import msip, msip_gs, svgd @@ -25,7 +26,7 @@ regression_model = LogisticRegressionModel(data_path, None, hyperprior_b=0.01) log_dens = regression_model.to_log_dens(use_compiled=False) -plt.scatter(regression_model.data[1], regression_model.data[2], c=regression_model.labels, alpha=0.4) +plt.scatter(regression_model.train_data[1], regression_model.train_data[2], c=regression_model.train_labels, alpha=0.4) plt.show() # %% @@ -79,11 +80,11 @@ def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, di # %% msip_idx = 999 msip_final_pts, msip_final_wts = trajectories_msip[msip_idx], traj_wts_msip[msip_idx] -logit_out = msip_final_pts[:,:-1] @ regression_model.data +logit_out = msip_final_pts[:,:-1] @ regression_model.train_data prob_out = torch.nn.functional.sigmoid(logit_out) fig, axs = plt.subplots(4,5,figsize=(5*1.25,4.5*1.25)) -sc_data = regression_model.data[1:] +sc_data = regression_model.train_data[1:] for i in range(4): for j in range(5): ax = axs[i,j] @@ -96,9 +97,6 @@ def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, di fig.suptitle("Different regression outcomes, MSIP wt as title") plt.show() -# %% -plt.scatter(sc_data[0], sc_data[1], c=regression_model.labels) - # %% n_steps_hmc = 1000 pyro_model = pyro_tools.pyro_model_factory(regression_model, 4) From 696c1f6404ff8c1184a4b1ced7c812c95f95efb9 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 18 Apr 2026 11:25:56 -0400 Subject: [PATCH 09/60] Work on covtype logreg --- examples/logistic_regression/covertype.py | 31 +++++++++++++---------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/logistic_regression/covertype.py b/examples/logistic_regression/covertype.py index 488aa34..7e394b3 100644 --- a/examples/logistic_regression/covertype.py +++ b/examples/logistic_regression/covertype.py @@ -42,7 +42,7 @@ def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): # %% data_path = DATA_PATH -regression_model = LogisticRegressionModel(data_path, None, hyperprior_b=0.01, train_proportion=0.8, sum_bernoulli=True) +regression_model = LogisticRegressionModel(data_path, None, hyperprior_b=0.01, train_proportion=0.8, sum_bernoulli=False) log_dens = regression_model.to_log_dens(use_compiled=False) # %% @@ -52,18 +52,19 @@ def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): # %% n_particles, state_dim = 20, regression_model.dim -coeff_init = torch.randn((n_particles, regression_model.dim - 1)) -alpha_init = torch.log(regression_model.hyperprior.sample((n_particles,))) -init_particles = torch.column_stack((coeff_init, alpha_init)) +alpha_init = regression_model.hyperprior.sample((n_particles,1)) +log_alpha_init = alpha_init.log() +coeff_init = torch.randn((n_particles, regression_model.dim - 1)) / alpha_init.sqrt() +init_particles = torch.column_stack((coeff_init, log_alpha_init)) log_dens(init_particles) # test eval # %% kernel_length_scale = 0.05 bounds = (-100.0, 100.0) -gradient_decay = 0.75 -lr_msip = 1e-1 -kernel_diag_infl = 1e-6 -n_steps = 20 +gradient_decay = 0.9 +lr_msip = 0.05 +kernel_diag_infl = 1e-5 +n_steps = 1000 grad_val_log_p = torch.vmap(torch.func.grad_and_value(log_dens)) @torch.compile(dynamic=False) @@ -83,7 +84,7 @@ def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, di # %% trajectories_msip, traj_wts_msip = msip( - msip_gi, + msip_f, n_particles, n_steps, dim=state_dim, @@ -107,21 +108,26 @@ def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, di plt.hist(lower_tri_dist) # %% +from tqdm import tqdm bce_logit_v = torch.vmap(torch.nn.functional.binary_cross_entropy_with_logits, in_dims=(0,None)) # @torch.compile def bce_logit_t(traj_t): logits_t = traj_t[:,:-1] @ regression_model.test_data return bce_logit_v(logits_t, regression_model.test_labels) -# bce_logit_traj = torch.vmap(bce_logit_t) -bce_traj = torch.stack([bce_logit_t(trajectories_msip[j]) for j in range(trajectories_msip.shape[0])]) +bce_logit_traj = torch.vmap(bce_logit_t) +bse_traj_list = [] +for j in tqdm(range(trajectories_msip.shape[0])): + bse_traj_list.append(bce_logit_t(trajectories_msip[j])) +bce_traj = torch.stack(bse_traj_list) # logits_t = trajectories_msip[:,:,:-1].reshape(-1, trajectories_msip.shape[-1] - 1) @ regression_model.data # bce_traj = bce_logit_v(logits_t, regression_model.labels).reshape(*trajectories_msip.shape[:2], -1) # print("BCE t=0: {}, BCE t=T: {}".format(bce_0.mean(), bce_T.mean())) +# %% fig, ax = plt.subplots() for particle_idx in range(n_particles): - ax.plot(bce_traj[:,particle_idx], alpha= 0.4) + ax.loglog(bce_traj[:,particle_idx], alpha= 0.4) plt.show() # %% @@ -137,7 +143,6 @@ def accuracy(coeffs): accuracy_v(trajectories_msip[-1]) # %% - trajectories_msip, traj_wts_msip = svgd( msip_f, n_particles, From d57250daea12998a9197d6f51d82ce5f6f0b6acb Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 21 Apr 2026 14:44:05 -0400 Subject: [PATCH 10/60] Start creating core loop --- examples/pyro_tools.py | 6 +- src/nak_torch/algorithms/loop.py | 84 ++++++++++++++++++++ src/nak_torch/algorithms/msip/estimators.py | 22 +++--- src/nak_torch/tools/func.py | 57 ++++++++++++++ src/nak_torch/tools/types.py | 86 ++++++++++++--------- src/nak_torch/tools/util.py | 15 +++- 6 files changed, 214 insertions(+), 56 deletions(-) create mode 100644 src/nak_torch/algorithms/loop.py create mode 100644 src/nak_torch/tools/func.py diff --git a/examples/pyro_tools.py b/examples/pyro_tools.py index 3508e2b..aa6104a 100644 --- a/examples/pyro_tools.py +++ b/examples/pyro_tools.py @@ -1,12 +1,10 @@ -from typing import Optional, Union +from typing import Optional import torch import pyro import pyro.distributions as dist from nak_torch import GaussianModel - -DeviceLike = Union[str, torch.device, int] - +from nak_torch.tools.types import DeviceLike def get_pyro_std_from_prec( prec: torch.Tensor, dim: Optional[int] = None diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py new file mode 100644 index 0000000..afc25cd --- /dev/null +++ b/src/nak_torch/algorithms/loop.py @@ -0,0 +1,84 @@ +from typing import Any, Optional + +from tqdm import tqdm +import numpy as np +import torch +from torch import Tensor + +from nak_torch.tools.util import initialize_particles +from nak_torch.tools.types import ( + BatchDensityEvaluator, +) + +from nak_torch.tools.func import NAKAlgorithm, WeightedNAKAlgorithm + + +def nak( + log_density: BatchDensityEvaluator, + algorithm: NAKAlgorithm, + n_particles: int, + n_steps: int, + lr: float, + seed: Optional[int] = None, + init_particles: Optional[Tensor | np.ndarray] = None, + bounds: Optional[tuple[float, float]] = None, + keep_all: bool = True, + target_args: Any = None, + verbose: bool = False, +) -> Tensor | tuple[Tensor, Tensor]: + r""" + TODO: Document + """ + + if n_steps < 0: + raise ValueError("Expected positive number of steps.") + + if seed is not None: + torch.manual_seed(seed) + + dim, device, dtype = algorithm.dim, algorithm.device, algorithm.dtype + particles = initialize_particles( + n_particles, dim, init_particles, device, dtype, bounds + ) + is_weighted = isinstance(algorithm, WeightedNAKAlgorithm) + if keep_all: + trajectories = torch.empty( + (n_steps + 1, *particles.shape), device=device, dtype=dtype + ) + trajectories[0].copy_(particles) + if is_weighted: + traj_wts = torch.empty( + (n_steps + 1, particles.shape[0]), device=device, dtype=dtype + ) + else: + traj_wts = torch.empty(()) + else: + trajectories = torch.empty(()) + traj_wts = torch.empty(()) + particle_wts = torch.empty(()) + for idx in tqdm(range(n_steps + 1), disable=not verbose): + algorithm_args = algorithm.update(particles) + + if keep_all and is_weighted: + particle_wts = algorithm.get_weights(particles, target_args) + traj_wts[idx].copy_(particle_wts) + + if idx < n_steps: + particles, algorithm_args = algorithm( + lr, log_density, particles, algorithm_args, target_args + ) + + if bounds is not None: + particles.clamp_(bounds[0], bounds[1]) + + if keep_all: + trajectories[idx + 1].copy_(particles) + + if not keep_all: + trajectories = particles.unsqueeze_(0) + if is_weighted: + traj_wts = particle_wts.unsqueeze_(0) + if is_weighted: + return trajectories.detach(), traj_wts.detach() + + return trajectories.detach() diff --git a/src/nak_torch/algorithms/msip/estimators.py b/src/nak_torch/algorithms/msip/estimators.py index aa33d50..b70097f 100644 --- a/src/nak_torch/algorithms/msip/estimators.py +++ b/src/nak_torch/algorithms/msip/estimators.py @@ -1,12 +1,12 @@ import torch -from abc import ABC, abstractmethod +from abc import abstractmethod from nak_torch.tools.average import recursive_weighted_average_alpha_v from nak_torch.tools.types import ( - BatchPtType, MSIPEstimatorOutput, BatchLogDensityGradVal, BatchLogDensity, BatchQuadratureRule, + BatchDensityEvaluator, ) from jaxtyping import Float from torch import Tensor @@ -14,11 +14,9 @@ __all__ = ["MSIPFredholm", "MSIPQuadGradientFree", "MSIPQuadGradientInformed"] -class MSIPEstimator(ABC): +class MSIPEstimator(BatchDensityEvaluator[MSIPEstimatorOutput]): @abstractmethod - def get_v_evals( - self, particles: BatchPtType, kernel_length_scale: float - ) -> MSIPEstimatorOutput: + def __call__(self, particles, evaluator_args, *target_args) -> MSIPEstimatorOutput: r""" Function that returns estimation of $(\log(v_0), sigma^2 * \nabla \log v_0(y)$ Note that @@ -40,8 +38,8 @@ def __init__( self.gradient_decay = gradient_decay self.log_dens_grad_val = log_dens_grad_val - def get_v_evals(self, particles, kernel_length_scale): - grads, v0 = self.log_dens_grad_val(particles) + def __call__(self, particles, kernel_length_scale, *target_args): + grads, v0 = self.log_dens_grad_val(particles, *target_args) sigma_sq_log_v0 = grads.mul_(kernel_length_scale * self.gradient_decay) return v0, sigma_sq_log_v0 @@ -63,7 +61,7 @@ def __init__( self.quadrature = quadrature self.log_dens = log_dens - def get_v_evals(self, particles, kernel_length_scale): + def __call__(self, particles, kernel_length_scale, *args): n_particles, dim = particles.shape quad_pts, quad_wts = self.quadrature(n_particles) @@ -94,13 +92,13 @@ def __init__( self.quadrature, self.gradient_decay = quadrature, gradient_decay self.log_dens_grad_val = log_dens_grad_val - def get_v_evals(self, particles, kernel_length_scale): + def __call__(self, particles, kernel_length_scale, *target_args): quad_pts, quad_wts = self.quadrature(particles.shape[0]) particle_quad_pts = quad_pts.mul_(kernel_length_scale).add( particles.unsqueeze(1) ) # (N_part, N_quad, dim) log_dens_grads, log_dens_evals = self.log_dens_grad_val( - particle_quad_pts.reshape(-1, particles.shape[1]) + particle_quad_pts.reshape(-1, particles.shape[1]), *target_args ) log_dens_grads = log_dens_grads.reshape_as(particle_quad_pts) @@ -133,7 +131,7 @@ def __init__( self.covariances = covariances self.bandwidth = bandwidth - def get_v_evals(self, particles, kernel_length_scale) -> MSIPEstimatorOutput: + def __call__(self, particles, kernel_length_scale, *_) -> MSIPEstimatorOutput: N, D = particles.shape dtype, device = particles.dtype, particles.device sigma_sq = torch.as_tensor( diff --git a/src/nak_torch/tools/func.py b/src/nak_torch/tools/func.py new file mode 100644 index 0000000..8cdb85a --- /dev/null +++ b/src/nak_torch/tools/func.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import Generic, Optional, TypeVar +import torch +from .types import ( + BatchPtType, + BatchType, + DeviceLike, + BatchDensityEvaluator, +) + +BatchDensityEvaluatorT = TypeVar("BatchDensityEvaluatorT", bound=BatchDensityEvaluator) +AlgorithmArgsT = TypeVar("AlgorithmArgsT") + + +class AdaptiveNAKAlgorithm(ABC, Generic[BatchDensityEvaluatorT, AlgorithmArgsT]): + dim: int + n_particles: int + device: Optional[DeviceLike] + dtype: Optional[torch.dtype] + + @abstractmethod + def __call__( + self, + lr: float, + target: BatchDensityEvaluatorT, + points: BatchPtType, + algorithm_args: AlgorithmArgsT, + target_args, + ) -> BatchPtType: + pass + + @abstractmethod + def update(self, particles: BatchPtType) -> AlgorithmArgsT: + pass + + def get_weights(self, points: BatchPtType, target_args) -> BatchType: + N_ens = points.shape[0] + return torch.ones(N_ens, dtype=points.dtype, device=points.device) / N_ens + + +class NAKAlgorithm(AdaptiveNAKAlgorithm[BatchDensityEvaluatorT, None]): + def update(self, particles: BatchPtType) -> None: + return None + + +class WeightedAdaptiveNAKAlgorithm( + AdaptiveNAKAlgorithm[BatchDensityEvaluatorT, AlgorithmArgsT] +): + @abstractmethod + def get_weights(self, points: BatchPtType, target_args) -> BatchType: + pass + + +class WeightedNAKAlgorithm(NAKAlgorithm[BatchDensityEvaluatorT]): + @abstractmethod + def get_weights(self, points: BatchPtType, target_args) -> BatchType: + pass diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index 216ed00..1e251b0 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -1,8 +1,12 @@ +from typing import Any, Callable, Optional, Protocol, TypeVar, Generic +from dataclasses import dataclass +from abc import ABC, abstractmethod + import torch from torch import Tensor from jaxtyping import Float -from typing import Callable, Optional, Protocol -from dataclasses import dataclass + +DeviceLike = str | torch.device | int BatchType = Float[Tensor, "batch"] PtType = Float[Tensor, " d"] @@ -14,10 +18,21 @@ KernelMatrixType = Float[Tensor, "batch batch"] GradKernelMatrixType = Float[Tensor, "batch batch d"] +DensityGradValOutput = tuple[BatchPtType, BatchType] MSIPEstimatorOutput = tuple[BatchType, BatchPtType] KernelFunction = Callable[[PtType, PtType, float], Float] +EvaluatorOutput = TypeVar("EvaluatorOutput") + + +class BatchDensityEvaluator(ABC, Generic[EvaluatorOutput]): + @abstractmethod + def __call__( + self, particles: BatchPtType, evaluator_args, *target_args + ) -> EvaluatorOutput: + pass + class MatSelfKernelFunction(Protocol): def __call__( @@ -28,23 +43,25 @@ def __call__( ) -> KernelMatrixType: ... -LogDensity = Callable[[PtType], Float] +LogDensity = Callable[[PtType, Any], Float] -GradLogDensity = Callable[[PtType], PtType] +GradLogDensity = Callable[[PtType, Any], PtType] -LogDensityGradVal = Callable[[PtType], tuple[PtType, Float]] +LogDensityGradVal = Callable[[PtType, Any], tuple[PtType, Float]] -BatchLogDensity = Callable[[BatchPtType], BatchType] +BatchLogDensity = Callable[[BatchPtType, Any], BatchType] -BatchLogDensityGradVal = Callable[[BatchPtType], tuple[BatchPtType, BatchType]] +BatchLogDensityGradVal = Callable[[BatchPtType, Any], DensityGradValOutput] -BatchGradLogDensity = Callable[[BatchPtType], BatchPtType] +BatchGradLogDensity = Callable[[BatchPtType, Any], BatchPtType] BatchQuadratureRule = Callable[[int], tuple[BatchQuadrulePtType, BatchQuadruleWtType]] -ForwardModel = Callable[[Float[Tensor, " dim"]], Float[Tensor, " obs"]] +ForwardModel = Callable[[Float[Tensor, " dim"], Any], Float[Tensor, " obs"]] -BatchForwardModel = Callable[[Float[Tensor, "batch dim"]], Float[Tensor, "batch obs"]] +BatchForwardModel = Callable[ + [Float[Tensor, "batch dim"], Any], Float[Tensor, "batch obs"] +] @dataclass @@ -64,35 +81,32 @@ def __init__( prior_mean: float | Float[Tensor, " dim"] = 0.0, is_vectorized: bool = False, ): - if not is_vectorized: - forward_model = torch.vmap(forward_model) - self.forward_model = forward_model + batch_forward_model: BatchForwardModel + if is_vectorized: + batch_forward_model = forward_model # type: ignore + else: + batch_forward_model = torch.vmap(forward_model, in_dims=(0, None)) + self.forward_model = batch_forward_model self.prior_mean = prior_mean self.likelihood_precision = likelihood_precision self.prior_precision = prior_precision self.true_obs = true_obs self.prior_mean = prior_mean - def to_log_dens(self, use_compiled: bool = True): - return gaussian_log_dens_factory(self, use_compiled) - - -def gaussian_log_dens_factory( - model: GaussianModel, use_compiled: bool = True -) -> BatchLogDensity: - def log_dens(pts: BatchPtType) -> BatchType: - model_eval = model.forward_model(pts) - obs_error = model_eval.sub_(model.true_obs) - like_term = torch.square(torch.linalg.norm(obs_error, dim=-1)).mul_( - model.likelihood_precision - ) - like_term.mul_(model.likelihood_precision) - prior_diff = pts - if model.prior_mean != 0.0: - prior_diff -= model.prior_mean - prior_term = torch.square(torch.linalg.norm(prior_diff, dim=-1)).mul_( - model.prior_precision - ) - return -0.5 * (prior_term + like_term) - - return torch.compile(log_dens) if use_compiled else log_dens + def to_log_dens(self, use_compiled: bool = True) -> BatchLogDensity: + def log_dens(pts: BatchPtType, aux_args: Any) -> BatchType: + model_eval = self.forward_model(pts, aux_args) + obs_error = model_eval.sub_(self.true_obs) + like_term = torch.square(torch.linalg.norm(obs_error, dim=-1)).mul_( + self.likelihood_precision + ) + like_term.mul_(self.likelihood_precision) + prior_diff = pts + if self.prior_mean != 0.0: + prior_diff -= self.prior_mean + prior_term = torch.square(torch.linalg.norm(prior_diff, dim=-1)).mul_( + self.prior_precision + ) + return -0.5 * (prior_term + like_term) + + return torch.compile(log_dens) if use_compiled else log_dens diff --git a/src/nak_torch/tools/util.py b/src/nak_torch/tools/util.py index 27dc9d0..f762cd9 100644 --- a/src/nak_torch/tools/util.py +++ b/src/nak_torch/tools/util.py @@ -2,7 +2,7 @@ from torch import Tensor from jaxtyping import Float from typing import Optional, Callable -from .types import BatchGradLogDensity, BatchPtType +from .types import BatchGradLogDensity, BatchPtType, DeviceLike import numpy as np import inspect @@ -24,7 +24,8 @@ def initialize_particles( n_particles: int, dim: int, init_particles: Optional[Tensor | np.ndarray], - device: Optional[torch.device], + device: Optional[DeviceLike], + dtype: Optional[torch.dtype], bounds: Optional[tuple[float, float]], rng: Optional[torch.Generator] = None, ) -> BatchPtType: @@ -48,6 +49,12 @@ def initialize_particles( init_particles.device, torch.device(device) ) ) + if dtype is not None and init_particles.dtype != dtype: + raise ValueError( + "Unexpected dtype for init_particles: got {}, expected {}".format( + init_particles.dtype, dtype + ) + ) return torch.as_tensor(init_particles, device=device).clone() @@ -58,9 +65,9 @@ def batched_grad_log_density_factory( ) -> BatchGradLogDensity: if grad_log_density is None: if is_log_density_batched: - return torch.func.grad(lambda p: log_density(p).sum()) + return torch.func.grad(lambda p, a: log_density(p, a).sum()) else: - return torch.vmap(torch.func.grad(log_density)) + return torch.vmap(torch.func.grad(log_density), in_dims=(0, None)) else: return grad_log_density From bde9a8c5079a24964e69ca7ecc6ffb18ead66501 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 21 Apr 2026 14:46:26 -0400 Subject: [PATCH 11/60] Remove old algorithms --- .../algorithms/msip/msip_geom_greedy.py | 187 ------------------ src/nak_torch/algorithms/msip/msip_greedy.py | 113 ----------- src/nak_torch/algorithms/msip/msip_ni.py | 118 ----------- 3 files changed, 418 deletions(-) delete mode 100644 src/nak_torch/algorithms/msip/msip_geom_greedy.py delete mode 100644 src/nak_torch/algorithms/msip/msip_greedy.py delete mode 100644 src/nak_torch/algorithms/msip/msip_ni.py diff --git a/src/nak_torch/algorithms/msip/msip_geom_greedy.py b/src/nak_torch/algorithms/msip/msip_geom_greedy.py deleted file mode 100644 index c0fe881..0000000 --- a/src/nak_torch/algorithms/msip/msip_geom_greedy.py +++ /dev/null @@ -1,187 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - - -# This file contains the implementation of mean shift interacting particles_greedy -# Ayoub Belhadji -# 05/12/2025 - -import numpy as np -import torch -import copy - -from .msip_map import msip_map - - -def _geometric_safe_step( - particles, - idx, - target, - base_lr, - lr_max, - min_separation, - min_lr=1e-6, - shrink_factor=0.5, -): - """ - Purely geometric line-search on the segment from x_i to target. - We look for the largest step lambda in (0, min(base_lr, lr_max)] such that - the new point is at least 'min_separation' away from all other particles. - No objective or gradient evaluations are used. - """ - with torch.no_grad(): - x_i = particles[idx] - direction = target - x_i - dir_norm = direction.norm() - - # If direction is (numerically) zero, no move - if dir_norm < 1e-12: - return x_i.clone(), 0.0, False - - # Normalize direction once; we will scale by effective step. - # However, to stay consistent with your convex combination form - # we interpret "step" as the convex weight in [0,1]. - # direction_unit = direction / dir_norm - - # Effective step is in [0, min(base_lr, lr_max)] - step = float(min(base_lr, lr_max)) - min_sep_sq = float(min_separation**2) - - # Build the set of "other" particles - if particles.shape[0] > 1: - others = torch.cat([particles[:idx], particles[idx + 1 :]], dim=0) - else: - # Only one particle: trivially safe - new_pos = x_i + step * direction - return new_pos, step, True - - moved = False - while step >= min_lr: - cand = x_i + step * direction # keep your convex formulation - diff = cand.unsqueeze(0) - others - dist_sq = (diff**2).sum(dim=1) - if (dist_sq >= min_sep_sq).all(): - # Geometrically safe endpoint - moved = True - return cand, step, moved - # Otherwise shrink the step - step *= shrink_factor - - # Could not find a safe step above min_lr -> skip move - return x_i.clone(), 0.0, False - - -def update_one_particle( - objective_function, - particles, - idx, - lr=0.1, - kernel_bandwidth=1.0, - inner_tol=1e-4, - max_inner_steps=50, - min_separation=0.2, - lr_max=0.5, - min_lr=1e-6, - shrink_factor=0.5, -): - """ - Coordinate-wise MSIP update with geometric safety: - - All particles are kept fixed except particle `idx`. - - For particle `idx`, we iterate until the MSIP update is small (equilibrium) - or max_inner_steps is reached. - - The step size along the MSIP direction is adapted geometrically to ensure - the updated particle remains at least `min_separation` away from all others. - If no such step >= min_lr exists, we skip the update for this particle. - - The effective step is also capped by `lr_max`. - - Returns a list of snapshots of the particle system. - """ - new_list_particles = [copy.deepcopy(particles)] - - for _ in range(max_inner_steps): - # Compute full MSIP map given current particles - t_arr = msip_map( - objective_function, particles, kernel_bandwidth=kernel_bandwidth - ) - - with torch.no_grad(): - old_pos = particles[idx].clone() - target = t_arr[idx] - - # Purely geometric adaptive step: - new_pos, eff_lr, moved = _geometric_safe_step( - particles=particles, - idx=idx, - target=target, - base_lr=lr, - lr_max=lr_max, - min_separation=min_separation, - min_lr=min_lr, - shrink_factor=shrink_factor, - ) - - move_norm = (new_pos - old_pos).norm() - - if moved: - particles[idx].copy_(new_pos) - new_list_particles.append( - copy.deepcopy(particles.detach().cpu().numpy()) - ) - else: - # Could not move without violating separation; consider this an equilibrium - print(f"Particle {idx}: geometric blocking after {_} inner steps.") - break - - if move_norm.item() < inner_tol: - print(f"Particle {idx}: inner convergence at step {_}.") - break - - return new_list_particles - - -def msip_geom_greedy( - objective_function, - n_particles=50, - n_steps=10, # now interpreted as "epochs" (passes over all particles) - dim=2, - bounds=(-5, 5), - lr=0.1, - noise=0.05, # currently unused, kept for compatibility - kernel_bandwidth=1.0, - inner_tol=1e-4, # equilibrium tolerance for a particle - max_inner_steps=50, # max inner iterations per particle - min_separation=0.2, # NEW: minimal allowed distance between particles - lr_max=0.5, # NEW: global cap on effective learning rate - min_lr=1e-6, - shrink_factor=0.5, - seed=None, - device="cpu", -): - if seed is not None: - torch.manual_seed(seed) - - # Init particles - particles = torch.empty((n_particles, dim), device=device).uniform_(*bounds) - - trajectories = [particles.detach().cpu().numpy().copy()] - - # Outer loop: epochs - for _ in range(n_steps): - # Loop over particles, one at a time - for i in range(n_particles): - new_list_particles = update_one_particle( - objective_function, - particles, - idx=i, - lr=lr, - kernel_bandwidth=kernel_bandwidth, - inner_tol=inner_tol, - max_inner_steps=max_inner_steps, - min_separation=min_separation, - lr_max=lr_max, - min_lr=min_lr, - shrink_factor=shrink_factor, - ) - trajectories = trajectories + new_list_particles - - return np.array(trajectories), bounds diff --git a/src/nak_torch/algorithms/msip/msip_greedy.py b/src/nak_torch/algorithms/msip/msip_greedy.py deleted file mode 100644 index c805391..0000000 --- a/src/nak_torch/algorithms/msip/msip_greedy.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - - -# This file contains the implementation of mean shift interacting particles_greedy -# Ayoub Belhadji -# 05/12/2025 - -import numpy as np -import torch -from tqdm import tqdm -from typing import Optional - -from nak_torch.tools.util import initialize_particles -from .msip_map import msip_map - - -def update_one_particle( - objective_function, - particles: torch.Tensor, - idx: int, - lr: float = 0.1, - inner_tol: float = 1e-4, - max_inner_steps: int = 50, - # kernel_bandwidth: float = 1.0, - # bandwidth_factor: float = 0.5, - # bounds: tuple[float, float] = (-torch.inf, torch.inf), - # projection: bool = True, - # gradient_informed: bool = True, - # kernel_diag_infl: float = 0.0, - **msip_kwargs, -): - """ - Coordinate-wise MSIP update: - - All particles are kept fixed except particle `idx` - - For particle `idx`, we iterate until the MSIP update is small - or max_inner_steps is reached. - Mutates `particles` in-place and returns it. - """ - - new_list_particles = [] - for _ in range(max_inner_steps): - # Compute full MSIP map given current particles - - t_arr = msip_map( - objective_function, - particles, - # kernel_bandwidth, - # bandwidth_factor, - # bounds, - # projection, - # gradient_informed, - # kernel_diag_infl, - output_idx=idx, - **msip_kwargs, - ) - - with torch.no_grad(): - old_pos = particles[idx] - new_pos = (1.0 - lr) * old_pos + lr * t_arr - - move_norm = (new_pos - old_pos).norm() - particles[idx].copy_(new_pos) - new_list_particles.append(particles.detach().cpu().numpy().copy()) - - if move_norm.isnan(): - print("nan") - - if move_norm.item() < inner_tol: - break - - return new_list_particles - - -def msip_greedy( - log_density, - n_particles: int, - # now interpreted as "epochs" (passes over all particles) - n_steps: int, - dim: int, - lr: float, - noise: float = 0.05, # currently unused, kept for compatibility - seed: Optional[int] = None, - device: Optional[torch.device] = None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - bounds: Optional[tuple[float, float]] = None, - **msip_kwargs, -): - - if seed is not None: - torch.manual_seed(seed) - - particles = initialize_particles(n_particles, dim, init_particles, device, bounds) - - trajectories = [particles.detach().cpu().numpy().copy()] - - # Outer loop: epochs - pbar = tqdm(total=n_steps * n_particles) - for _ in range(n_steps): - # Loop over particles, one at a time - for i in range(n_particles): - new_list_particles = update_one_particle( - log_density, particles, idx=i, lr=lr, bounds=bounds, **msip_kwargs - ) - # If you want a very fine-grained trajectory, record after each particle: - # trajectories.append(particles.detach().cpu().numpy().copy()) - trajectories = trajectories + new_list_particles - pbar.update() - - # If you prefer only one snapshot per epoch, move the append here instead: - # trajectories.append(particles.detach().cpu().numpy().copy()) - - return torch.tensor(trajectories) diff --git a/src/nak_torch/algorithms/msip/msip_ni.py b/src/nak_torch/algorithms/msip/msip_ni.py deleted file mode 100644 index c3773f6..0000000 --- a/src/nak_torch/algorithms/msip/msip_ni.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - - -# This file contains the implementation of mean shift interacting particles -# Ayoub Belhadji -# 05/12/2025 - -# TODO: FIX THIS CODE OR DISCARD IT - -import numpy as np -import torch -from nak_torch.tools.average import recursive_weighted_average_alpha_v - - -def update_particles_ni( - objective_function, - particles, - t, - lr=0.1, - kernel_bandwidth=1.0, - noise_level=0.01, - noise_injection=None, # <-- new argument: a map for noise -): - # Make sure this is a leaf with grad - particles = particles.detach().clone() - particles.requires_grad_(True) - - fitness = objective_function(particles) # shape (N,) - # If fitness is vector, you usually do sum() to get grads w.r.t. all particles - (grads,) = torch.autograd.grad(fitness.sum(), particles) - # Be aware that this is the gradient of the log(p), which is the grad of V - - # expf_times_y = torch.exp(fitness.unsqueeze(-1)) * particles - - # From here on, think of 'fitness' and 'grads' as just arrays - with torch.no_grad(): - diff = particles.unsqueeze(1) - particles.unsqueeze(0) - sigma2 = kernel_bandwidth**2 - kernel_matrix = torch.exp(-(diff**2).sum(dim=-1) / sigma2) - # print(kernel_matrix) - - K_minus_one = torch.linalg.inv(kernel_matrix) - # print(K_minus_one) - - N, d = particles.shape - t_list = [] - for i in range(N): - alpha_i = K_minus_one[i, :] - - # all these calls now use 'plain' tensors - # t1 = recursive_weighted_average_alpha_v( - # particles, alpha_i, log_v=torch.log(fitness) - # ) - - # t1_1 = recursive_weighted_average_alpha_v(expf_times_y, alpha_i) - # t2_1 = recursive_weighted_average_alpha_v(grads, alpha_i) - # t2_2 = recursive_weighted_average_alpha_v(fitness.unsqueeze(-1), alpha_i) - # t2 = t2_1 / t2_2 - t1 = recursive_weighted_average_alpha_v(particles, alpha_i, log_v=fitness) - t2 = recursive_weighted_average_alpha_v(grads, alpha_i, log_v=fitness) - # - # t1_1 / t2_2 - # print(t2) - t_list.append(t1 + sigma2 * t2) - - t_arr = torch.stack(t_list, dim=0) - - # deterministic transport step - particles_det = (1 - lr) * particles + lr * t_arr - - # optional noise injection map - if noise_injection is not None: - # expect: noise_injection(particles, t, noise_level) -> new_particles - particles_new = noise_injection(particles_det, t, noise_level) - else: - particles_new = particles_det - - return particles_new - - -def sqexp_noise_injection(particles, t, noise_level): - return particles + noise_level * torch.randn_like(particles) - - -def msip_ni( - objective_function, - n_particles=50, - n_steps=100, - dim=2, - bounds=(-5, 5), - lr=0.1, - noise_level_0=0.05, - kernel_bandwidth=1.0, - seed=None, - device="cpu", -): - if seed is not None: - torch.manual_seed(seed) - - particles = torch.empty((n_particles, dim), device=device).uniform_(*bounds) - - trajectories = [particles.detach().cpu().numpy().copy()] - - for t in range(n_steps): - noise_level = noise_level_0 / (t + 1) - particles = update_particles_ni( - objective_function, - particles, - 0, - lr, - kernel_bandwidth, - noise_level, - sqexp_noise_injection, - ) - trajectories.append(particles.detach().cpu().numpy().copy()) - - return np.array(trajectories), bounds From e12c4633235117da31c5ecd4921abd55de89d7a0 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 21 Apr 2026 18:51:00 -0400 Subject: [PATCH 12/60] First draft MSIP with new setup --- src/nak_torch/algorithms/loop.py | 45 ++--- src/nak_torch/algorithms/msip/estimators.py | 20 +- src/nak_torch/algorithms/msip/msip.py | 197 +++++++++----------- src/nak_torch/tools/func.py | 54 +++--- src/nak_torch/tools/kernel.py | 7 +- src/nak_torch/tools/types.py | 8 +- 6 files changed, 156 insertions(+), 175 deletions(-) diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index afc25cd..9a472e0 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -10,12 +10,14 @@ BatchDensityEvaluator, ) -from nak_torch.tools.func import NAKAlgorithm, WeightedNAKAlgorithm +from nak_torch.tools.func import ( + GeneralAdaptiveNAKAlgorithm, +) def nak( - log_density: BatchDensityEvaluator, - algorithm: NAKAlgorithm, + target: BatchDensityEvaluator, + algorithm: GeneralAdaptiveNAKAlgorithm, n_particles: int, n_steps: int, lr: float, @@ -40,45 +42,44 @@ def nak( particles = initialize_particles( n_particles, dim, init_particles, device, dtype, bounds ) - is_weighted = isinstance(algorithm, WeightedNAKAlgorithm) + + particle_wts, algorithm_args = algorithm.initialize(particles, target, target_args) + if keep_all: trajectories = torch.empty( (n_steps + 1, *particles.shape), device=device, dtype=dtype ) trajectories[0].copy_(particles) - if is_weighted: + if algorithm.is_weighted(): traj_wts = torch.empty( (n_steps + 1, particles.shape[0]), device=device, dtype=dtype ) + traj_wts[0].copy_(particle_wts) else: traj_wts = torch.empty(()) else: trajectories = torch.empty(()) traj_wts = torch.empty(()) - particle_wts = torch.empty(()) - for idx in tqdm(range(n_steps + 1), disable=not verbose): - algorithm_args = algorithm.update(particles) - if keep_all and is_weighted: - particle_wts = algorithm.get_weights(particles, target_args) - traj_wts[idx].copy_(particle_wts) + for idx in tqdm(range(n_steps), disable=not verbose): + if keep_all: + trajectories[idx + 1].copy_(particles) + if algorithm.is_weighted() and keep_all: + traj_wts[idx + 1].copy_(particle_wts) - if idx < n_steps: - particles, algorithm_args = algorithm( - lr, log_density, particles, algorithm_args, target_args - ) - - if bounds is not None: - particles.clamp_(bounds[0], bounds[1]) + particles, particle_wts, algorithm_args = algorithm.step( + lr, particles, algorithm_args, target, target_args + ) - if keep_all: - trajectories[idx + 1].copy_(particles) + if bounds is not None: + particles.clamp_(bounds[0], bounds[1]) if not keep_all: trajectories = particles.unsqueeze_(0) - if is_weighted: + if algorithm.is_weighted(): traj_wts = particle_wts.unsqueeze_(0) - if is_weighted: + + if algorithm.is_weighted(): return trajectories.detach(), traj_wts.detach() return trajectories.detach() diff --git a/src/nak_torch/algorithms/msip/estimators.py b/src/nak_torch/algorithms/msip/estimators.py index b70097f..1beba53 100644 --- a/src/nak_torch/algorithms/msip/estimators.py +++ b/src/nak_torch/algorithms/msip/estimators.py @@ -16,7 +16,7 @@ class MSIPEstimator(BatchDensityEvaluator[MSIPEstimatorOutput]): @abstractmethod - def __call__(self, particles, evaluator_args, *target_args) -> MSIPEstimatorOutput: + def __call__(self, particles, evaluator_args, target_args) -> MSIPEstimatorOutput: r""" Function that returns estimation of $(\log(v_0), sigma^2 * \nabla \log v_0(y)$ Note that @@ -38,8 +38,8 @@ def __init__( self.gradient_decay = gradient_decay self.log_dens_grad_val = log_dens_grad_val - def __call__(self, particles, kernel_length_scale, *target_args): - grads, v0 = self.log_dens_grad_val(particles, *target_args) + def __call__(self, particles, kernel_length_scale, target_args): + grads, v0 = self.log_dens_grad_val(particles, target_args) sigma_sq_log_v0 = grads.mul_(kernel_length_scale * self.gradient_decay) return v0, sigma_sq_log_v0 @@ -61,16 +61,16 @@ def __init__( self.quadrature = quadrature self.log_dens = log_dens - def __call__(self, particles, kernel_length_scale, *args): + def __call__(self, particles, kernel_length_scale, target_args): n_particles, dim = particles.shape quad_pts, quad_wts = self.quadrature(n_particles) particle_quad_pts = quad_pts.mul_(kernel_length_scale).add( particles.reshape(n_particles, 1, -1) ) - log_dens_evals = self.log_dens(particle_quad_pts.reshape(-1, dim)).reshape( - n_particles, -1 - ) + log_dens_evals = self.log_dens( + particle_quad_pts.reshape(-1, dim), target_args + ).reshape(n_particles, -1) sigma_sq_score_v0, log_v0 = vmap_recursive_weighted_average_alpha_v( quad_pts, quad_wts, log_dens_evals ) @@ -92,13 +92,13 @@ def __init__( self.quadrature, self.gradient_decay = quadrature, gradient_decay self.log_dens_grad_val = log_dens_grad_val - def __call__(self, particles, kernel_length_scale, *target_args): + def __call__(self, particles, kernel_length_scale, target_args): quad_pts, quad_wts = self.quadrature(particles.shape[0]) particle_quad_pts = quad_pts.mul_(kernel_length_scale).add( particles.unsqueeze(1) ) # (N_part, N_quad, dim) log_dens_grads, log_dens_evals = self.log_dens_grad_val( - particle_quad_pts.reshape(-1, particles.shape[1]), *target_args + particle_quad_pts.reshape(-1, particles.shape[1]), target_args ) log_dens_grads = log_dens_grads.reshape_as(particle_quad_pts) @@ -131,7 +131,7 @@ def __init__( self.covariances = covariances self.bandwidth = bandwidth - def __call__(self, particles, kernel_length_scale, *_) -> MSIPEstimatorOutput: + def __call__(self, particles, kernel_length_scale, _) -> MSIPEstimatorOutput: N, D = particles.shape dtype, device = particles.dtype, particles.device sigma_sq = torch.as_tensor( diff --git a/src/nak_torch/algorithms/msip/msip.py b/src/nak_torch/algorithms/msip/msip.py index b71cef0..e0fef64 100644 --- a/src/nak_torch/algorithms/msip/msip.py +++ b/src/nak_torch/algorithms/msip/msip.py @@ -1,126 +1,99 @@ -import warnings -from typing import Optional +from dataclasses import astuple, dataclass +from typing import Generic, Optional, TypeVar -from tqdm import tqdm -import numpy as np import torch +from nak_torch.tools.func import WeightedAdaptiveNAKAlgorithm from nak_torch.tools.kernel import default_kernel_matrix -from nak_torch.tools.util import initialize_particles, quantile_distance +from nak_torch.tools.util import quantile_distance from .msip_map import MSIPEstimatorOutput, msip_map, get_msip_wts from .estimators import MSIPEstimator -from .msip_tools import msip_map_used_keys, process_msip_density from nak_torch.tools.types import ( - LogDensity, - BatchLogDensity, - BatchType, + BatchPtType, + KernelMatrixType, MatSelfKernelFunction, ) - -def msip( - log_density: LogDensity | BatchLogDensity | MSIPEstimator, - n_particles: int, - n_steps: int, - dim: int, - lr: float, - kernel_length_scale: float, - noise: float = 0.05, - seed: Optional[int] = None, - device: Optional[torch.device] = None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - bounds: Optional[tuple[float, float]] = None, - keep_all: bool = True, - get_kernel_matrix: Optional[MatSelfKernelFunction] = None, - kernel_diag_infl: float = 0.0, - verbose: bool = False, - use_quantile_length_scale: Optional[float] = None, - compile_step: bool = True, - **msip_kwargs, -): - r""" - TODO: Document - """ - - if n_steps < 0: - raise ValueError("Expected positive number of steps.") - - unused_kwargs = { - k: v for (k, v) in msip_kwargs.items() if k not in msip_map_used_keys - } - - if verbose and len(unused_kwargs) > 0: - warnings.warn("Unused kwargs: {}".format(unused_kwargs)) - - if seed is not None: - torch.manual_seed(seed) - if get_kernel_matrix is None: - get_kernel_matrix = default_kernel_matrix - - msip_estimator = process_msip_density(log_density, **msip_kwargs) - est_v = msip_estimator.get_v_evals - _msip_map = msip_map - _get_msip_wts = get_msip_wts - if compile_step: - _msip_map = torch.compile(msip_map) - _get_msip_wts = torch.compile(_get_msip_wts) - est_v = torch.compile(est_v) - - particles = initialize_particles(n_particles, dim, init_particles, device, bounds) - - if keep_all: - trajectories = torch.empty( - (n_steps + 1, *particles.shape), device=device, dtype=particles.dtype - ) - trajectories[0].copy_(particles) - traj_wts = torch.empty( - (n_steps + 1, particles.shape[0]), device=device, dtype=particles.dtype - ) - else: - trajectories = torch.empty(()) - traj_wts = torch.empty(()) - - msip_estimator_out: MSIPEstimatorOutput - particle_wts: BatchType - for idx in tqdm(range(n_steps + 1), disable=not verbose): - if use_quantile_length_scale is not None: - kernel_length_scale = quantile_distance( - particles, use_quantile_length_scale +MSIPEstimatorOutputT = TypeVar("MSIPEstimatorOutputT", bound=MSIPEstimatorOutput) + + +@dataclass +class MSIPAlgorithmArgs(Generic[MSIPEstimatorOutputT]): + kernel_lengthscale: float + kernel_matrix_inverse: KernelMatrixType + msip_estimator_output: MSIPEstimatorOutputT + + +class MSIP(WeightedAdaptiveNAKAlgorithm[MSIPEstimator, MSIPAlgorithmArgs]): + kernel_diag_infl: float + default_kernel_lengthscale: float + kernel_lengthscale_quantile: Optional[float] + get_kernel_matrix: MatSelfKernelFunction + + def __init__( + self, + *_, + kernel_diag_infl: float = 0.0, + kernel_lengthscale: Optional[float] = None, + kernel_lengthscale_quantile: Optional[float] = None, + get_kernel_matrix: Optional[MatSelfKernelFunction] = None, + ): + self.kernel_diag_infl = kernel_diag_infl + if kernel_lengthscale is None and kernel_lengthscale_quantile is None: + raise ValueError( + "Must have either kernel_lengthscale" + "or kernel_lengthscale_quantile as value" ) - - kernel_matrix = get_kernel_matrix(particles, kernel_length_scale) - kernel_matrix[torch.arange(n_particles), torch.arange(n_particles)] += ( - kernel_diag_infl + self.kernel_lengthscale = kernel_lengthscale + self.kernel_lengthscale_quantile = kernel_lengthscale_quantile + if get_kernel_matrix is None: + self.get_kernel_matrix = default_kernel_matrix + else: + self.get_kernel_matrix = get_kernel_matrix + + def get_adaptive_lengthscale(self, particles: BatchPtType) -> float: + q = self.kernel_lengthscale_quantile + if q is None: + return self.default_kernel_lengthscale + return quantile_distance(particles, q) + + def get_infl_kernel_matrix(self, particles, kernel_lengthscale) -> KernelMatrixType: + kernel_matrix = self.get_kernel_matrix(particles, kernel_lengthscale) + if self.kernel_diag_infl is not None: + kernel_matrix[ + torch.arange(self.n_particles, device=self.device), + torch.arange(self.n_particles, device=self.device), + ] += self.kernel_diag_infl + return kernel_matrix + + def initialize(self, init_particles, target, target_args): + kernel_lengthscale = self.get_adaptive_lengthscale(init_particles) + estimator_output = target(init_particles, kernel_lengthscale, target_args) + kernel_matrix = self.get_infl_kernel_matrix(init_particles, kernel_lengthscale) + wts = get_msip_wts(init_particles, estimator_output, kernel_matrix) + return wts, MSIPAlgorithmArgs( + kernel_lengthscale, kernel_matrix, estimator_output ) - msip_estimator_out = est_v(particles, kernel_length_scale) - particle_wts = _get_msip_wts(particles, msip_estimator_out, kernel_matrix) - - if keep_all: - traj_wts[idx].copy_(particle_wts) - - if idx < n_steps: - if kernel_diag_infl > 0: - kernel_matrix_inverse = torch.linalg.inv(kernel_matrix) - else: - kernel_matrix_inverse = torch.linalg.pinv(kernel_matrix) - - particles_diff = _msip_map( - msip_estimator_out, - particles, - kernel_matrix_inverse, - output_idx=None, - ) - - with torch.no_grad(): - particles = (1.0 - lr) * particles + lr * particles_diff - if bounds is not None: - particles.clamp_(bounds[0], bounds[1]) - if keep_all: - trajectories[idx + 1].copy_(particles) - - if not keep_all: - trajectories = particles.unsqueeze_(0) - traj_wts = particle_wts.unsqueeze_(0) # type: ignore + def step(self, lr, particles, target, algorithm_args, target_args): + kernel_lengthscale, kernel_matrix, estimator_output = astuple(algorithm_args) + kernel_matrix_inverse = torch.linalg.pinv(kernel_matrix) - return trajectories.detach(), traj_wts.detach() + # Update the particles + particles_diff = msip_map( + estimator_output, + particles, + kernel_matrix_inverse, + output_idx=None, + ) + new_particles = particles * (1 - lr) + lr * particles_diff + + # Update the parameters + kernel_lengthscale = self.get_adaptive_lengthscale(new_particles) + kernel_matrix = self.get_infl_kernel_matrix(new_particles, kernel_lengthscale) + msip_estimator_output = target(particles, kernel_lengthscale, target_args) + algorithm_args = MSIPAlgorithmArgs( + kernel_lengthscale, kernel_matrix_inverse, msip_estimator_output + ) + new_weights = get_msip_wts(new_particles, estimator_output, kernel_matrix) + return new_particles, new_weights, algorithm_args diff --git a/src/nak_torch/tools/func.py b/src/nak_torch/tools/func.py index 8cdb85a..5b7bc56 100644 --- a/src/nak_torch/tools/func.py +++ b/src/nak_torch/tools/func.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar +from typing import Any, Generic, Optional, TypeVar import torch from .types import ( BatchPtType, @@ -10,48 +10,54 @@ BatchDensityEvaluatorT = TypeVar("BatchDensityEvaluatorT", bound=BatchDensityEvaluator) AlgorithmArgsT = TypeVar("AlgorithmArgsT") +WeightT = TypeVar("WeightT", bound=Optional[BatchType]) -class AdaptiveNAKAlgorithm(ABC, Generic[BatchDensityEvaluatorT, AlgorithmArgsT]): +class GeneralAdaptiveNAKAlgorithm( + ABC, Generic[BatchDensityEvaluatorT, WeightT, AlgorithmArgsT] +): dim: int n_particles: int device: Optional[DeviceLike] dtype: Optional[torch.dtype] @abstractmethod - def __call__( + def initialize( + self, + init_particles: BatchPtType, + target: BatchDensityEvaluatorT, + target_args: Any, + ) -> tuple[WeightT, AlgorithmArgsT]: + pass + + @abstractmethod + def step( self, lr: float, + particles: BatchPtType, target: BatchDensityEvaluatorT, - points: BatchPtType, algorithm_args: AlgorithmArgsT, - target_args, - ) -> BatchPtType: + target_args: Any, + ) -> tuple[BatchPtType, WeightT, AlgorithmArgsT]: pass + @classmethod @abstractmethod - def update(self, particles: BatchPtType) -> AlgorithmArgsT: + def is_weighted(cls) -> bool: pass - def get_weights(self, points: BatchPtType, target_args) -> BatchType: - N_ens = points.shape[0] - return torch.ones(N_ens, dtype=points.dtype, device=points.device) / N_ens - -class NAKAlgorithm(AdaptiveNAKAlgorithm[BatchDensityEvaluatorT, None]): - def update(self, particles: BatchPtType) -> None: - return None +class UnweightedAdaptiveNAKAlgorithm( + GeneralAdaptiveNAKAlgorithm[BatchDensityEvaluatorT, None, AlgorithmArgsT] +): + @classmethod + def is_weighted(cls) -> bool: + return False class WeightedAdaptiveNAKAlgorithm( - AdaptiveNAKAlgorithm[BatchDensityEvaluatorT, AlgorithmArgsT] + GeneralAdaptiveNAKAlgorithm[BatchDensityEvaluatorT, BatchType, AlgorithmArgsT] ): - @abstractmethod - def get_weights(self, points: BatchPtType, target_args) -> BatchType: - pass - - -class WeightedNAKAlgorithm(NAKAlgorithm[BatchDensityEvaluatorT]): - @abstractmethod - def get_weights(self, points: BatchPtType, target_args) -> BatchType: - pass + @classmethod + def is_weighted(cls) -> bool: + return True diff --git a/src/nak_torch/tools/kernel.py b/src/nak_torch/tools/kernel.py index c730d88..8c7a427 100644 --- a/src/nak_torch/tools/kernel.py +++ b/src/nak_torch/tools/kernel.py @@ -1,5 +1,5 @@ import torch -from typing import Optional, Callable +from typing import Any, Optional, Callable from jaxtyping import Float from torch import Tensor from .types import ( @@ -124,6 +124,7 @@ def process_kernel_jac(x, y, length_scale): def stein_kernel_mat_factory( grad_log_p: GradLogDensity | BatchGradLogDensity, kernel_fcn: KernelFunction, + target_args: Any, is_grad_vectorized: bool = False, use_compiled: bool = True, ) -> MatSelfKernelFunction: @@ -133,12 +134,12 @@ def stein_kernel_mat_factory( def stein_kernel_mat( pts: BatchPtType, kernel_length_scale: float, pts2: Optional[BatchPtType] = None ) -> KernelMatrixType: - grad_log_p_eval1 = grad_log_p_v(pts) + grad_log_p_eval1 = grad_log_p_v(pts, target_args) if pts2 is None: pts2 = pts grad_log_p_eval2 = grad_log_p_eval1 else: - grad_log_p_eval2 = grad_log_p_v(pts2) + grad_log_p_eval2 = grad_log_p_v(pts2, target_args) trace_kernel, grad1_kernel, eval_kernel = kernel_diffs( pts, pts2, kernel_length_scale ) diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index 1e251b0..2362824 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -23,14 +23,14 @@ KernelFunction = Callable[[PtType, PtType, float], Float] -EvaluatorOutput = TypeVar("EvaluatorOutput") +EvaluatorOutputT = TypeVar("EvaluatorOutputT") -class BatchDensityEvaluator(ABC, Generic[EvaluatorOutput]): +class BatchDensityEvaluator(ABC, Generic[EvaluatorOutputT]): @abstractmethod def __call__( - self, particles: BatchPtType, evaluator_args, *target_args - ) -> EvaluatorOutput: + self, particles: BatchPtType, evaluator_args, target_args + ) -> EvaluatorOutputT: pass From fd893542ea338bd58714bbea237c65ec62288312 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Thu, 23 Apr 2026 09:58:52 -0400 Subject: [PATCH 13/60] Consolidate reused MSIP code --- src/nak_torch/algorithms/msip/msip.py | 71 +-------- src/nak_torch/algorithms/msip/msip_gs.py | 161 ++++---------------- src/nak_torch/algorithms/msip/msip_tools.py | 78 +++++++++- 3 files changed, 113 insertions(+), 197 deletions(-) diff --git a/src/nak_torch/algorithms/msip/msip.py b/src/nak_torch/algorithms/msip/msip.py index e0fef64..0ce442e 100644 --- a/src/nak_torch/algorithms/msip/msip.py +++ b/src/nak_torch/algorithms/msip/msip.py @@ -1,71 +1,12 @@ -from dataclasses import astuple, dataclass -from typing import Generic, Optional, TypeVar +from dataclasses import astuple import torch -from nak_torch.tools.func import WeightedAdaptiveNAKAlgorithm -from nak_torch.tools.kernel import default_kernel_matrix -from nak_torch.tools.util import quantile_distance -from .msip_map import MSIPEstimatorOutput, msip_map, get_msip_wts -from .estimators import MSIPEstimator -from nak_torch.tools.types import ( - BatchPtType, - KernelMatrixType, - MatSelfKernelFunction, -) +from nak_torch.algorithms.msip.msip_tools import GeneralMSIPAlgorithm, MSIPAlgorithmArgs +from .msip_map import msip_map, get_msip_wts -MSIPEstimatorOutputT = TypeVar("MSIPEstimatorOutputT", bound=MSIPEstimatorOutput) - - -@dataclass -class MSIPAlgorithmArgs(Generic[MSIPEstimatorOutputT]): - kernel_lengthscale: float - kernel_matrix_inverse: KernelMatrixType - msip_estimator_output: MSIPEstimatorOutputT - - -class MSIP(WeightedAdaptiveNAKAlgorithm[MSIPEstimator, MSIPAlgorithmArgs]): - kernel_diag_infl: float - default_kernel_lengthscale: float - kernel_lengthscale_quantile: Optional[float] - get_kernel_matrix: MatSelfKernelFunction - - def __init__( - self, - *_, - kernel_diag_infl: float = 0.0, - kernel_lengthscale: Optional[float] = None, - kernel_lengthscale_quantile: Optional[float] = None, - get_kernel_matrix: Optional[MatSelfKernelFunction] = None, - ): - self.kernel_diag_infl = kernel_diag_infl - if kernel_lengthscale is None and kernel_lengthscale_quantile is None: - raise ValueError( - "Must have either kernel_lengthscale" - "or kernel_lengthscale_quantile as value" - ) - self.kernel_lengthscale = kernel_lengthscale - self.kernel_lengthscale_quantile = kernel_lengthscale_quantile - if get_kernel_matrix is None: - self.get_kernel_matrix = default_kernel_matrix - else: - self.get_kernel_matrix = get_kernel_matrix - - def get_adaptive_lengthscale(self, particles: BatchPtType) -> float: - q = self.kernel_lengthscale_quantile - if q is None: - return self.default_kernel_lengthscale - return quantile_distance(particles, q) - - def get_infl_kernel_matrix(self, particles, kernel_lengthscale) -> KernelMatrixType: - kernel_matrix = self.get_kernel_matrix(particles, kernel_lengthscale) - if self.kernel_diag_infl is not None: - kernel_matrix[ - torch.arange(self.n_particles, device=self.device), - torch.arange(self.n_particles, device=self.device), - ] += self.kernel_diag_infl - return kernel_matrix +class MSIP(GeneralMSIPAlgorithm[MSIPAlgorithmArgs]): def initialize(self, init_particles, target, target_args): kernel_lengthscale = self.get_adaptive_lengthscale(init_particles) estimator_output = target(init_particles, kernel_lengthscale, target_args) @@ -86,14 +27,14 @@ def step(self, lr, particles, target, algorithm_args, target_args): kernel_matrix_inverse, output_idx=None, ) - new_particles = particles * (1 - lr) + lr * particles_diff + new_particles = particles.mul(1 - lr).add_(particles_diff.mul_(lr)) # Update the parameters kernel_lengthscale = self.get_adaptive_lengthscale(new_particles) kernel_matrix = self.get_infl_kernel_matrix(new_particles, kernel_lengthscale) msip_estimator_output = target(particles, kernel_lengthscale, target_args) algorithm_args = MSIPAlgorithmArgs( - kernel_lengthscale, kernel_matrix_inverse, msip_estimator_output + kernel_lengthscale, kernel_matrix, msip_estimator_output ) new_weights = get_msip_wts(new_particles, estimator_output, kernel_matrix) return new_particles, new_weights, algorithm_args diff --git a/src/nak_torch/algorithms/msip/msip_gs.py b/src/nak_torch/algorithms/msip/msip_gs.py index a51c565..693621e 100644 --- a/src/nak_torch/algorithms/msip/msip_gs.py +++ b/src/nak_torch/algorithms/msip/msip_gs.py @@ -1,139 +1,42 @@ -import warnings -from typing import Optional +from dataclasses import astuple -from tqdm import tqdm -import numpy as np import torch -from nak_torch.tools.kernel import default_kernel_matrix -from nak_torch.tools.util import initialize_particles, quantile_distance from .msip_map import msip_map, get_msip_wts -from .estimators import MSIPEstimator -from .msip_tools import msip_map_used_keys, process_msip_density - -from nak_torch.tools.types import ( - LogDensity, - BatchLogDensity, - BatchType, - MatSelfKernelFunction, -) - - -# Gauss-Seidel variant of MSIP. -def msip_gs( - log_density: LogDensity | BatchLogDensity | MSIPEstimator, - n_particles: int, - n_steps: int, - dim: int, - lr: float, - kernel_length_scale: float, - noise: float = 0.05, - seed: Optional[int] = None, - device: Optional[torch.device] = None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - bounds: Optional[tuple[float, float]] = None, - keep_all: bool = True, - get_kernel_matrix: Optional[MatSelfKernelFunction] = None, - kernel_diag_infl: float = 0.0, - verbose: bool = False, - use_quantile_length_scale: Optional[float] = None, - compile_step: bool = True, - **msip_kwargs, -): - r""" - TODO: Document - """ - - if n_steps < 0: - raise ValueError("Expected positive number of steps.") - - unused_kwargs = { - k: v for (k, v) in msip_kwargs.items() if k not in msip_map_used_keys - } - - if verbose and len(unused_kwargs) > 0: - warnings.warn("Unused kwargs: {}".format(unused_kwargs)) - - if seed is not None: - torch.manual_seed(seed) - if get_kernel_matrix is None: - get_kernel_matrix = default_kernel_matrix - - msip_estimator = process_msip_density(log_density, **msip_kwargs) - est_v = msip_estimator.get_v_evals - _msip_map = msip_map - _get_msip_wts = get_msip_wts - if compile_step: - _msip_map = torch.compile(msip_map) - _get_msip_wts = torch.compile(_get_msip_wts) - est_v = torch.compile(est_v) - - particles = initialize_particles(n_particles, dim, init_particles, device, bounds) - - if keep_all: - trajectories = torch.empty( - (n_steps + 1, *particles.shape), device=device, dtype=particles.dtype - ) - trajectories[0].copy_(particles) - traj_wts = torch.empty( - (n_steps + 1, particles.shape[0]), device=device, dtype=particles.dtype - ) - else: - trajectories = torch.empty(()) - traj_wts = torch.empty(()) - - particle_wts: BatchType = torch.tensor(()) - - if use_quantile_length_scale is not None: - kernel_length_scale = quantile_distance(particles, use_quantile_length_scale) - est_out = est_v(particles, kernel_length_scale) - - # est_out should keep references to est_out_0 and est_out_1 - est_out_0, est_out_1 = est_out - for step in tqdm(range(n_steps + 1), disable=not verbose): - if use_quantile_length_scale is not None: - kernel_length_scale = quantile_distance( - particles, use_quantile_length_scale - ) - - for i in range(n_particles): - km_i = get_kernel_matrix(particles, kernel_length_scale) - if kernel_diag_infl > 0: - km_i[torch.arange(n_particles), torch.arange(n_particles)] += ( - kernel_diag_infl - ) - - est_out_i_0, est_out_i_1 = est_v( - particles[i].unsqueeze(0), kernel_length_scale +from .msip_tools import GeneralMSIPAlgorithm, MSIPGSAlgorithmArgs + + +class MSIPGS(GeneralMSIPAlgorithm[MSIPGSAlgorithmArgs]): + def initialize(self, init_particles, target, target_args): + kernel_lengthscale = self.get_adaptive_lengthscale(init_particles) + estimator_output = target(init_particles, kernel_lengthscale, target_args) + kernel_matrix = self.get_infl_kernel_matrix(init_particles, kernel_lengthscale) + wts = get_msip_wts(init_particles, estimator_output, kernel_matrix) + return wts, MSIPGSAlgorithmArgs(kernel_lengthscale, estimator_output) + + def step(self, lr, particles, target, algorithm_args, target_args): + kernel_lengthscale, _, estimator_output = astuple(algorithm_args) + est_out_0, est_out_1 = estimator_output + new_particles = particles.clone() + for i in range(particles.shape[0]): + km_i = self.get_infl_kernel_matrix(particles, kernel_lengthscale) + km_inv_i = torch.linalg.pinv(km_i) + est_out_i_0, est_out_i_1 = target( + new_particles[i].unsqueeze(0), kernel_lengthscale, target_args ) est_out_0[i].copy_(est_out_i_0.squeeze()) est_out_1[i].copy_(est_out_i_1.squeeze()) - particle_wts = _get_msip_wts(particles, est_out, km_i) - - if keep_all and i == n_particles - 1: - traj_wts[step].copy_(particle_wts) - - if step >= n_steps: - continue - - if kernel_diag_infl > 0: - km_inv_i = torch.linalg.inv(km_i) - else: - km_inv_i = torch.linalg.pinv(km_i) - - target_i = _msip_map(est_out, particles, km_inv_i, output_idx=i) - - with torch.no_grad(): - particles[i].mul_(1.0 - lr).add_(target_i.mul_(lr)) - if bounds is not None: - particles[i].clamp_(bounds[0], bounds[1]) - - if keep_all and step < n_steps: - trajectories[step + 1].copy_(particles) + target_i = msip_map(estimator_output, particles, km_inv_i, output_idx=i) - if not keep_all: - trajectories = particles.unsqueeze_(0) - traj_wts = particle_wts.unsqueeze_(0) + new_particles[i].mul_(1.0 - lr).add_(target_i.mul_(lr)) - return trajectories.detach(), traj_wts.detach() + # Update the parameters + new_kernel_lengthscale = self.get_adaptive_lengthscale(new_particles) + kernel_matrix = self.get_infl_kernel_matrix(new_particles, kernel_lengthscale) + if new_kernel_lengthscale != kernel_lengthscale: + estimator_output = target(particles, new_kernel_lengthscale, target_args) + kernel_lengthscale = new_kernel_lengthscale + algorithm_args = MSIPGSAlgorithmArgs(kernel_lengthscale, estimator_output) + new_weights = get_msip_wts(new_particles, estimator_output, kernel_matrix) + return new_particles, new_weights, algorithm_args diff --git a/src/nak_torch/algorithms/msip/msip_tools.py b/src/nak_torch/algorithms/msip/msip_tools.py index 5276439..7a880aa 100644 --- a/src/nak_torch/algorithms/msip/msip_tools.py +++ b/src/nak_torch/algorithms/msip/msip_tools.py @@ -1,16 +1,88 @@ +from dataclasses import dataclass +from typing import Generic, Optional, TypeVar + import torch -from nak_torch.tools.util import get_keywords +from nak_torch.tools.func import AlgorithmArgsT, WeightedAdaptiveNAKAlgorithm +from nak_torch.tools.kernel import default_kernel_matrix +from nak_torch.tools.util import get_keywords, quantile_distance from .msip_map import msip_map from .estimators import MSIPEstimator, MSIPFredholm from nak_torch.tools.types import ( + BatchPtType, + KernelMatrixType, LogDensity, BatchLogDensity, BatchLogDensityGradVal, + MSIPEstimatorOutput, + MatSelfKernelFunction, ) +MSIPEstimatorOutputT = TypeVar("MSIPEstimatorOutputT", bound=MSIPEstimatorOutput) +MSIPAlgorithmArgsT = TypeVar("MSIPAlgorithmArgsT") + + +@dataclass +class MSIPAlgorithmArgs(Generic[MSIPEstimatorOutputT]): + kernel_lengthscale: float + kernel_matrix: KernelMatrixType + msip_estimator_output: MSIPEstimatorOutputT + + +@dataclass +class MSIPGSAlgorithmArgs(Generic[MSIPEstimatorOutputT]): + kernel_lengthscale: float + msip_estimator_output: MSIPEstimatorOutputT + + +class GeneralMSIPAlgorithm(WeightedAdaptiveNAKAlgorithm[MSIPEstimator, AlgorithmArgsT]): + kernel_diag_infl: Optional[float] + default_kernel_lengthscale: float + kernel_lengthscale_quantile: Optional[float] + get_kernel_matrix: MatSelfKernelFunction + + def __init__( + self, + *_, + kernel_diag_infl: Optional[float] = None, + kernel_lengthscale: Optional[float] = None, + kernel_lengthscale_quantile: Optional[float] = None, + get_kernel_matrix: Optional[MatSelfKernelFunction] = None, + ): + self.kernel_diag_infl = kernel_diag_infl + if kernel_lengthscale is None and kernel_lengthscale_quantile is None: + raise ValueError( + "Must have either kernel_lengthscale" + "or kernel_lengthscale_quantile as value" + ) + if kernel_lengthscale is None: + self.default_kernel_lengthscale = 0.0 + else: + self.default_kernel_lengthscale = kernel_lengthscale + self.kernel_lengthscale_quantile = kernel_lengthscale_quantile + if get_kernel_matrix is None: + self.get_kernel_matrix = default_kernel_matrix + else: + self.get_kernel_matrix = get_kernel_matrix + + def get_adaptive_lengthscale(self, particles: BatchPtType) -> float: + q = self.kernel_lengthscale_quantile + if q is None: + return self.default_kernel_lengthscale + return quantile_distance(particles, q) + + def get_infl_kernel_matrix(self, particles, kernel_lengthscale) -> KernelMatrixType: + kernel_matrix = self.get_kernel_matrix(particles, kernel_lengthscale) + if self.kernel_diag_infl is not None: + kernel_matrix[ + torch.arange(self.n_particles, device=self.device), + torch.arange(self.n_particles, device=self.device), + ] += self.kernel_diag_infl + return kernel_matrix + + def process_msip_density( log_density: LogDensity | BatchLogDensity | MSIPEstimator, *_, @@ -23,8 +95,8 @@ def process_msip_density( log_density_grad_val: BatchLogDensityGradVal if is_log_density_batched: - def dens_eval(_p): - out = log_density(_p) + def dens_eval(_p, target_args): + out = log_density(_p, target_args) return out.sum(), out log_density_grad_val = torch.func.grad(dens_eval, has_aux=True) From a17e3776eb0380406f1e8430c174344ef5c72821 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Thu, 23 Apr 2026 14:56:25 -0400 Subject: [PATCH 14/60] Finish initial implementations for the most part --- src/nak_torch/algorithms/cbs.py | 105 +++++++------- src/nak_torch/algorithms/grad_aldi.py | 95 +++++-------- src/nak_torch/algorithms/loop.py | 4 +- src/nak_torch/algorithms/msip/estimators.py | 4 +- src/nak_torch/algorithms/msip/msip_tools.py | 6 + src/nak_torch/algorithms/svgd.py | 147 +++++++++++++------- src/nak_torch/tools/func.py | 26 +++- src/nak_torch/tools/kernel.py | 7 +- src/nak_torch/tools/types.py | 46 +++++- 9 files changed, 254 insertions(+), 186 deletions(-) diff --git a/src/nak_torch/algorithms/cbs.py b/src/nak_torch/algorithms/cbs.py index cde82c8..ccbdff4 100644 --- a/src/nak_torch/algorithms/cbs.py +++ b/src/nak_torch/algorithms/cbs.py @@ -1,10 +1,15 @@ +from dataclasses import astuple, dataclass + import torch from typing import Optional -from nak_torch.tools.types import BatchType, BatchPtType -import warnings -from tqdm import tqdm -import numpy as np -from nak_torch.tools.util import initialize_particles, sym_sqrtm +from nak_torch.tools.func import UnweightedAdaptiveNAKAlgorithm +from nak_torch.tools.types import ( + BatchLogDensityEvaluator, + BatchType, + BatchPtType, + DeviceLike, +) +from nak_torch.tools.util import sym_sqrtm def cbs_step( @@ -28,61 +33,45 @@ def cbs_step( return drift_term, motion_term -def cbs( - log_density, - n_particles: int, - n_steps: int, - dim: int, - lr: float, - inverse_temp: float, - seed: Optional[int] = None, - device: Optional[torch.device] = None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - bounds: Optional[tuple[float, float]] = None, - rng: Optional[torch.Generator] = None, - keep_all: bool = True, - is_log_density_batched: bool = False, - verbose: bool = False, - compile_step: bool = True, - **unused_kwargs, -): - if verbose and len(unused_kwargs) > 0: - warnings.warn("Unused kwargs:\n{}".format(unused_kwargs)) - - if rng is None: - rng = torch.default_generator - if seed is not None: - rng.manual_seed(seed) +@dataclass +class CBSAlgorithmArgs: + inverse_temp: float + motion_scaling_sq_div_lr: float - particles = initialize_particles( - n_particles, dim, init_particles, device, bounds, rng - ) - if keep_all: - trajectories = torch.empty( - (n_steps, *particles.shape), device=device, dtype=particles.dtype - ) - trajectories[0].copy_(particles) - else: - trajectories = torch.empty(()) - _cbs_step = cbs_step - if compile_step: - _cbs_step = torch.compile(cbs_step) +class CBSAlgorithm( + UnweightedAdaptiveNAKAlgorithm[BatchLogDensityEvaluator, CBSAlgorithmArgs] +): + default_inverse_temp: float + rng: torch.Generator - log_p = log_density if is_log_density_batched else torch.vmap(log_density) - motion_scaling_sq = lr * 2 * (1 + inverse_temp) + def __init__( + self, + dim: int, + n_particles: int, + device: Optional[DeviceLike] = None, + dtype: Optional[torch.dtype] = None, + *_, + default_inverse_temp: float, + rng: torch.Generator, + ): + super().__init__(dim, n_particles, device, dtype) + self.default_inverse_temp = default_inverse_temp + self.rng = rng - for idx in tqdm(range(n_steps), disable=not verbose): - log_dens_eval = log_p(particles) - with torch.no_grad(): - particles_diff, particles_noise = _cbs_step( - particles, log_dens_eval, inverse_temp, motion_scaling_sq, rng - ) - particles_diff.mul_(lr) - particles = particles.add_(particles_diff).add_(particles_noise) - if bounds is not None: - particles.clamp_(bounds[0], bounds[1]) - if keep_all: - trajectories[idx].copy_(particles) + def initialize(self, init_particles, target, target_args): + inverse_temp = self.default_inverse_temp + motion_scaling_sq_div_lr = 2 * (1 + inverse_temp) + alg_args = CBSAlgorithmArgs(inverse_temp, motion_scaling_sq_div_lr) + return None, alg_args - return trajectories.detach() if keep_all else particles.unsqueeze_(0) + def step(self, lr, particles, target, algorithm_args, target_args): + inverse_temp, motion_scaling_sq_div_lr = astuple(algorithm_args) + motion_scaling_sq = motion_scaling_sq_div_lr * lr + log_dens_eval = target(particles, None, target_args) + particles_diff, particles_noise = cbs_step( + particles, log_dens_eval, inverse_temp, motion_scaling_sq, self.rng + ) + particles_diff.mul_(lr) + new_particles = particles_diff.add_(particles).add_(particles_noise) + return new_particles, None, algorithm_args diff --git a/src/nak_torch/algorithms/grad_aldi.py b/src/nak_torch/algorithms/grad_aldi.py index 06079fc..34264a8 100644 --- a/src/nak_torch/algorithms/grad_aldi.py +++ b/src/nak_torch/algorithms/grad_aldi.py @@ -1,12 +1,12 @@ import torch from typing import Optional -from nak_torch.tools.types import BatchGradLogDensity, BatchPtType -import warnings -from tqdm import tqdm -import numpy as np +from nak_torch.tools.func import UnweightedAdaptiveNAKAlgorithm +from nak_torch.tools.types import ( + BatchGradLogDensityEvaluator, + BatchPtType, + DeviceLike, +) from nak_torch.tools.util import ( - batched_grad_log_density_factory, - initialize_particles, sym_sqrtm, ) @@ -36,64 +36,35 @@ def grad_aldi_step( return drift_term, particles_noise -def grad_aldi( - log_density, - n_particles: int, - n_steps: int, - dim: int, - lr: float, - seed: Optional[int] = None, - device: Optional[torch.device] = None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - bounds: Optional[tuple[float, float]] = None, - rng: Optional[torch.Generator] = None, - keep_all: bool = True, - is_log_density_batched: bool = False, - grad_log_density: Optional[BatchGradLogDensity] = None, - verbose: bool = False, - compile_step: bool = True, - **unused_kwargs, +class GradALDIAlgorithm( + UnweightedAdaptiveNAKAlgorithm[BatchGradLogDensityEvaluator, None] ): - if verbose and len(unused_kwargs) > 0: - warnings.warn("Unused kwargs:\n{}".format(unused_kwargs)) + rng: torch.Generator - if rng is None: - rng = torch.default_generator - if seed is not None: - rng.manual_seed(seed) + def _sqrt(self, x: float): + return torch.sqrt_(torch.as_tensor(x, device=self.device, dtype=self.dtype)) - grad_log_p = batched_grad_log_density_factory( - log_density, is_log_density_batched, grad_log_density - ) - particles = initialize_particles( - n_particles, dim, init_particles, device, bounds, rng - ) + def __init__( + self, + dim: int, + n_particles: int, + device: Optional[DeviceLike] = None, + dtype: Optional[torch.dtype] = None, + *_, + rng: torch.Generator, + ): + super().__init__(dim, n_particles, device, dtype) + self.rng = rng - if keep_all: - trajectories = torch.empty( - (n_steps, *particles.shape), device=device, dtype=particles.dtype - ) - trajectories[0].copy_(particles) - else: - trajectories = torch.empty(()) + def initialize(self, init_particles, target, target_args): + return None, None - sqrt_lr = torch.sqrt(torch.tensor(lr)) - g_aldi_step = grad_aldi_step - if compile_step: - g_aldi_step = torch.compile(g_aldi_step) - - for idx in tqdm(range(n_steps), disable=not verbose): - grad_log_dens_eval = grad_log_p(particles) - with torch.no_grad(): - particles_diff, particles_noise = g_aldi_step( - particles, grad_log_dens_eval, rng - ) - particles_diff.mul_(lr) - particles_noise.mul_(sqrt_lr) - particles.add_(particles_diff).add_(particles_noise) - if bounds is not None: - particles.clamp_(bounds[0], bounds[1]) - if keep_all: - trajectories[idx].copy_(particles) - - return trajectories.detach() if keep_all else particles.unsqueeze_(0) + def step(self, lr, particles, target, algorithm_args, target_args): + grad_log_dens_evals = target(particles, None, target_args) + particles_diff, particles_noise = grad_aldi_step( + particles, grad_log_dens_evals, self.rng + ) + particles_diff.mul_(lr) + particles_noise.mul_(self._sqrt(lr)) + new_particles = particles_diff.add_(particles).add_(particles_noise) + return new_particles, None, algorithm_args diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index 9a472e0..8389a98 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -7,7 +7,7 @@ from nak_torch.tools.util import initialize_particles from nak_torch.tools.types import ( - BatchDensityEvaluator, + BatchTargetEvaluator, ) from nak_torch.tools.func import ( @@ -16,7 +16,7 @@ def nak( - target: BatchDensityEvaluator, + target: BatchTargetEvaluator, algorithm: GeneralAdaptiveNAKAlgorithm, n_particles: int, n_steps: int, diff --git a/src/nak_torch/algorithms/msip/estimators.py b/src/nak_torch/algorithms/msip/estimators.py index 1beba53..8ded7c6 100644 --- a/src/nak_torch/algorithms/msip/estimators.py +++ b/src/nak_torch/algorithms/msip/estimators.py @@ -6,7 +6,7 @@ BatchLogDensityGradVal, BatchLogDensity, BatchQuadratureRule, - BatchDensityEvaluator, + BatchTargetEvaluator, ) from jaxtyping import Float from torch import Tensor @@ -14,7 +14,7 @@ __all__ = ["MSIPFredholm", "MSIPQuadGradientFree", "MSIPQuadGradientInformed"] -class MSIPEstimator(BatchDensityEvaluator[MSIPEstimatorOutput]): +class MSIPEstimator(BatchTargetEvaluator[MSIPEstimatorOutput]): @abstractmethod def __call__(self, particles, evaluator_args, target_args) -> MSIPEstimatorOutput: r""" diff --git a/src/nak_torch/algorithms/msip/msip_tools.py b/src/nak_torch/algorithms/msip/msip_tools.py index 7a880aa..8bedb8f 100644 --- a/src/nak_torch/algorithms/msip/msip_tools.py +++ b/src/nak_torch/algorithms/msip/msip_tools.py @@ -11,6 +11,7 @@ from nak_torch.tools.types import ( BatchPtType, + DeviceLike, KernelMatrixType, LogDensity, BatchLogDensity, @@ -45,12 +46,17 @@ class GeneralMSIPAlgorithm(WeightedAdaptiveNAKAlgorithm[MSIPEstimator, Algorithm def __init__( self, + dim: int, + n_particles: int, + device: Optional[DeviceLike] = None, + dtype: Optional[torch.dtype] = None, *_, kernel_diag_infl: Optional[float] = None, kernel_lengthscale: Optional[float] = None, kernel_lengthscale_quantile: Optional[float] = None, get_kernel_matrix: Optional[MatSelfKernelFunction] = None, ): + super().__init__(dim, n_particles, device, dtype) self.kernel_diag_infl = kernel_diag_infl if kernel_lengthscale is None and kernel_lengthscale_quantile is None: raise ValueError( diff --git a/src/nak_torch/algorithms/svgd.py b/src/nak_torch/algorithms/svgd.py index 533ecc3..1792a5c 100644 --- a/src/nak_torch/algorithms/svgd.py +++ b/src/nak_torch/algorithms/svgd.py @@ -6,14 +6,20 @@ # Ayoub Belhadji # 05/12/2025 -import warnings -import numpy as np -import torch +from dataclasses import astuple, dataclass from typing import Optional, Callable -from tqdm import tqdm -from nak_torch.tools.kernel import sqexp_kernel_elem, kernel_grad_and_value_factory -from nak_torch.tools.types import KernelFunction, BatchGradLogDensity, BatchPtType -from nak_torch.tools.util import batched_grad_log_density_factory, initialize_particles +import torch +from nak_torch.tools.func import UnweightedAdaptiveNAKAlgorithm +from nak_torch.tools.kernel import kernel_grad_and_value_factory, default_kernel_elem +from nak_torch.tools.types import ( + BatchGradLogDensityEvaluator, + BatchKernelGradValFunction, + DeviceLike, + KernelFunction, + BatchGradLogDensity, + BatchPtType, +) +from nak_torch.tools.util import quantile_distance def create_svgd_step( @@ -39,52 +45,91 @@ def svgd_step_dir(points: BatchPtType): return svgd_step_dir -def svgd( - log_density, - n_particles: int, - n_steps: int, - dim: int, - lr: float, - seed: Optional[int] = None, - device: Optional[torch.device] = None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - kernel_length_scale: float = 1.0, - kernel_elem: KernelFunction = sqexp_kernel_elem, - bounds: Optional[tuple[float, float]] = None, - keep_all: bool = True, - is_log_density_batched: bool = False, - grad_log_density: Optional[BatchGradLogDensity] = None, - verbose: bool = False, - **unused_kwargs, +def create_svgd_kernel_grad_val( + kernel_elem: KernelFunction, +) -> BatchKernelGradValFunction: + which_argnum = 1 + kernel_grad_val = torch.func.grad_and_value(kernel_elem, argnums=which_argnum) + kernel_grad_val_vec = torch.vmap( + torch.vmap(kernel_grad_val, in_dims=(None, 0, None)), in_dims=(0, None, None) + ) + return kernel_grad_val_vec + + +def svgd_step( + kernel_grad_val: BatchKernelGradValFunction, + points: BatchPtType, + grad_log_dens: BatchPtType, + kernel_elem_args, +) -> BatchPtType: + k_grad, k_eval = kernel_grad_val(points, points, kernel_elem_args) + # lpg[j,ell] = grad(x_j[ell]) log_p(x_j) + log_p_grad_ev = grad_log_dens + # term_1[i, ell] = sum_j k(i, j) grad(x_j[ell]) log_p(x_j) + term_1: BatchPtType = k_eval @ log_p_grad_ev + # term_2[i, ell] = sum_j grad(x_j[ell]) k(x_i, x_j) + term_2: BatchPtType = k_grad.sum(1) + return (term_1 + term_2) / points.shape[0] + + +@dataclass +class SVGDAlgorithmArgs: + kernel_lengthscale: float + + +class SVGDAlgorithm( + UnweightedAdaptiveNAKAlgorithm[BatchGradLogDensityEvaluator, SVGDAlgorithmArgs] ): - if verbose and len(unused_kwargs) > 0: - warnings.warn("Unused kwargs:\n{}".format(unused_kwargs)) + default_kernel_lengthscale: float + kernel_lengthscale_quantile: Optional[float] + kernel_grad_val: BatchKernelGradValFunction - if seed is not None: - torch.manual_seed(seed) + def get_adaptive_lengthscale(self, particles: BatchPtType) -> float: + q = self.kernel_lengthscale_quantile + if q is None: + return self.default_kernel_lengthscale + return quantile_distance(particles, q) - particles = initialize_particles(n_particles, dim, init_particles, device, bounds) + def __init__( + self, + dim: int, + n_particles: int, + device: Optional[DeviceLike] = None, + dtype: Optional[torch.dtype] = None, + *_, + default_kernel_lengthscale: Optional[float] = None, + kernel_lengthscale_quantile: Optional[float] = None, + kernel_elem: Optional[KernelFunction] = None, + ): + super().__init__(dim, n_particles, device, dtype) + if default_kernel_lengthscale is None and kernel_lengthscale_quantile is None: + raise ValueError( + "Must provide either default_kernel_lengthscale or kernel_lengthscale_quantile" + ) + if kernel_lengthscale_quantile is not None and ( + kernel_lengthscale_quantile < 0 or kernel_lengthscale_quantile > 1 + ): + raise ValueError( + f"Expected kernel_lengthscale_quantile in [0,1], given {kernel_lengthscale_quantile}" + ) + if default_kernel_lengthscale is None: + default_kernel_lengthscale = 0.0 + if kernel_elem is None: + kernel_elem = default_kernel_elem + self.default_kernel_lengthscale = default_kernel_lengthscale + self.kernel_lengthscale_quantile = kernel_lengthscale_quantile + self.kernel_grad_val = create_svgd_kernel_grad_val(kernel_elem) - if keep_all: - trajectories = torch.empty( - (n_steps + 1, *particles.shape), device=device, dtype=particles.dtype - ) - trajectories[0].copy_(particles) - else: - trajectories = torch.empty(()) + def initialize(self, init_particles, target, target_args): + kernel_lengthscale = self.get_adaptive_lengthscale(init_particles) + return None, SVGDAlgorithmArgs(kernel_lengthscale) - grad_log_p = batched_grad_log_density_factory( - log_density, is_log_density_batched, grad_log_density - ) - step_fcn = create_svgd_step(kernel_elem, grad_log_p, kernel_length_scale) - - for idx in tqdm(range(n_steps), disable=not verbose): - particles_diff = step_fcn(particles) - with torch.no_grad(): - particles = particles + lr * particles_diff - if bounds is not None: - particles.clamp_(bounds[0], bounds[1]) - if keep_all: - trajectories[idx + 1].copy_(particles) - - return trajectories.detach() if keep_all else particles.unsqueeze_(0) + def step(self, lr, particles, target, algorithm_args, target_args): + (kernel_lengthscale,) = astuple(algorithm_args) + grad_log_dens_eval = target(particles, None, target_args) + particles_diff = svgd_step( + self.kernel_grad_val, particles, grad_log_dens_eval, kernel_lengthscale + ) + new_particles = particles_diff.mul_(lr).add_(particles) + new_kernel_lengthscale = self.get_adaptive_lengthscale(new_particles) + return new_particles, None, SVGDAlgorithmArgs(new_kernel_lengthscale) diff --git a/src/nak_torch/tools/func.py b/src/nak_torch/tools/func.py index 5b7bc56..07c282a 100644 --- a/src/nak_torch/tools/func.py +++ b/src/nak_torch/tools/func.py @@ -5,27 +5,39 @@ BatchPtType, BatchType, DeviceLike, - BatchDensityEvaluator, + BatchTargetEvaluator, ) -BatchDensityEvaluatorT = TypeVar("BatchDensityEvaluatorT", bound=BatchDensityEvaluator) +BatchTargetEvaluatorT = TypeVar("BatchTargetEvaluatorT", bound=BatchTargetEvaluator) AlgorithmArgsT = TypeVar("AlgorithmArgsT") WeightT = TypeVar("WeightT", bound=Optional[BatchType]) class GeneralAdaptiveNAKAlgorithm( - ABC, Generic[BatchDensityEvaluatorT, WeightT, AlgorithmArgsT] + ABC, Generic[BatchTargetEvaluatorT, WeightT, AlgorithmArgsT] ): dim: int n_particles: int device: Optional[DeviceLike] dtype: Optional[torch.dtype] + def __init__( + self, + dim: int, + n_particles: int, + device: Optional[DeviceLike], + dtype: Optional[torch.dtype], + ): + self.dim = dim + self.n_particles = n_particles + self.device = device + self.dtype = dtype + @abstractmethod def initialize( self, init_particles: BatchPtType, - target: BatchDensityEvaluatorT, + target: BatchTargetEvaluatorT, target_args: Any, ) -> tuple[WeightT, AlgorithmArgsT]: pass @@ -35,7 +47,7 @@ def step( self, lr: float, particles: BatchPtType, - target: BatchDensityEvaluatorT, + target: BatchTargetEvaluatorT, algorithm_args: AlgorithmArgsT, target_args: Any, ) -> tuple[BatchPtType, WeightT, AlgorithmArgsT]: @@ -48,7 +60,7 @@ def is_weighted(cls) -> bool: class UnweightedAdaptiveNAKAlgorithm( - GeneralAdaptiveNAKAlgorithm[BatchDensityEvaluatorT, None, AlgorithmArgsT] + GeneralAdaptiveNAKAlgorithm[BatchTargetEvaluatorT, None, AlgorithmArgsT] ): @classmethod def is_weighted(cls) -> bool: @@ -56,7 +68,7 @@ def is_weighted(cls) -> bool: class WeightedAdaptiveNAKAlgorithm( - GeneralAdaptiveNAKAlgorithm[BatchDensityEvaluatorT, BatchType, AlgorithmArgsT] + GeneralAdaptiveNAKAlgorithm[BatchTargetEvaluatorT, BatchType, AlgorithmArgsT] ): @classmethod def is_weighted(cls) -> bool: diff --git a/src/nak_torch/tools/kernel.py b/src/nak_torch/tools/kernel.py index 8c7a427..1288ff5 100644 --- a/src/nak_torch/tools/kernel.py +++ b/src/nak_torch/tools/kernel.py @@ -38,9 +38,6 @@ def sqexp_kernel_matrix( ) -default_kernel_matrix = sqexp_kernel_matrix - - def sqexp_kernel_elem(x: PtType, y: PtType, kernel_length_scale: float) -> Float: torch._assert( x.shape == y.shape and y.ndim == 1, "Invalid input dimensions of x and y" @@ -51,6 +48,10 @@ def sqexp_kernel_elem(x: PtType, y: PtType, kernel_length_scale: float) -> Float return ret +default_kernel_elem = sqexp_kernel_elem +default_kernel_matrix = sqexp_kernel_matrix + + def inverse_multi_quadric_kernel_elem( x: PtType, y: PtType, kernel_length_scale: float ) -> Float: diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index 2362824..76a14cd 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -18,15 +18,19 @@ KernelMatrixType = Float[Tensor, "batch batch"] GradKernelMatrixType = Float[Tensor, "batch batch d"] + DensityGradValOutput = tuple[BatchPtType, BatchType] MSIPEstimatorOutput = tuple[BatchType, BatchPtType] KernelFunction = Callable[[PtType, PtType, float], Float] +BatchKernelGradValFunction = Callable[ + [BatchPtType, BatchPtType, Any], tuple[BatchPtType, Float] +] EvaluatorOutputT = TypeVar("EvaluatorOutputT") -class BatchDensityEvaluator(ABC, Generic[EvaluatorOutputT]): +class BatchTargetEvaluator(ABC, Generic[EvaluatorOutputT]): @abstractmethod def __call__( self, particles: BatchPtType, evaluator_args, target_args @@ -64,6 +68,46 @@ def __call__( ] +class BatchLogDensityEvaluator(BatchTargetEvaluator[BatchType]): + log_density: BatchLogDensity + + def __init__(self, log_density: LogDensity | BatchLogDensity, is_batched: bool): + if not is_batched: + log_density = torch.vmap(log_density, in_dims=(0, None)) + self.log_density = log_density + + def __call__(self, pts, _, target_args): + return self.log_density(pts, target_args) + + +class BatchGradLogDensityEvaluator(BatchTargetEvaluator[BatchPtType]): + grad_log_density: BatchGradLogDensity + + def __init__( + self, + log_density_or_grad: LogDensity + | BatchLogDensity + | GradLogDensity + | BatchGradLogDensity, + is_grad: bool, + is_batched: bool, + ): + if is_batched: + if is_grad: + self.grad_log_density = log_density_or_grad + else: + self.grad_log_density = torch.func.grad( + lambda x, args: log_density_or_grad(x, args).sum() + ) + else: + if not is_grad: + log_density_or_grad = torch.func.grad(log_density_or_grad) + self.grad_log_density = torch.vmap(log_density_or_grad, in_dims=(0, None)) + + def __call__(self, pts, _, target_args): + return self.grad_log_density(pts, target_args) + + @dataclass class GaussianModel: forward_model: BatchForwardModel From 11b1eb4d8165e3add10e77cc32d2aba8bd85a176 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Thu, 23 Apr 2026 15:05:21 -0400 Subject: [PATCH 15/60] Work on file organization --- src/nak_torch/__init__.py | 3 ++- src/nak_torch/algorithms/__init__.py | 24 +++++++++--------- src/nak_torch/algorithms/cbs.py | 6 ++--- src/nak_torch/algorithms/grad_aldi.py | 8 +++--- src/nak_torch/algorithms/loop.py | 2 ++ src/nak_torch/algorithms/msip/__init__.py | 16 +++--------- src/nak_torch/algorithms/msip/msip.py | 2 ++ src/nak_torch/algorithms/msip/msip_gs.py | 2 ++ src/nak_torch/algorithms/svgd.py | 30 +++-------------------- 9 files changed, 34 insertions(+), 59 deletions(-) diff --git a/src/nak_torch/__init__.py b/src/nak_torch/__init__.py index 9846e2c..71a201c 100644 --- a/src/nak_torch/__init__.py +++ b/src/nak_torch/__init__.py @@ -1,5 +1,6 @@ from . import algorithms, tools +from .algorithms import nak from .tools import GaussianModel, metrics -__all__ = ["algorithms", "tools", "GaussianModel", "metrics"] +__all__ = ["algorithms", "tools", "GaussianModel", "metrics", "nak"] diff --git a/src/nak_torch/algorithms/__init__.py b/src/nak_torch/algorithms/__init__.py index b2c5b54..13eb1a8 100644 --- a/src/nak_torch/algorithms/__init__.py +++ b/src/nak_torch/algorithms/__init__.py @@ -8,27 +8,25 @@ # 05/12/2025 from .eks import eks -from .msip import msip, msip_gs, msip_ni, msip_greedy, msip_geom_greedy, msip_adapt -from .svgd import svgd +from .msip import MSIP, MSIPGS +from .svgd import SVGD from .deepensembles import deepensembles -from .grad_aldi import grad_aldi +from .grad_aldi import GradALDI from .gradfree_aldi import gradfree_aldi -from .cbs import cbs +from .cbs import CBS from .kfrflow import kfrflow +from .loop import nak __all__ = [ - "msip", - "msip_gs", - "msip_ni", - "msip_greedy", - "msip_geom_greedy", - "msip_adapt", - "svgd", + "nak", + "MSIP", + "MSIPGS", + "SVGD", "deepensembles", - "grad_aldi", + "GradALDI", "gradfree_aldi", "eks", - "cbs", + "CBS", "kfrflow", ] diff --git a/src/nak_torch/algorithms/cbs.py b/src/nak_torch/algorithms/cbs.py index ccbdff4..dcb60bf 100644 --- a/src/nak_torch/algorithms/cbs.py +++ b/src/nak_torch/algorithms/cbs.py @@ -11,6 +11,8 @@ ) from nak_torch.tools.util import sym_sqrtm +__all__ = ["CBS"] + def cbs_step( particles: BatchPtType, @@ -39,9 +41,7 @@ class CBSAlgorithmArgs: motion_scaling_sq_div_lr: float -class CBSAlgorithm( - UnweightedAdaptiveNAKAlgorithm[BatchLogDensityEvaluator, CBSAlgorithmArgs] -): +class CBS(UnweightedAdaptiveNAKAlgorithm[BatchLogDensityEvaluator, CBSAlgorithmArgs]): default_inverse_temp: float rng: torch.Generator diff --git a/src/nak_torch/algorithms/grad_aldi.py b/src/nak_torch/algorithms/grad_aldi.py index 34264a8..5dce78c 100644 --- a/src/nak_torch/algorithms/grad_aldi.py +++ b/src/nak_torch/algorithms/grad_aldi.py @@ -10,6 +10,8 @@ sym_sqrtm, ) +__all__ = ["GradALDI"] + def grad_aldi_step( particles: BatchPtType, @@ -36,9 +38,7 @@ def grad_aldi_step( return drift_term, particles_noise -class GradALDIAlgorithm( - UnweightedAdaptiveNAKAlgorithm[BatchGradLogDensityEvaluator, None] -): +class GradALDI(UnweightedAdaptiveNAKAlgorithm[BatchGradLogDensityEvaluator, None]): rng: torch.Generator def _sqrt(self, x: float): @@ -67,4 +67,4 @@ def step(self, lr, particles, target, algorithm_args, target_args): particles_diff.mul_(lr) particles_noise.mul_(self._sqrt(lr)) new_particles = particles_diff.add_(particles).add_(particles_noise) - return new_particles, None, algorithm_args + return new_particles, None, None diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index 8389a98..6c69d48 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -14,6 +14,8 @@ GeneralAdaptiveNAKAlgorithm, ) +__all__ = ["nak"] + def nak( target: BatchTargetEvaluator, diff --git a/src/nak_torch/algorithms/msip/__init__.py b/src/nak_torch/algorithms/msip/__init__.py index 66ba1f8..175789d 100644 --- a/src/nak_torch/algorithms/msip/__init__.py +++ b/src/nak_torch/algorithms/msip/__init__.py @@ -1,9 +1,5 @@ -from .msip import msip -from .msip_gs import msip_gs -from .msip_greedy import msip_greedy -from .msip_ni import msip_ni -from .msip_geom_greedy import msip_geom_greedy -from .msip_adapt import msip_adapt +from .msip import MSIP +from .msip_gs import MSIPGS from .estimators import ( MSIPEstimator, MSIPQuadGradientFree, @@ -13,12 +9,8 @@ ) __all__ = [ - "msip", - "msip_gs", - "msip_greedy", - "msip_ni", - "msip_geom_greedy", - "msip_adapt", + "MSIP", + "MSIPGS", "MSIPEstimator", "MSIPQuadGradientFree", "MSIPFredholm", diff --git a/src/nak_torch/algorithms/msip/msip.py b/src/nak_torch/algorithms/msip/msip.py index 0ce442e..5cd85c5 100644 --- a/src/nak_torch/algorithms/msip/msip.py +++ b/src/nak_torch/algorithms/msip/msip.py @@ -5,6 +5,8 @@ from nak_torch.algorithms.msip.msip_tools import GeneralMSIPAlgorithm, MSIPAlgorithmArgs from .msip_map import msip_map, get_msip_wts +__all__ = ["MSIP"] + class MSIP(GeneralMSIPAlgorithm[MSIPAlgorithmArgs]): def initialize(self, init_particles, target, target_args): diff --git a/src/nak_torch/algorithms/msip/msip_gs.py b/src/nak_torch/algorithms/msip/msip_gs.py index 693621e..e0a7794 100644 --- a/src/nak_torch/algorithms/msip/msip_gs.py +++ b/src/nak_torch/algorithms/msip/msip_gs.py @@ -5,6 +5,8 @@ from .msip_map import msip_map, get_msip_wts from .msip_tools import GeneralMSIPAlgorithm, MSIPGSAlgorithmArgs +__all__ = ["MSIPGS"] + class MSIPGS(GeneralMSIPAlgorithm[MSIPGSAlgorithmArgs]): def initialize(self, init_particles, target, target_args): diff --git a/src/nak_torch/algorithms/svgd.py b/src/nak_torch/algorithms/svgd.py index 1792a5c..8246368 100644 --- a/src/nak_torch/algorithms/svgd.py +++ b/src/nak_torch/algorithms/svgd.py @@ -7,42 +7,20 @@ # 05/12/2025 from dataclasses import astuple, dataclass -from typing import Optional, Callable +from typing import Optional import torch from nak_torch.tools.func import UnweightedAdaptiveNAKAlgorithm -from nak_torch.tools.kernel import kernel_grad_and_value_factory, default_kernel_elem +from nak_torch.tools.kernel import default_kernel_elem from nak_torch.tools.types import ( BatchGradLogDensityEvaluator, BatchKernelGradValFunction, DeviceLike, KernelFunction, - BatchGradLogDensity, BatchPtType, ) from nak_torch.tools.util import quantile_distance - -def create_svgd_step( - kernel_elem: KernelFunction, grad_log_p: BatchGradLogDensity, *kernel_elem_args -) -> Callable[[BatchPtType], BatchPtType]: - which_argnum = 1 - kernel_grad_val = kernel_grad_and_value_factory( - kernel_elem, which_argnum, *kernel_elem_args - ) - - def svgd_step_dir(points: BatchPtType): - # ASSUME SYMMETRY OF KERNEL - # kg[i,j,ell] = grad(x_j[ell]) k(x_i, x_j), k[i,j] = k(x_i, x_j) - k_grad, k_eval = kernel_grad_val(points, points) - # lpg[j,ell] = grad(x_j[ell]) log_p(x_j) - log_p_grad_ev = grad_log_p(points) - # term_1[i, ell] = sum_j k(i, j) grad(x_j[ell]) log_p(x_j) - term_1 = k_eval @ log_p_grad_ev - # term_2[i, ell] = sum_j grad(x_j[ell]) k(x_i, x_j) - term_2 = k_grad.sum(1) - return (term_1 + term_2) / points.shape[0] - - return svgd_step_dir +__all__ = ["SVGD"] def create_svgd_kernel_grad_val( @@ -77,7 +55,7 @@ class SVGDAlgorithmArgs: kernel_lengthscale: float -class SVGDAlgorithm( +class SVGD( UnweightedAdaptiveNAKAlgorithm[BatchGradLogDensityEvaluator, SVGDAlgorithmArgs] ): default_kernel_lengthscale: float From bb24bae3dc3cfd56e501f33ea1e581e86022bfec Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Thu, 23 Apr 2026 15:31:58 -0400 Subject: [PATCH 16/60] Work on himmelblau --- examples/functions/twodims/himmelblau.py | 2 +- examples/himmelblau.py | 79 +++++++-------------- src/nak_torch/algorithms/cbs.py | 3 +- src/nak_torch/algorithms/grad_aldi.py | 3 +- src/nak_torch/algorithms/loop.py | 9 ++- src/nak_torch/algorithms/msip/msip_tools.py | 3 +- src/nak_torch/algorithms/svgd.py | 15 ++-- src/nak_torch/tools/func.py | 7 ++ 8 files changed, 54 insertions(+), 67 deletions(-) diff --git a/examples/functions/twodims/himmelblau.py b/examples/functions/twodims/himmelblau.py index caaec88..03506ef 100644 --- a/examples/functions/twodims/himmelblau.py +++ b/examples/functions/twodims/himmelblau.py @@ -12,7 +12,7 @@ def himmelblau(T): - def himmelblau_aux(x): + def himmelblau_aux(x, _ = None): x1, x2 = x[...,0], x[...,1] t1 = torch.square(torch.square(x1)+x2-11) t2 = torch.square(x1 + torch.square(x2)-7) diff --git a/examples/himmelblau.py b/examples/himmelblau.py index 8a64255..6c1d726 100644 --- a/examples/himmelblau.py +++ b/examples/himmelblau.py @@ -2,16 +2,17 @@ import torch import matplotlib.pyplot as plt import nak_torch +from nak_torch.tools.types import BatchGradLogDensityEvaluator from viz_tools import animate_trajectories_box from functions import himmelblau -from nak_torch.algorithms import msip, svgd +from nak_torch import nak +from nak_torch.algorithms import MSIP, SVGD from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientFree from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre from datetime import datetime from nak_torch.tools.kernel import kernel_optimal_weight_factory, default_kernel_matrix save_gif = False -algorithm_name = "msip_ni" function_name = "himmelblau" log_density = himmelblau(50.0) @@ -23,36 +24,33 @@ torch.manual_seed(19230182) init_particles = torch.randn((n_particles, 2)) + 8.0 params = { + "n_steps": 100, "bounds": (-15, 15), - "kernel_length_scale": 0.2, + "kernel_lengthscale": 0.18, "init_particles": init_particles, "n_particles": n_particles, "dim": 2, "lr": 0.8, "kernel_diag_infl": 1e-5, + "verbose": False } # %% -estimator_fredholm = MSIPFredholm( +msip = MSIP(**params) +svgd = SVGD(**params) + +# %% +target_msip_fr = MSIPFredholm( gradient_decay=0.95, - log_dens_grad_val=torch.vmap(torch.func.grad_and_value(log_density)) + log_dens_grad_val=torch.vmap( + torch.func.grad_and_value(log_density), + in_dims=(0,None) + ) ) -trajectories_fr, trajectories_wts_fr = msip( - estimator_fredholm, - # n_particles=n_particles, - # init_particles=init_particles, - n_steps=100, # now interpreted as "epochs" (passes over all particles) - # lr=0.6, - # noise=0.05, # currently unused, kept for compatibility - # kernel_length_scale=0.5, - # inner_tol=1e-4, # equilibrium tolerance for a particle - # max_inner_steps=1000, # max inner iterations per particle - # kernel_diag_infl=1e-8, - # seed=, - **params -) +trajectories_fr, trajectories_wts_fr = nak(target_msip_fr, msip, **params) +# %% Ngrid = 1000 x = y = torch.linspace(-5, 5, Ngrid) X,Y = torch.meshgrid(x,y,indexing="ij") @@ -70,25 +68,11 @@ plt.show() - # %% -trajectories_svgd = svgd( - log_density, - # n_particles=25, - # init_particles=init_particles, - n_steps=100, # now interpreted as "epochs" (passes over all particles) - # dim=2, - # bounds=(-20, 20), - # lr=0.6, - # noise=0.05, # currently unused, kept for compatibility - # kernel_length_scale=0.5, - # inner_tol=1e-4, # equilibrium tolerance for a particle - # max_inner_steps=1000, # max inner iterations per particle - # kernel_diag_infl=1e-8, - # seed=, - **params -) +target_svgd = BatchGradLogDensityEvaluator(log_density, is_grad=False, is_batched=False) +trajectories_svgd = nak(target_svgd, svgd, **params) +# %% plt.contourf(X,Y,Z, levels=20, cmap="Grays") pts_svgd = trajectories_svgd[-1] s = plt.scatter( @@ -101,29 +85,18 @@ # %% -estimator = MSIPQuadGradientFree( +target_msip_gf = MSIPQuadGradientFree( log_density, lambda b: spherical_MC_radial_Laguerre(b, N_spherical=5, d=2, N_radial=2) ) params_gf = params.copy() -params_gf['lr'] = 0.6 +params_gf['lr'] = 0.8 n_particles = 25 -trajectories_gf,w = msip( - estimator, - # n_particles=n_particles, - # init_particles=init_particles, - n_steps=100, # now interpreted as "epochs" (passes over all particles) - # dim=2, - # bounds=(-20, 20), - # lr=0.6, - # kernel_length_scale=0.5, - # kernel_diag_infl=1e-8, - seed=1, - **params_gf -) +trajectories_pts_gf,trajectories_wts_gf = nak(target_msip_gf, msip, **params_gf) -pts_gf = trajectories_gf[-1] -wts_gf = kernel_optimal_weight_factory(pts_gf, log_density(pts_gf), default_kernel_matrix(pts_gf, params["kernel_length_scale"])) +# %% +pts_gf = trajectories_pts_gf[-1] +wts_gf = trajectories_wts_gf[-1] plt.contourf(X,Y,Z, levels=20, cmap="Grays") plt.scatter(pts_gf[:,0], pts_gf[:,1], c=wts_gf) diff --git a/src/nak_torch/algorithms/cbs.py b/src/nak_torch/algorithms/cbs.py index dcb60bf..8c777cf 100644 --- a/src/nak_torch/algorithms/cbs.py +++ b/src/nak_torch/algorithms/cbs.py @@ -54,8 +54,9 @@ def __init__( *_, default_inverse_temp: float, rng: torch.Generator, + **kwargs, ): - super().__init__(dim, n_particles, device, dtype) + super().__init__(dim, n_particles, device, dtype, **kwargs) self.default_inverse_temp = default_inverse_temp self.rng = rng diff --git a/src/nak_torch/algorithms/grad_aldi.py b/src/nak_torch/algorithms/grad_aldi.py index 5dce78c..5b9c933 100644 --- a/src/nak_torch/algorithms/grad_aldi.py +++ b/src/nak_torch/algorithms/grad_aldi.py @@ -52,8 +52,9 @@ def __init__( dtype: Optional[torch.dtype] = None, *_, rng: torch.Generator, + **kwargs, ): - super().__init__(dim, n_particles, device, dtype) + super().__init__(dim, n_particles, device, dtype, **kwargs) self.rng = rng def initialize(self, init_particles, target, target_args): diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index 6c69d48..07d0f8c 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -1,4 +1,5 @@ from typing import Any, Optional +import warnings from tqdm import tqdm import numpy as np @@ -28,12 +29,14 @@ def nak( bounds: Optional[tuple[float, float]] = None, keep_all: bool = True, target_args: Any = None, - verbose: bool = False, + **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: r""" TODO: Document """ - + verbose = algorithm.verbose + if verbose: + warnings.warn(f"Discarding kwargs {kwargs}") if n_steps < 0: raise ValueError("Expected positive number of steps.") @@ -70,7 +73,7 @@ def nak( traj_wts[idx + 1].copy_(particle_wts) particles, particle_wts, algorithm_args = algorithm.step( - lr, particles, algorithm_args, target, target_args + lr, particles, target, algorithm_args, target_args ) if bounds is not None: diff --git a/src/nak_torch/algorithms/msip/msip_tools.py b/src/nak_torch/algorithms/msip/msip_tools.py index 8bedb8f..fb05eba 100644 --- a/src/nak_torch/algorithms/msip/msip_tools.py +++ b/src/nak_torch/algorithms/msip/msip_tools.py @@ -55,8 +55,9 @@ def __init__( kernel_lengthscale: Optional[float] = None, kernel_lengthscale_quantile: Optional[float] = None, get_kernel_matrix: Optional[MatSelfKernelFunction] = None, + **kwargs, ): - super().__init__(dim, n_particles, device, dtype) + super().__init__(dim, n_particles, device, dtype, **kwargs) self.kernel_diag_infl = kernel_diag_infl if kernel_lengthscale is None and kernel_lengthscale_quantile is None: raise ValueError( diff --git a/src/nak_torch/algorithms/svgd.py b/src/nak_torch/algorithms/svgd.py index 8246368..18fe578 100644 --- a/src/nak_torch/algorithms/svgd.py +++ b/src/nak_torch/algorithms/svgd.py @@ -75,14 +75,15 @@ def __init__( device: Optional[DeviceLike] = None, dtype: Optional[torch.dtype] = None, *_, - default_kernel_lengthscale: Optional[float] = None, + kernel_lengthscale: Optional[float] = None, kernel_lengthscale_quantile: Optional[float] = None, kernel_elem: Optional[KernelFunction] = None, + **kwargs, ): - super().__init__(dim, n_particles, device, dtype) - if default_kernel_lengthscale is None and kernel_lengthscale_quantile is None: + super().__init__(dim, n_particles, device, dtype, **kwargs) + if kernel_lengthscale is None and kernel_lengthscale_quantile is None: raise ValueError( - "Must provide either default_kernel_lengthscale or kernel_lengthscale_quantile" + "Must provide either kernel_lengthscale or kernel_lengthscale_quantile" ) if kernel_lengthscale_quantile is not None and ( kernel_lengthscale_quantile < 0 or kernel_lengthscale_quantile > 1 @@ -90,11 +91,11 @@ def __init__( raise ValueError( f"Expected kernel_lengthscale_quantile in [0,1], given {kernel_lengthscale_quantile}" ) - if default_kernel_lengthscale is None: - default_kernel_lengthscale = 0.0 if kernel_elem is None: kernel_elem = default_kernel_elem - self.default_kernel_lengthscale = default_kernel_lengthscale + self.default_kernel_lengthscale = ( + 0.0 if kernel_lengthscale is None else kernel_lengthscale + ) self.kernel_lengthscale_quantile = kernel_lengthscale_quantile self.kernel_grad_val = create_svgd_kernel_grad_val(kernel_elem) diff --git a/src/nak_torch/tools/func.py b/src/nak_torch/tools/func.py index 07c282a..4de405d 100644 --- a/src/nak_torch/tools/func.py +++ b/src/nak_torch/tools/func.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Any, Generic, Optional, TypeVar +import warnings import torch from .types import ( BatchPtType, @@ -20,6 +21,7 @@ class GeneralAdaptiveNAKAlgorithm( n_particles: int device: Optional[DeviceLike] dtype: Optional[torch.dtype] + verbose: bool def __init__( self, @@ -27,11 +29,16 @@ def __init__( n_particles: int, device: Optional[DeviceLike], dtype: Optional[torch.dtype], + verbose: bool = True, + **kwargs, ): self.dim = dim self.n_particles = n_particles self.device = device self.dtype = dtype + self.verbose = verbose + if verbose: + warnings.warn(f"Unused kwargs:\n{kwargs}") @abstractmethod def initialize( From 19217ad050cb9c568c443c614b838c6ebfd6984d Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Thu, 23 Apr 2026 17:46:32 -0400 Subject: [PATCH 17/60] Fix Gaussian example --- examples/gaussian.py | 252 ++++++++++++++------------ src/nak_torch/algorithms/loop.py | 21 ++- src/nak_torch/algorithms/msip/msip.py | 2 +- src/nak_torch/tools/func.py | 2 +- src/nak_torch/tools/util.py | 11 +- 5 files changed, 157 insertions(+), 131 deletions(-) diff --git a/examples/gaussian.py b/examples/gaussian.py index 2e3ac4d..3d0236f 100644 --- a/examples/gaussian.py +++ b/examples/gaussian.py @@ -1,5 +1,6 @@ # %% from functools import partial +import math from jaxtyping import Float import matplotlib.pyplot as plt @@ -7,9 +8,13 @@ from torch import Tensor import nak_torch -from nak_torch.algorithms import grad_aldi, eks, gradfree_aldi, cbs, msip, kfrflow +from nak_torch.algorithms import eks, gradfree_aldi, kfrflow +from nak_torch.algorithms import SVGD, MSIP, MSIPGS, GradALDI, CBS +from tqdm import tqdm from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientInformed, MSIPQuadGradientFree from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre +from nak_torch.tools.kernel import sqexp_kernel_elem as kernel_elem, sqexp_kernel_matrix +from nak_torch.tools.types import BatchGradLogDensityEvaluator, BatchLogDensityEvaluator if torch.cuda.is_available(): torch.set_default_device("cuda") @@ -46,11 +51,12 @@ def weighted_cov(pts: Tensor, wts: Tensor): return second_moment - mean.outer(mean) + # %% torch.manual_seed(1023921) obs_op = torch.randn(2, 5) obs_op.div_(obs_op.norm(dim=1, keepdim=True)) -forward_model = torch.compile(lambda particles: particles @ obs_op) +forward_model = torch.compile(lambda particles, _=None: particles @ obs_op) true_obs = torch.tensor([1.0, 2.0, 3.0, 2.0, 1.0]) + 20 model = nak_torch.GaussianModel( @@ -61,23 +67,23 @@ def weighted_cov(pts: Tensor, wts: Tensor): @torch.compile -def like_log_dens(pt): +def like_log_dens(pt, model): ll_term = model.likelihood_precision * \ torch.linalg.norm(pt @ obs_op - model.true_obs, dim=-1)**2 return -0.5 * ll_term.squeeze() @torch.compile -def post_log_dens(pt): +def post_log_dens(pt, model): ll_term = model.likelihood_precision * \ torch.linalg.norm(pt @ obs_op - model.true_obs, dim=-1)**2 prior_term = model.prior_precision * torch.linalg.norm(pt, dim=-1)**2 return -0.5 * (ll_term + prior_term).squeeze() -post_log_dens_batch = torch.vmap(post_log_dens) +post_log_dens_batch = torch.vmap(post_log_dens, in_dims=(0,None)) post_log_dens_grad_val = torch.func.grad_and_value(post_log_dens) -post_log_dens_grad_val_batch = torch.vmap(post_log_dens_grad_val) +post_log_dens_grad_val_batch = torch.vmap(post_log_dens_grad_val, in_dims=(0,None)) # %% mean_pr, cov_pr = torch.zeros(2), torch.eye(2) / model.prior_precision @@ -95,6 +101,7 @@ def post_log_dens(pt): # %% n_steps, n_particles = 1000, 500 lr = 0.1 +bounds = (-100., 100.) init_particles = torch.randn((n_particles, 2)) / \ model.prior_precision + model.prior_mean @@ -109,56 +116,65 @@ def post_log_dens(pt): # %% # init_particles = torch.randn((n_particles, 2)) + torch.tensor([3, -3]) # delta_ts = torch.ones(1000)/1000 -n_particles_kfr = 100 -init_kfr = init_particles[:n_particles_kfr] #torch.randn((n_particles_kfr,2)) + torch.tensor([3,-5]) -def imq(pt1,pt2,h): - return 1/torch.sqrt(1 + (torch.linalg.norm(pt1-pt2) / h)**2) -trajectories_kfr = kfrflow( - like_log_dens, - n_particles_kfr, - 10000, 2, - init_particles=init_kfr, - kernel_length_scale = 1e-2, - kernel_diag_infl=1e-5, - # bounds=(-10,10), - # kernel_elem=imq, - keep_all=False -) +# n_particles_kfr = 100 +# init_kfr = init_particles[:n_particles_kfr] #torch.randn((n_particles_kfr,2)) + torch.tensor([3,-5]) +# def imq(pt1,pt2,h): +# return 1/torch.sqrt(1 + (torch.linalg.norm(pt1-pt2) / h)**2) +# trajectories_kfr = kfrflow( +# like_log_dens, +# n_particles_kfr, +# 10000, 2, +# init_particles=init_kfr, +# kernel_length_scale = 1e-2, +# kernel_diag_infl=1e-5, +# # bounds=(-10,10), +# # kernel_elem=imq, +# keep_all=False +# ) # %% -from nak_torch.tools.kernel import sqexp_kernel_elem as kernel_elem, sqexp_kernel_matrix -from tqdm import tqdm -kernel_vec = torch.compile(torch.vmap(kernel_elem, in_dims=(None,0,None))) -jac_kernel_vec = torch.vmap(torch.func.grad(kernel_elem), in_dims = (None, 0, None)) -kernel_mat = sqexp_kernel_matrix -n_steps_kfr = 100 -delta_t = 1 / n_steps_kfr -particles = init_particles.clone() -kernel_length_scale = 1e-2 -grad_ks = torch.empty((n_particles, n_particles, 2)) -M_t = torch.empty((n_particles, n_particles)) -for n in tqdm(range(n_steps_kfr)): - log_likely_evals = like_log_dens(particles) - M_t.zero_() - for i in range(n_particles): - grad_K = jac_kernel_vec(particles[i], particles, kernel_length_scale) - grad_ks[i].copy_(grad_K) - M_t.add_(grad_K @ grad_K.T) - M_t = M_t.div_(n_particles) - M_t[torch.arange(n_particles), torch.arange(n_particles)] += 1e-4 - wts_shift = log_likely_evals.mean() - wts = log_likely_evals.sub_(wts_shift).div_(n_particles) - K_mat = kernel_mat(particles, kernel_length_scale) - kernelized_wts = K_mat @ wts - particles += torch.einsum("jid,i->jd", grad_ks, torch.linalg.solve(M_t, kernelized_wts)).mul_(delta_t) +# kernel_vec = torch.compile(torch.vmap(kernel_elem, in_dims=(None,0,None))) +# jac_kernel_vec = torch.vmap(torch.func.grad(kernel_elem), in_dims = (None, 0, None)) +# kernel_mat = sqexp_kernel_matrix +# n_steps_kfr = 100 +# delta_t = 1 / n_steps_kfr +# particles = init_particles.clone() +# kernel_length_scale = 1e-2 +# grad_ks = torch.empty((n_particles, n_particles, 2)) +# M_t = torch.empty((n_particles, n_particles)) +# for n in tqdm(range(n_steps_kfr)): +# log_likely_evals = like_log_dens(particles, model) +# M_t.zero_() +# for i in range(n_particles): +# grad_K = jac_kernel_vec(particles[i], particles, kernel_length_scale) +# grad_ks[i].copy_(grad_K) +# M_t.add_(grad_K @ grad_K.T) +# M_t = M_t.div_(n_particles) +# M_t[torch.arange(n_particles), torch.arange(n_particles)] += 1e-4 +# wts_shift = log_likely_evals.mean() +# wts = log_likely_evals.sub_(wts_shift).div_(n_particles) +# K_mat = kernel_mat(particles, kernel_length_scale) +# kernelized_wts = K_mat @ wts +# particles += torch.einsum("jid,i->jd", grad_ks, torch.linalg.solve(M_t, kernelized_wts)).mul_(delta_t) # %% -trajectories_galdi = grad_aldi( - post_log_dens, n_particles, n_steps, dim=2, - lr=lr, init_particles=init_particles, - keep_all=False +rng = torch.Generator() +rng.manual_seed(0) + +# %% +grad_aldi = GradALDI(dim=2, n_particles=n_particles, rng = rng) +grad_aldi_target = BatchGradLogDensityEvaluator(post_log_dens, is_grad=False, is_batched=True) +trajectories_galdi = nak_torch.nak(grad_aldi_target, grad_aldi, + n_steps=n_steps, lr=lr, + init_particles=init_particles, keep_all=False, + rng_or_seed=rng, target_args=model, bounds=bounds ) +# trajectories_galdi = grad_aldi( +# post_log_dens, n_particles, n_steps, dim=2, +# lr=lr, init_particles=init_particles, +# keep_all=False +# ) # %% trajectories_gfaldi = gradfree_aldi( @@ -168,102 +184,99 @@ def imq(pt1,pt2,h): ) # %% -trajectories_cbs = cbs( - post_log_dens, n_particles, n_steps, inverse_temp=0.95, dim=2, - lr=lr, init_particles=init_particles, - keep_all=True +cbs_target = BatchLogDensityEvaluator(post_log_dens, is_batched=True) +cbs = CBS(dim=2, n_particles=n_particles, default_inverse_temp=0.95, rng=rng) +trajectories_cbs = nak_torch.nak( + cbs_target, cbs, n_steps, lr, + rng_or_seed=rng, init_particles=init_particles, + target_args=model, bounds = bounds ) +# trajectories_cbs = cbs( +# post_log_dens, n_particles, n_steps, inverse_temp=0.95, dim=2, +# lr=lr, init_particles=init_particles, +# keep_all=True +# ) # %% -kernel_length_scale = 0.25 -bounds = (-100., 100.) -gradient_decay = 1.0 -n_particles_msip = 5 +kernel_lengthscale = 0.1 +gradient_decay = 0.95 +n_particles_msip = 10 n_steps_msip = 1000 -lr_msip = 1e-2 -kernel_diag_infl = 1e-8 -msip_fredholm = MSIPFredholm( +lr_msip = 5e-2 +kernel_diag_infl = 1e-5 +msip = MSIP( + dim=2, + n_particles=n_particles_msip, + kernel_diag_infl=kernel_diag_infl, + kernel_lengthscale=kernel_lengthscale +) + +msip_fredholm_target = MSIPFredholm( gradient_decay, post_log_dens_grad_val_batch ) -trajectories_msip, traj_wts_msip = msip( - msip_fredholm, n_particles_msip, n_steps_msip, dim=2, - lr=lr_msip, init_particles=init_particles[:n_particles_msip], - kernel_length_scale=kernel_length_scale, - is_log_density_batched=True, - kernel_diag_infl=kernel_diag_infl, - bounds=bounds, - gradient_decay=gradient_decay, - keep_all=True +trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( + msip_fredholm_target, msip, n_steps_msip, lr_msip, + rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], + target_args=model, keep_all=True, bounds=bounds ) # %% - - def mc_quad_rule(batch_size: int, N_quad: int = 5, dim: int = 2): - pts = torch.randn((batch_size, N_quad, dim)) + pts = torch.randn((batch_size, N_quad, dim), generator=rng) wts = torch.ones((batch_size, N_quad)).div_(N_quad) return pts, wts def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): pts, wts = spherical_MC_radial_Laguerre( - batch_size, N_spherical, 2, N_radial) + batch_size, N_spherical, 2, N_radial + ) return pts, wts # %% -# kernel_length_scale = 1e-3 -# gradient_decay = 1. -msip_quadgrad = MSIPQuadGradientInformed( +msip_quadgrad_target = MSIPQuadGradientInformed( post_log_dens_grad_val_batch, mc_quad_rule, gradient_decay ) -trajectories_msip_qg, traj_wts_msip_qg = msip( - msip_quadgrad, n_particles_msip, 100, dim=2, - lr=10., init_particles=init_particles[:n_particles_msip], - kernel_length_scale=kernel_length_scale, - # is_log_density_batched=True, - kernel_diag_infl=1e-8, - bounds=(-1000, 1000), - # gradient_decay=gradient_decay, - keep_all=False + +trajectories_pts_msip_qg, trajectories_wts_msip_qg = nak_torch.nak( + msip_quadgrad_target, msip, n_steps_msip, lr_msip, + rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], target_args=model, + keep_all=False, bounds=bounds ) # %% -# n_particles_msip = 500 -# kernel_length_scale = 1e-2 -msip_quadgf = MSIPQuadGradientFree( +msip_quadgf_target = MSIPQuadGradientFree( post_log_dens_batch, partial(mc_quad_rule, N_quad=100) ) -trajectories_msip_qgf, traj_wts_msip_qgf = msip( - msip_quadgf, n_particles_msip, 500, dim=2, - lr=1., init_particles=init_particles[:n_particles_msip], - kernel_length_scale=kernel_length_scale, - kernel_diag_infl=1e-8, - bounds=(-1000., 1000.), - keep_all=False +trajectories_pts_msip_qgf, trajectories_wts_msip_qgf = nak_torch.nak( + msip_quadgf_target, msip, n_steps, lr_msip, + rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], target_args=model, + keep_all=False, bounds=bounds ) # %% -pts_eks = trajectories_eks[-1] -pts_kfr = particles -pts_galdi = trajectories_galdi[-1] -pts_gfaldi = trajectories_gfaldi[-1] -pts_cbs = trajectories_cbs[-1] -idx_msip = 100 -pts_msip = trajectories_msip[idx_msip] -wts_msip = traj_wts_msip[idx_msip] +# pts_eks = trajectories_eks[-1] +# pts_kfr = particles +# pts_galdi = trajectories_galdi[-1] +# pts_gfaldi = trajectories_gfaldi[-1] +# pts_cbs = trajectories_cbs[-1] +idx_msip = -1 +alpha_msip = 2/math.sqrt(n_particles_msip) +pts_msip = trajectories_pts_msip_fr[idx_msip] +wts_msip = trajectories_wts_msip_fr[idx_msip] # wts_msip /= wts_msip.sum() -pts_msip_qg = trajectories_msip_qg[-1] -wts_msip_qg = traj_wts_msip_qg[-1] -wts_msip_qg = wts_msip_qg/wts_msip_qg.sum() -pts_msip_qgf = trajectories_msip_qgf[-1] -wts_msip_qgf = traj_wts_msip_qgf[-1] +pts_msip_qg = trajectories_pts_msip_qg[-1] +wts_msip_qg = trajectories_wts_msip_qg[-1] +# wts_msip_qg = wts_msip_qg/wts_msip_qg.sum() +pts_msip_qgf = trajectories_pts_msip_qgf[-1] +wts_msip_qgf = trajectories_wts_msip_qgf[-1] # wts_msip_qgf = wts_msip_qgf/wts_msip_qgf.sum() Ngrid = 100 @@ -275,23 +288,24 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): grid_pts = torch.stack((X.flatten(), Y.flatten()), 1) fig, ax = plt.subplots() -ax.contour(X, Y, post_log_dens(grid_pts).reshape(Ngrid, Ngrid), levels=10) +ax.contour(X, Y, post_log_dens(grid_pts, model).reshape(Ngrid, Ngrid), levels=10) +handles = [] # ax.scatter(samps[:, 0], samps[:, 1], alpha=0.025, label="Truth") # ax.scatter(pts_galdi[:, 0], pts_galdi[:, 1], alpha=0.2, label="Grad-ALDI") # ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], # alpha=0.2, label="GradFree-ALDI") -ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") +# ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") # ax.scatter(pts_eks[:, 0], pts_eks[:, 1], alpha=0.1, label="EKS") # ax.scatter(pts_cbs[:, 0], pts_cbs[:, 1], alpha=0.1, label="CBS") -# s = ax.scatter(pts_msip[:, 0], pts_msip[:, 1], - # c=wts_msip, alpha=0.15, label="MSIP") -# s = ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], -# c = wts_msip_qg, alpha=0.15, label="MSIP-QuadGrad") -# s = ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], -# c = wts_msip_qgf, alpha=0.15, label="MSIP-QuadGradFree") +handles.append(ax.scatter(pts_msip[:, 0], pts_msip[:, 1], alpha=alpha_msip, label="MSIP")) +handles.append(ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], alpha=alpha_msip, label="MSIP-QuadGrad")) +handles.append(ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], + s = 50*wts_msip_qgf/wts_msip_qgf.max(), alpha=alpha_msip, label="MSIP-QuadGradFree")) # plt.colorbar(s) -ax.set_aspect(1.0) -ax.legend() +# ax.set_aspect(1.0) +ax.legend(handles = handles) +# ax.set_xlim(xgrid.min(), xgrid.max()) +# ax.set_ylim(ygrid.min(), ygrid.max()) plt.show() # %% diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index 07d0f8c..5cc4e32 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -21,10 +21,9 @@ def nak( target: BatchTargetEvaluator, algorithm: GeneralAdaptiveNAKAlgorithm, - n_particles: int, n_steps: int, lr: float, - seed: Optional[int] = None, + rng_or_seed: Optional[int | torch.Generator] = None, init_particles: Optional[Tensor | np.ndarray] = None, bounds: Optional[tuple[float, float]] = None, keep_all: bool = True, @@ -34,18 +33,24 @@ def nak( r""" TODO: Document """ - verbose = algorithm.verbose - if verbose: + verbose, n_particles = algorithm.verbose, algorithm.n_particles + if verbose and len(kwargs) > 0: warnings.warn(f"Discarding kwargs {kwargs}") if n_steps < 0: raise ValueError("Expected positive number of steps.") - if seed is not None: - torch.manual_seed(seed) - dim, device, dtype = algorithm.dim, algorithm.device, algorithm.dtype + rng: torch.Generator + if isinstance(rng_or_seed, int): + rng = torch.Generator(device) + rng.manual_seed(rng_or_seed) + elif rng_or_seed is not None: + rng = rng_or_seed + else: + rng = torch.default_generator + particles = initialize_particles( - n_particles, dim, init_particles, device, dtype, bounds + n_particles, dim, init_particles, device, dtype, bounds, rng=rng ) particle_wts, algorithm_args = algorithm.initialize(particles, target, target_args) diff --git a/src/nak_torch/algorithms/msip/msip.py b/src/nak_torch/algorithms/msip/msip.py index 5cd85c5..84d07e6 100644 --- a/src/nak_torch/algorithms/msip/msip.py +++ b/src/nak_torch/algorithms/msip/msip.py @@ -34,7 +34,7 @@ def step(self, lr, particles, target, algorithm_args, target_args): # Update the parameters kernel_lengthscale = self.get_adaptive_lengthscale(new_particles) kernel_matrix = self.get_infl_kernel_matrix(new_particles, kernel_lengthscale) - msip_estimator_output = target(particles, kernel_lengthscale, target_args) + msip_estimator_output = target(new_particles, kernel_lengthscale, target_args) algorithm_args = MSIPAlgorithmArgs( kernel_lengthscale, kernel_matrix, msip_estimator_output ) diff --git a/src/nak_torch/tools/func.py b/src/nak_torch/tools/func.py index 4de405d..def9b43 100644 --- a/src/nak_torch/tools/func.py +++ b/src/nak_torch/tools/func.py @@ -37,7 +37,7 @@ def __init__( self.device = device self.dtype = dtype self.verbose = verbose - if verbose: + if verbose and len(kwargs) > 0: warnings.warn(f"Unused kwargs:\n{kwargs}") @abstractmethod diff --git a/src/nak_torch/tools/util.py b/src/nak_torch/tools/util.py index f762cd9..00cd3d3 100644 --- a/src/nak_torch/tools/util.py +++ b/src/nak_torch/tools/util.py @@ -29,10 +29,17 @@ def initialize_particles( bounds: Optional[tuple[float, float]], rng: Optional[torch.Generator] = None, ) -> BatchPtType: - rng = torch.default_generator + if rng is None: + rng = torch.default_generator + if device is None: + device = torch.get_default_device() + elif not isinstance(device, torch.device): + device = torch.device(device) + if rng.device != device: + raise ValueError(f"Expected rng to be on device {device}. Got {rng.device}") if init_particles is None: if bounds is None: - return torch.randn((n_particles, dim), device=device) + return torch.randn((n_particles, dim), device=device, generator=rng) else: return torch.empty((n_particles, dim), device=device).uniform_( *bounds, generator=rng From 545b0ae2ff42c38764d12e86a76ad0264e4c24e3 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 24 Apr 2026 14:18:06 -0400 Subject: [PATCH 18/60] Working GF-ALDI --- examples/gaussian.py | 82 +++---- examples/himmelblau.py | 10 +- src/nak_torch/algorithms/__init__.py | 5 +- src/nak_torch/algorithms/cbs.py | 2 +- src/nak_torch/algorithms/grad_aldi.py | 2 +- src/nak_torch/algorithms/gradfree_aldi.py | 244 +++++++++++--------- src/nak_torch/algorithms/loop.py | 14 +- src/nak_torch/algorithms/msip/estimators.py | 7 +- src/nak_torch/algorithms/msip/msip.py | 4 +- src/nak_torch/algorithms/svgd.py | 2 +- src/nak_torch/tools/func.py | 16 +- src/nak_torch/tools/metrics.py | 2 + src/nak_torch/tools/types.py | 16 +- 13 files changed, 220 insertions(+), 186 deletions(-) diff --git a/examples/gaussian.py b/examples/gaussian.py index 3d0236f..5039b24 100644 --- a/examples/gaussian.py +++ b/examples/gaussian.py @@ -8,8 +8,8 @@ from torch import Tensor import nak_torch -from nak_torch.algorithms import eks, gradfree_aldi, kfrflow -from nak_torch.algorithms import SVGD, MSIP, MSIPGS, GradALDI, CBS +from nak_torch.algorithms import eks, kfrflow +from nak_torch.algorithms import SVGD, MSIP, MSIPGS, GradALDI, CBS, GradFreeALDI from tqdm import tqdm from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientInformed, MSIPQuadGradientFree from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre @@ -24,8 +24,6 @@ torch.set_default_dtype(torch.float64) # %% - - def make_gaussian_post( forward_op: Float[Tensor, "obs dim"], mean_pr: Float[Tensor, " dim"], @@ -73,17 +71,20 @@ def like_log_dens(pt, model): return -0.5 * ll_term.squeeze() -@torch.compile -def post_log_dens(pt, model): - ll_term = model.likelihood_precision * \ - torch.linalg.norm(pt @ obs_op - model.true_obs, dim=-1)**2 - prior_term = model.prior_precision * torch.linalg.norm(pt, dim=-1)**2 - return -0.5 * (ll_term + prior_term).squeeze() - +# @torch.compile +# def post_log_dens(pt, model): +# ll_term = model.likelihood_precision * \ +# torch.linalg.norm(pt @ obs_op - model.true_obs, dim=-1)**2 +# prior_term = model.prior_precision * torch.linalg.norm(pt, dim=-1)**2 +# return -0.5 * (ll_term + prior_term).squeeze() -post_log_dens_batch = torch.vmap(post_log_dens, in_dims=(0,None)) -post_log_dens_grad_val = torch.func.grad_and_value(post_log_dens) -post_log_dens_grad_val_batch = torch.vmap(post_log_dens_grad_val, in_dims=(0,None)) +# post_log_dens_batch = torch.vmap(post_log_dens, in_dims=(0,None)) +# post_log_dens_grad_val = torch.func.grad_and_value(post_log_dens) +post_log_dens = model.to_log_dens() +def _tmp_post_log_dens(pts, args): + out = post_log_dens(pts,args) + return out.sum(), out +post_log_dens_grad_val_batch = torch.func.grad(_tmp_post_log_dens, has_aux=True) # %% mean_pr, cov_pr = torch.zeros(2), torch.eye(2) / model.prior_precision @@ -168,19 +169,20 @@ def post_log_dens(pt, model): trajectories_galdi = nak_torch.nak(grad_aldi_target, grad_aldi, n_steps=n_steps, lr=lr, init_particles=init_particles, keep_all=False, - rng_or_seed=rng, target_args=model, bounds=bounds + rng_or_seed=rng, target_args=None, bounds=bounds ) -# trajectories_galdi = grad_aldi( -# post_log_dens, n_particles, n_steps, dim=2, -# lr=lr, init_particles=init_particles, -# keep_all=False -# ) # %% -trajectories_gfaldi = gradfree_aldi( - model, n_particles, n_steps, dim=2, - lr=lr, init_particles=init_particles, - keep_all=True +gf_aldi = GradFreeALDI(dim=2, n_particles=n_particles) +# trajectories_gfaldi = gradfree_aldi( +# model, n_particles, n_steps, dim=2, +# lr=lr, init_particles=init_particles, +# keep_all=True +# ) +trajectories_gfaldi = nak_torch.nak(model, gf_aldi, + n_steps=n_steps, lr=1e-2, + init_particles=init_particles, keep_all=True, + rng_or_seed=rng, target_args=None, bounds=bounds ) # %% @@ -189,7 +191,7 @@ def post_log_dens(pt, model): trajectories_cbs = nak_torch.nak( cbs_target, cbs, n_steps, lr, rng_or_seed=rng, init_particles=init_particles, - target_args=model, bounds = bounds + target_args=None, bounds = bounds ) # trajectories_cbs = cbs( # post_log_dens, n_particles, n_steps, inverse_temp=0.95, dim=2, @@ -198,17 +200,18 @@ def post_log_dens(pt, model): # ) # %% -kernel_lengthscale = 0.1 +kernel_lengthscale = 0.15 gradient_decay = 0.95 -n_particles_msip = 10 +n_particles_msip = 25 n_steps_msip = 1000 -lr_msip = 5e-2 +lr_msip = 5e-3 kernel_diag_infl = 1e-5 msip = MSIP( dim=2, n_particles=n_particles_msip, kernel_diag_infl=kernel_diag_infl, - kernel_lengthscale=kernel_lengthscale + kernel_lengthscale=kernel_lengthscale, + # kernel_lengthscale_quantile=0.25 ) msip_fredholm_target = MSIPFredholm( @@ -216,6 +219,7 @@ def post_log_dens(pt, model): post_log_dens_grad_val_batch ) +# %% trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( msip_fredholm_target, msip, n_steps_msip, lr_msip, rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], @@ -251,11 +255,11 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): # %% msip_quadgf_target = MSIPQuadGradientFree( - post_log_dens_batch, partial(mc_quad_rule, N_quad=100) + post_log_dens, partial(spherical_quad, N_spherical=5, N_radial=4) ) trajectories_pts_msip_qgf, trajectories_wts_msip_qgf = nak_torch.nak( - msip_quadgf_target, msip, n_steps, lr_msip, + msip_quadgf_target, msip, 100, 8e-1, rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], target_args=model, keep_all=False, bounds=bounds ) @@ -265,7 +269,7 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): # pts_eks = trajectories_eks[-1] # pts_kfr = particles # pts_galdi = trajectories_galdi[-1] -# pts_gfaldi = trajectories_gfaldi[-1] +pts_gfaldi = trajectories_gfaldi[-1] # pts_cbs = trajectories_cbs[-1] idx_msip = -1 alpha_msip = 2/math.sqrt(n_particles_msip) @@ -292,18 +296,18 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): handles = [] # ax.scatter(samps[:, 0], samps[:, 1], alpha=0.025, label="Truth") # ax.scatter(pts_galdi[:, 0], pts_galdi[:, 1], alpha=0.2, label="Grad-ALDI") -# ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], -# alpha=0.2, label="GradFree-ALDI") +ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], + alpha=0.2, label="GradFree-ALDI") # ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") # ax.scatter(pts_eks[:, 0], pts_eks[:, 1], alpha=0.1, label="EKS") # ax.scatter(pts_cbs[:, 0], pts_cbs[:, 1], alpha=0.1, label="CBS") -handles.append(ax.scatter(pts_msip[:, 0], pts_msip[:, 1], alpha=alpha_msip, label="MSIP")) -handles.append(ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], alpha=alpha_msip, label="MSIP-QuadGrad")) -handles.append(ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], - s = 50*wts_msip_qgf/wts_msip_qgf.max(), alpha=alpha_msip, label="MSIP-QuadGradFree")) +# ax.scatter(pts_msip[:, 0], pts_msip[:, 1], alpha=alpha_msip, label="MSIP") +# handles.append(ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], alpha=alpha_msip, label="MSIP-QuadGrad")) +# ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], +# s = 50*wts_msip_qgf.abs()/wts_msip_qgf.max(), alpha=alpha_msip, label="MSIP-QuadGradFree") # plt.colorbar(s) # ax.set_aspect(1.0) -ax.legend(handles = handles) +ax.legend() # ax.set_xlim(xgrid.min(), xgrid.max()) # ax.set_ylim(ygrid.min(), ygrid.max()) plt.show() diff --git a/examples/himmelblau.py b/examples/himmelblau.py index 6c1d726..3f194f4 100644 --- a/examples/himmelblau.py +++ b/examples/himmelblau.py @@ -26,11 +26,11 @@ params = { "n_steps": 100, "bounds": (-15, 15), - "kernel_lengthscale": 0.18, + "kernel_lengthscale": 0.15, "init_particles": init_particles, "n_particles": n_particles, "dim": 2, - "lr": 0.8, + "lr": 0.6, "kernel_diag_infl": 1e-5, "verbose": False } @@ -102,7 +102,7 @@ # %% batch_log_dens = torch.vmap(log_density) -batch_grad_log_dens = torch.vmap(torch.func.grad(log_density)) +batch_grad_log_dens = torch.vmap(torch.func.grad(log_density), in_dims=(0,None)) def kernel_elem(x: torch.Tensor, y: torch.Tensor, sigma: float): return torch.reciprocal(1 + (x - y).div(sigma).square().sum()) ksd_eval = nak_torch.metrics.KernelSteinDiscrepancy(batch_grad_log_dens, 0.25, kernel_elem=kernel_elem) @@ -130,8 +130,8 @@ def kernel_elem(x: torch.Tensor, y: torch.Tensor, sigma: float): title_weights = [None, None, 'heavy', 'heavy'] for (ax, title, pt, wt, title_wt) in zip(axs, titles, pt_list, wt_list, title_weights): ax.set_axis_off() - ax.set_xlim(g_min, g_max) - ax.set_ylim(g_min, 1.05*g_max) + # ax.set_xlim(g_min, g_max) + # ax.set_ylim(g_min, 1.05*g_max) ax.set_title(title, fontweight=title_wt, size=20) ax.contourf(X[:,40:],Y[:,40:],Z[:,40:], levels=20, cmap="Grays") s = 25 * (1. if wt is None else ((wt.abs()/wt.abs().max())).sqrt()) diff --git a/src/nak_torch/algorithms/__init__.py b/src/nak_torch/algorithms/__init__.py index 13eb1a8..f2eccde 100644 --- a/src/nak_torch/algorithms/__init__.py +++ b/src/nak_torch/algorithms/__init__.py @@ -12,7 +12,7 @@ from .svgd import SVGD from .deepensembles import deepensembles from .grad_aldi import GradALDI -from .gradfree_aldi import gradfree_aldi +from .gradfree_aldi import GradFreeALDI from .cbs import CBS from .kfrflow import kfrflow from .loop import nak @@ -25,7 +25,8 @@ "SVGD", "deepensembles", "GradALDI", - "gradfree_aldi", + # "gradfree_aldi", + "GradFreeALDI", "eks", "CBS", "kfrflow", diff --git a/src/nak_torch/algorithms/cbs.py b/src/nak_torch/algorithms/cbs.py index 8c777cf..072beaa 100644 --- a/src/nak_torch/algorithms/cbs.py +++ b/src/nak_torch/algorithms/cbs.py @@ -69,7 +69,7 @@ def initialize(self, init_particles, target, target_args): def step(self, lr, particles, target, algorithm_args, target_args): inverse_temp, motion_scaling_sq_div_lr = astuple(algorithm_args) motion_scaling_sq = motion_scaling_sq_div_lr * lr - log_dens_eval = target(particles, None, target_args) + log_dens_eval = target(particles, target_args) particles_diff, particles_noise = cbs_step( particles, log_dens_eval, inverse_temp, motion_scaling_sq, self.rng ) diff --git a/src/nak_torch/algorithms/grad_aldi.py b/src/nak_torch/algorithms/grad_aldi.py index 5b9c933..8e991ba 100644 --- a/src/nak_torch/algorithms/grad_aldi.py +++ b/src/nak_torch/algorithms/grad_aldi.py @@ -61,7 +61,7 @@ def initialize(self, init_particles, target, target_args): return None, None def step(self, lr, particles, target, algorithm_args, target_args): - grad_log_dens_evals = target(particles, None, target_args) + grad_log_dens_evals = target(particles, target_args) particles_diff, particles_noise = grad_aldi_step( particles, grad_log_dens_evals, self.rng ) diff --git a/src/nak_torch/algorithms/gradfree_aldi.py b/src/nak_torch/algorithms/gradfree_aldi.py index ae7b3ff..2ec5e37 100644 --- a/src/nak_torch/algorithms/gradfree_aldi.py +++ b/src/nak_torch/algorithms/gradfree_aldi.py @@ -1,118 +1,140 @@ +from dataclasses import astuple + import torch -from typing import Optional +from typing import Any, Optional from jaxtyping import Float from torch import Tensor -from nak_torch.tools.types import BatchPtType, GaussianModel -import warnings -from tqdm import tqdm -import numpy as np -from nak_torch.tools.util import initialize_particles, sym_sqrtm - - -def build_gradfree_aldi_step( - model: GaussianModel, rng: torch.Generator, compile_step: bool -): - prior_mean = model.prior_mean - likelihood_precision = model.likelihood_precision - prior_precision = model.prior_precision - true_obs = model.true_obs - if isinstance(true_obs, Tensor): - true_obs.reshape(1, -1) - - sqrt_2 = torch.sqrt(torch.tensor(2, dtype=true_obs.dtype, device=true_obs.device)) - - def gradfree_aldi_step( - particles: BatchPtType, forecast_observations: Float[Tensor, "batch obs"] - ) -> tuple[BatchPtType, Float[Tensor, "dim dim"]]: - N_batch, dim = particles.shape - particle_mean = particles.mean(0, True) - forecast_obs_mean = forecast_observations.mean(0, True) - prior_err = particles - if prior_mean != 0.0: - prior_err -= prior_mean - obs_error = forecast_observations - true_obs - obs_deviation = forecast_observations - forecast_obs_mean - forecast_deviation = particles - particle_mean - cov_forecast = (forecast_deviation.T @ forecast_deviation) / N_batch - cov_obs_forecast = (obs_deviation.T @ forecast_deviation) / N_batch - - if isinstance(likelihood_precision, float): - likely_term = obs_error @ cov_obs_forecast - likely_term.mul_(likelihood_precision) - else: - likely_term = torch.chain_matmul( - obs_error, likelihood_precision, cov_obs_forecast - ) +from nak_torch.tools.func import UnweightedAdaptiveNAKAlgorithm +from nak_torch.tools.types import ( + BatchPtType, + CovType, + DeviceLike, + GaussianModel, + PtType, +) +from nak_torch.tools.util import sym_sqrtm + + +# def build_gradfree_aldi_step( +# model: GaussianModel, rng: torch.Generator, compile_step: bool +# ): +# prior_mean = model.prior_mean +# likelihood_precision = model.likelihood_precision +# prior_precision = model.prior_precision +# true_obs = model.true_obs +# if isinstance(true_obs, Tensor): +# true_obs.reshape(1, -1) + +# sqrt_2 = torch.sqrt(torch.tensor(2, dtype=true_obs.dtype, device=true_obs.device)) + + +def gradfree_aldi_step( + particles: BatchPtType, + forecast_observations: Float[Tensor, "batch obs"], + prior_mean: PtType, + likelihood_precision: CovType, + prior_precision: CovType, + true_observation: Float[Tensor, " obs"], + rng: torch.Generator, +) -> tuple[BatchPtType, Float[Tensor, "dim dim"]]: + + N_batch, dim = particles.shape + particle_mean = particles.mean(dim=0, keepdim=True) + forecast_obs_mean = forecast_observations.mean(dim=0, keepdim=True) + prior_err = particles + if prior_mean != 0.0: + prior_err -= prior_mean + obs_error = forecast_observations - true_observation + obs_deviation = forecast_observations - forecast_obs_mean + forecast_deviation = particles - particle_mean + cov_forecast = (forecast_deviation.T @ forecast_deviation) / N_batch + cov_obs_forecast = (obs_deviation.T @ forecast_deviation) / N_batch + + if isinstance(likelihood_precision, float): + likely_term = obs_error @ cov_obs_forecast + likely_term.mul_(likelihood_precision) + else: + likely_term = torch.chain_matmul( + obs_error, likelihood_precision, cov_obs_forecast + ) - sqrt_cov_forecast = sym_sqrtm(cov_forecast) - sqrt_cov_forecast.mul_(sqrt_2) + sqrt_cov_forecast = sym_sqrtm(cov_forecast) - if isinstance(prior_precision, float): - prior_term1 = prior_err @ cov_forecast - prior_term1.mul_(prior_precision) - else: - prior_term1 = torch.chain_matmul(cov_forecast, prior_precision, prior_err) - - prior_term2 = forecast_deviation.mul_((dim + 1) / N_batch) - particle_diff = prior_term2.sub_(prior_term1).sub_(likely_term) - noise = torch.normal(0.0, 1.0, particles.shape, generator=rng, out=prior_err) - motion = torch.matmul(noise, sqrt_cov_forecast, out=prior_term1) - - return particle_diff, motion - - return torch.compile(gradfree_aldi_step) if compile_step else gradfree_aldi_step - - -def gradfree_aldi( - model: GaussianModel, - n_particles: int, - n_steps: int, - dim: int, - lr: float, - seed: Optional[int] = None, - device: Optional[torch.device] = None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - bounds: Optional[tuple[float, float]] = None, - rng: Optional[torch.Generator] = None, - keep_all: bool = True, - verbose: bool = False, - compile_step: bool = True, - **unused_kwargs, -): - if verbose and len(unused_kwargs) > 0: - warnings.warn("Unused kwargs:\n{}".format(unused_kwargs)) - - if rng is None: - rng = torch.default_generator - if seed is not None: - rng.manual_seed(seed) - - particles = initialize_particles( - n_particles, dim, init_particles, device, bounds, rng - ) - - if keep_all: - trajectories = torch.empty( - (n_steps, *particles.shape), device=device, dtype=particles.dtype - ) - trajectories[0].copy_(particles) + if isinstance(prior_precision, float): + prior_term1 = prior_err @ cov_forecast + prior_term1.mul_(prior_precision) else: - trajectories = torch.empty(()) - gradfree_aldi_step = build_gradfree_aldi_step(model, rng, compile_step) - sqrt_lr = torch.sqrt(torch.tensor(lr)) - - for idx in tqdm(range(n_steps), disable=not verbose): - forecast_observations = model.forward_model(particles) - with torch.no_grad(): - particles_diff, particles_noise = gradfree_aldi_step( - particles, forecast_observations + prior_term1 = torch.chain_matmul(cov_forecast, prior_precision, prior_err) + + prior_term2 = forecast_deviation.mul_((dim + 1) / N_batch) + particle_diff = prior_term2.sub_(prior_term1).sub_(likely_term) + noise = torch.normal(0.0, 1.0, particles.shape, generator=rng) + motion = torch.matmul(noise, sqrt_cov_forecast) + + return particle_diff, motion + + +class GradFreeALDI(UnweightedAdaptiveNAKAlgorithm[GaussianModel, None]): + rng: torch.Generator + + def sqrt_scalar(self, scalar: float) -> Float: + return torch.as_tensor(scalar, device=self.device, dtype=self.dtype).sqrt() + + def __init__( + self, + dim: int, + n_particles: int, + device: Optional[DeviceLike] = None, + dtype: Optional[torch.dtype] = None, + *_, + rng_or_seed: Optional[torch.Generator | int] = None, + **kwargs, + ): + super().__init__(dim, n_particles, device, dtype, **kwargs) + if isinstance(rng_or_seed, int): + self.rng = torch.Generator(self.device).set_state( + torch.default_generator.get_state() + ) + self.rng.manual_seed(rng_or_seed) + elif rng_or_seed is None: + self.rng = torch.Generator(self.device).set_state( + torch.default_generator.get_state() ) - particles_diff.mul_(lr) - particles_noise.mul_(sqrt_lr) - particles.add_(particles_diff).add_(particles_noise) - if bounds is not None: - particles.clamp_(bounds[0], bounds[1]) - if keep_all: - trajectories[idx].copy_(particles) - - return trajectories.detach() if keep_all else particles.unsqueeze_(0) + else: + self.rng = rng_or_seed + if self.rng.device != self.device: + raise ValueError( + f"Expected rng to live on device {self.device}, got {self.rng.device}" + ) + + def initialize( + self, init_particles: Tensor, target: GaussianModel, target_args: Any + ) -> tuple[None, None]: + return None, None + + def step( + self, + lr: float, + particles: Tensor, + target: GaussianModel, + algorithm_args: None, + target_args: Any, + ) -> tuple[Tensor, None, None]: + forward_model, likelihood_precision, prior_precision, true_obs, prior_mean = ( + astuple(target) + ) + forecast_observations = forward_model(particles, target_args) + particles_diff, particles_noise = gradfree_aldi_step( + particles, + forecast_observations, + prior_mean, + likelihood_precision, + prior_precision, + true_obs, + self.rng, + ) + sqrt_lr = self.sqrt_scalar(2 * lr) + new_particles = ( + particles_diff.mul_(lr).add_(particles).add_(particles_noise.mul_(sqrt_lr)) + ) + return new_particles, None, None diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index 5cc4e32..54049f6 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -8,7 +8,7 @@ from nak_torch.tools.util import initialize_particles from nak_torch.tools.types import ( - BatchTargetEvaluator, + NAKTarget, ) from nak_torch.tools.func import ( @@ -19,7 +19,7 @@ def nak( - target: BatchTargetEvaluator, + target: NAKTarget, algorithm: GeneralAdaptiveNAKAlgorithm, n_steps: int, lr: float, @@ -71,10 +71,10 @@ def nak( trajectories = torch.empty(()) traj_wts = torch.empty(()) - for idx in tqdm(range(n_steps), disable=not verbose): + for idx in tqdm(range(n_steps - 1), disable=not verbose): if keep_all: trajectories[idx + 1].copy_(particles) - if algorithm.is_weighted() and keep_all: + if algorithm.is_weighted(): traj_wts[idx + 1].copy_(particle_wts) particles, particle_wts, algorithm_args = algorithm.step( @@ -84,7 +84,11 @@ def nak( if bounds is not None: particles.clamp_(bounds[0], bounds[1]) - if not keep_all: + if keep_all: + trajectories[-1].copy_(particles) + if algorithm.is_weighted(): + traj_wts[-1].copy_(particle_wts) + else: trajectories = particles.unsqueeze_(0) if algorithm.is_weighted(): traj_wts = particle_wts.unsqueeze_(0) diff --git a/src/nak_torch/algorithms/msip/estimators.py b/src/nak_torch/algorithms/msip/estimators.py index 8ded7c6..21864d9 100644 --- a/src/nak_torch/algorithms/msip/estimators.py +++ b/src/nak_torch/algorithms/msip/estimators.py @@ -6,7 +6,7 @@ BatchLogDensityGradVal, BatchLogDensity, BatchQuadratureRule, - BatchTargetEvaluator, + NAKTarget, ) from jaxtyping import Float from torch import Tensor @@ -14,7 +14,7 @@ __all__ = ["MSIPFredholm", "MSIPQuadGradientFree", "MSIPQuadGradientInformed"] -class MSIPEstimator(BatchTargetEvaluator[MSIPEstimatorOutput]): +class MSIPEstimator(NAKTarget[MSIPEstimatorOutput]): @abstractmethod def __call__(self, particles, evaluator_args, target_args) -> MSIPEstimatorOutput: r""" @@ -24,7 +24,7 @@ def __call__(self, particles, evaluator_args, target_args) -> MSIPEstimatorOutpu \sigma^2 \nabla \log v_0(y) = \frac{v_1(y)}{v_0(y)} - y. $$ """ - pass + ... class MSIPFredholm(MSIPEstimator): @@ -71,6 +71,7 @@ def __call__(self, particles, kernel_length_scale, target_args): log_dens_evals = self.log_dens( particle_quad_pts.reshape(-1, dim), target_args ).reshape(n_particles, -1) + sigma_sq_score_v0, log_v0 = vmap_recursive_weighted_average_alpha_v( quad_pts, quad_wts, log_dens_evals ) diff --git a/src/nak_torch/algorithms/msip/msip.py b/src/nak_torch/algorithms/msip/msip.py index 84d07e6..6b7f992 100644 --- a/src/nak_torch/algorithms/msip/msip.py +++ b/src/nak_torch/algorithms/msip/msip.py @@ -20,7 +20,7 @@ def initialize(self, init_particles, target, target_args): def step(self, lr, particles, target, algorithm_args, target_args): kernel_lengthscale, kernel_matrix, estimator_output = astuple(algorithm_args) - kernel_matrix_inverse = torch.linalg.pinv(kernel_matrix) + kernel_matrix_inverse = torch.linalg.pinv(kernel_matrix, hermitian=True) # Update the particles particles_diff = msip_map( @@ -29,7 +29,7 @@ def step(self, lr, particles, target, algorithm_args, target_args): kernel_matrix_inverse, output_idx=None, ) - new_particles = particles.mul(1 - lr).add_(particles_diff.mul_(lr)) + new_particles = particles_diff.mul_(lr).add_(particles.mul(1 - lr)) # Update the parameters kernel_lengthscale = self.get_adaptive_lengthscale(new_particles) diff --git a/src/nak_torch/algorithms/svgd.py b/src/nak_torch/algorithms/svgd.py index 18fe578..bde255d 100644 --- a/src/nak_torch/algorithms/svgd.py +++ b/src/nak_torch/algorithms/svgd.py @@ -105,7 +105,7 @@ def initialize(self, init_particles, target, target_args): def step(self, lr, particles, target, algorithm_args, target_args): (kernel_lengthscale,) = astuple(algorithm_args) - grad_log_dens_eval = target(particles, None, target_args) + grad_log_dens_eval = target(particles, target_args) particles_diff = svgd_step( self.kernel_grad_val, particles, grad_log_dens_eval, kernel_lengthscale ) diff --git a/src/nak_torch/tools/func.py b/src/nak_torch/tools/func.py index def9b43..5db6760 100644 --- a/src/nak_torch/tools/func.py +++ b/src/nak_torch/tools/func.py @@ -6,17 +6,15 @@ BatchPtType, BatchType, DeviceLike, - BatchTargetEvaluator, + NAKTarget, ) -BatchTargetEvaluatorT = TypeVar("BatchTargetEvaluatorT", bound=BatchTargetEvaluator) +NAKTargetT = TypeVar("NAKTargetT", bound=NAKTarget) AlgorithmArgsT = TypeVar("AlgorithmArgsT") WeightT = TypeVar("WeightT", bound=Optional[BatchType]) -class GeneralAdaptiveNAKAlgorithm( - ABC, Generic[BatchTargetEvaluatorT, WeightT, AlgorithmArgsT] -): +class GeneralAdaptiveNAKAlgorithm(ABC, Generic[NAKTargetT, WeightT, AlgorithmArgsT]): dim: int n_particles: int device: Optional[DeviceLike] @@ -44,7 +42,7 @@ def __init__( def initialize( self, init_particles: BatchPtType, - target: BatchTargetEvaluatorT, + target: NAKTargetT, target_args: Any, ) -> tuple[WeightT, AlgorithmArgsT]: pass @@ -54,7 +52,7 @@ def step( self, lr: float, particles: BatchPtType, - target: BatchTargetEvaluatorT, + target: NAKTargetT, algorithm_args: AlgorithmArgsT, target_args: Any, ) -> tuple[BatchPtType, WeightT, AlgorithmArgsT]: @@ -67,7 +65,7 @@ def is_weighted(cls) -> bool: class UnweightedAdaptiveNAKAlgorithm( - GeneralAdaptiveNAKAlgorithm[BatchTargetEvaluatorT, None, AlgorithmArgsT] + GeneralAdaptiveNAKAlgorithm[NAKTargetT, None, AlgorithmArgsT] ): @classmethod def is_weighted(cls) -> bool: @@ -75,7 +73,7 @@ def is_weighted(cls) -> bool: class WeightedAdaptiveNAKAlgorithm( - GeneralAdaptiveNAKAlgorithm[BatchTargetEvaluatorT, BatchType, AlgorithmArgsT] + GeneralAdaptiveNAKAlgorithm[NAKTargetT, BatchType, AlgorithmArgsT] ): @classmethod def is_weighted(cls) -> bool: diff --git a/src/nak_torch/tools/metrics.py b/src/nak_torch/tools/metrics.py index ad6dcc7..df6d5db 100644 --- a/src/nak_torch/tools/metrics.py +++ b/src/nak_torch/tools/metrics.py @@ -143,6 +143,7 @@ def __init__( self, grad_log_dens: AnyLogDensGrad, kernel_length_scale: float, + target_args=None, kernel_elem=None, is_grad_vectorized: bool = True, use_compiled: bool = False, @@ -153,6 +154,7 @@ def __init__( self.stein_kernel_mat = stein_kernel_mat_factory( grad_log_dens, kernel_elem, + target_args=target_args, is_grad_vectorized=is_grad_vectorized, use_compiled=use_compiled, ) diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index 76a14cd..b115217 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -10,6 +10,7 @@ BatchType = Float[Tensor, "batch"] PtType = Float[Tensor, " d"] +CovType = Float[Tensor, "d d"] BatchPtType = Float[Tensor, "batch d"] QuadrulePtType = Float[Tensor, "quad d"] QuadruleWtType = Float[Tensor, "quad"] @@ -30,11 +31,12 @@ EvaluatorOutputT = TypeVar("EvaluatorOutputT") -class BatchTargetEvaluator(ABC, Generic[EvaluatorOutputT]): +class NAKTarget(ABC, Generic[EvaluatorOutputT]): ... + + +class BatchTargetEvaluator(NAKTarget[EvaluatorOutputT]): @abstractmethod - def __call__( - self, particles: BatchPtType, evaluator_args, target_args - ) -> EvaluatorOutputT: + def __call__(self, particles: BatchPtType, target_args) -> EvaluatorOutputT: pass @@ -76,7 +78,7 @@ def __init__(self, log_density: LogDensity | BatchLogDensity, is_batched: bool): log_density = torch.vmap(log_density, in_dims=(0, None)) self.log_density = log_density - def __call__(self, pts, _, target_args): + def __call__(self, pts, target_args): return self.log_density(pts, target_args) @@ -104,12 +106,12 @@ def __init__( log_density_or_grad = torch.func.grad(log_density_or_grad) self.grad_log_density = torch.vmap(log_density_or_grad, in_dims=(0, None)) - def __call__(self, pts, _, target_args): + def __call__(self, pts, target_args): return self.grad_log_density(pts, target_args) @dataclass -class GaussianModel: +class GaussianModel(NAKTarget): forward_model: BatchForwardModel likelihood_precision: float | Float[Tensor, "obs obs"] prior_precision: float | Float[Tensor, "dim dim"] From 4b8b6aa0cdc174ee2921ceb5b0ac81eb61100ab8 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 24 Apr 2026 15:40:51 -0400 Subject: [PATCH 19/60] Fix EKS --- examples/gaussian.py | 61 +++--- src/nak_torch/algorithms/__init__.py | 4 +- src/nak_torch/algorithms/eks.py | 228 +++++++++++----------- src/nak_torch/algorithms/gradfree_aldi.py | 4 +- src/nak_torch/tools/func.py | 5 +- 5 files changed, 154 insertions(+), 148 deletions(-) diff --git a/examples/gaussian.py b/examples/gaussian.py index 5039b24..ebd9c31 100644 --- a/examples/gaussian.py +++ b/examples/gaussian.py @@ -8,8 +8,8 @@ from torch import Tensor import nak_torch -from nak_torch.algorithms import eks, kfrflow -from nak_torch.algorithms import SVGD, MSIP, MSIPGS, GradALDI, CBS, GradFreeALDI +from nak_torch.algorithms import kfrflow +from nak_torch.algorithms import SVGD, MSIP, MSIPGS, GradALDI, CBS, GradFreeALDI, EKS from tqdm import tqdm from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientInformed, MSIPQuadGradientFree from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre @@ -101,17 +101,21 @@ def _tmp_post_log_dens(pts, args): # %% n_steps, n_particles = 1000, 500 -lr = 0.1 +lr = 1e-2 bounds = (-100., 100.) -init_particles = torch.randn((n_particles, 2)) / \ +rng = torch.Generator(device=torch.get_default_device()) +rng.manual_seed(0) + +init_particles = torch.randn((n_particles, 2), generator=rng) / \ model.prior_precision + model.prior_mean # %% -trajectories_eks = eks( - model, n_particles=n_particles, - n_steps=n_steps, dim=2, lr=lr, - init_particles=init_particles, keep_all=False, +eks = EKS(dim=2, n_particles=n_particles, rng_or_seed=rng) +trajectories_eks = nak_torch.nak( + model, eks, n_steps=n_steps, lr=lr, + rng_or_seed=rng, target_args=None, bounds=bounds, + init_particles=init_particles ) # %% @@ -159,10 +163,6 @@ def _tmp_post_log_dens(pts, args): # particles += torch.einsum("jid,i->jd", grad_ks, torch.linalg.solve(M_t, kernelized_wts)).mul_(delta_t) -# %% -rng = torch.Generator() -rng.manual_seed(0) - # %% grad_aldi = GradALDI(dim=2, n_particles=n_particles, rng = rng) grad_aldi_target = BatchGradLogDensityEvaluator(post_log_dens, is_grad=False, is_batched=True) @@ -187,17 +187,12 @@ def _tmp_post_log_dens(pts, args): # %% cbs_target = BatchLogDensityEvaluator(post_log_dens, is_batched=True) -cbs = CBS(dim=2, n_particles=n_particles, default_inverse_temp=0.95, rng=rng) +cbs = CBS(dim=2, n_particles=n_particles, default_inverse_temp=0.5, rng=rng) trajectories_cbs = nak_torch.nak( - cbs_target, cbs, n_steps, lr, + cbs_target, cbs, 5000, lr=lr, rng_or_seed=rng, init_particles=init_particles, - target_args=None, bounds = bounds + target_args=None, bounds=bounds ) -# trajectories_cbs = cbs( -# post_log_dens, n_particles, n_steps, inverse_temp=0.95, dim=2, -# lr=lr, init_particles=init_particles, -# keep_all=True -# ) # %% kernel_lengthscale = 0.15 @@ -219,7 +214,8 @@ def _tmp_post_log_dens(pts, args): post_log_dens_grad_val_batch ) -# %% +msip.kernel_lengthscale_quantile = 0.01 + trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( msip_fredholm_target, msip, n_steps_msip, lr_msip, rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], @@ -246,9 +242,11 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): gradient_decay ) +msip.kernel_lengthscale_quantile = 0.5 +msip.kernel_diag_infl = 1e-1 trajectories_pts_msip_qg, trajectories_wts_msip_qg = nak_torch.nak( - msip_quadgrad_target, msip, n_steps_msip, lr_msip, + msip_quadgrad_target, msip, n_steps_msip, 1e-1, rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], target_args=model, keep_all=False, bounds=bounds ) @@ -257,20 +255,21 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): msip_quadgf_target = MSIPQuadGradientFree( post_log_dens, partial(spherical_quad, N_spherical=5, N_radial=4) ) +msip.kernel_lengthscale_quantile = 0.05 trajectories_pts_msip_qgf, trajectories_wts_msip_qgf = nak_torch.nak( - msip_quadgf_target, msip, 100, 8e-1, + msip_quadgf_target, msip, n_steps=n_steps_msip, lr=1e-1, rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], target_args=model, keep_all=False, bounds=bounds ) # %% -# pts_eks = trajectories_eks[-1] +pts_eks = trajectories_eks[-1] # pts_kfr = particles -# pts_galdi = trajectories_galdi[-1] +pts_galdi = trajectories_galdi[-1] pts_gfaldi = trajectories_gfaldi[-1] -# pts_cbs = trajectories_cbs[-1] +pts_cbs = trajectories_cbs[-1] idx_msip = -1 alpha_msip = 2/math.sqrt(n_particles_msip) pts_msip = trajectories_pts_msip_fr[idx_msip] @@ -296,13 +295,13 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): handles = [] # ax.scatter(samps[:, 0], samps[:, 1], alpha=0.025, label="Truth") # ax.scatter(pts_galdi[:, 0], pts_galdi[:, 1], alpha=0.2, label="Grad-ALDI") -ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], - alpha=0.2, label="GradFree-ALDI") +# ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], +# alpha=0.2, label="GradFree-ALDI") # ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") # ax.scatter(pts_eks[:, 0], pts_eks[:, 1], alpha=0.1, label="EKS") # ax.scatter(pts_cbs[:, 0], pts_cbs[:, 1], alpha=0.1, label="CBS") -# ax.scatter(pts_msip[:, 0], pts_msip[:, 1], alpha=alpha_msip, label="MSIP") -# handles.append(ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], alpha=alpha_msip, label="MSIP-QuadGrad")) +ax.scatter(pts_msip[:, 0], pts_msip[:, 1], alpha=alpha_msip, label="MSIP") +# ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], alpha=alpha_msip, label="MSIP-QuadGrad") # ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], # s = 50*wts_msip_qgf.abs()/wts_msip_qgf.max(), alpha=alpha_msip, label="MSIP-QuadGradFree") # plt.colorbar(s) @@ -312,6 +311,8 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): # ax.set_ylim(ygrid.min(), ygrid.max()) plt.show() +# %% + # %% print(f""" Covariances--- diff --git a/src/nak_torch/algorithms/__init__.py b/src/nak_torch/algorithms/__init__.py index f2eccde..ca00ac3 100644 --- a/src/nak_torch/algorithms/__init__.py +++ b/src/nak_torch/algorithms/__init__.py @@ -7,7 +7,7 @@ # Ayoub Belhadji # 05/12/2025 -from .eks import eks +from .eks import EKS from .msip import MSIP, MSIPGS from .svgd import SVGD from .deepensembles import deepensembles @@ -27,7 +27,7 @@ "GradALDI", # "gradfree_aldi", "GradFreeALDI", - "eks", + "EKS", "CBS", "kfrflow", ] diff --git a/src/nak_torch/algorithms/eks.py b/src/nak_torch/algorithms/eks.py index 8dcc9ed..a13ff7c 100644 --- a/src/nak_torch/algorithms/eks.py +++ b/src/nak_torch/algorithms/eks.py @@ -1,129 +1,131 @@ +from dataclasses import astuple + import torch -from typing import Optional +from typing import Any, Optional from jaxtyping import Float from torch import Tensor -from nak_torch.tools.types import BatchPtType, GaussianModel -import warnings -from tqdm import tqdm -import numpy as np -from nak_torch.tools.util import sym_sqrtm, initialize_particles +from nak_torch.tools.func import UnweightedAdaptiveNAKAlgorithm +from nak_torch.tools.types import ( + BatchPtType, + CovType, + DeviceLike, + GaussianModel, + PtType, +) +from nak_torch.tools.util import sym_sqrtm + +__all__ = ["EKS"] -def build_eks_step( - eks_model: GaussianModel, +def eks_step( + particles: BatchPtType, + forecast_observations: Float[Tensor, "batch obs"], + prior_mean: PtType, + likelihood_precision: CovType, + prior_precision: CovType, + true_observation: Float[Tensor, " obs"], dt: float, - device: Optional[torch.device], - compile_step: bool, -): - likelihood_precision = torch.as_tensor( - eks_model.likelihood_precision, device=device + rng: torch.Generator, +) -> BatchPtType: + device, dtype = particles.device, particles.dtype + N_batch, dim = particles.shape + particle_mean = particles.mean(0, True) + forecast_obs_mean = forecast_observations.mean(0, True) + obs_diff = forecast_observations - true_observation + forecast_diff = forecast_observations - forecast_obs_mean + prior_ens_diff = particles - particle_mean + if prior_mean != 0.0: + prior_ens_diff -= prior_mean + cov_forecast = (prior_ens_diff.T @ prior_ens_diff) / N_batch + + if isinstance(likelihood_precision, float) or likelihood_precision.numel() == 1: + likely_term = torch.einsum("ko,jo,kd->jd", forecast_diff, obs_diff, particles) + likely_term.mul_(dt * likelihood_precision / N_batch) + else: + likely_term = torch.einsum( + "kp,pq,jq,kd->jd", + forecast_diff, + likelihood_precision, + obs_diff, + particles, + ) + likely_term.mul_(dt / N_batch) + # INPLACE + cov_forecast.mul_(dt) + sqrt_prior_cov = sym_sqrtm(cov_forecast) + sqrt_2 = torch.as_tensor(2.0, device=device, dtype=dtype).sqrt() + sqrt_prior_cov.mul_(sqrt_2) + if isinstance(prior_precision, float) or prior_precision.numel() == 1: + prior_term_premul = cov_forecast.mul_(prior_precision) + elif isinstance(prior_precision, Tensor): + prior_term_premul = torch.matmul(cov_forecast, prior_precision) + else: + raise ValueError() + + prior_term_premul.add_(torch.eye(dim, device=device)) + new_particles: BatchPtType = torch.linalg.solve( + prior_term_premul, particles - likely_term, left=False ) - prior_mean = torch.as_tensor(eks_model.prior_mean, device=device) - prior_precision = torch.as_tensor(eks_model.prior_precision, device=device) - true_obs = torch.as_tensor(eks_model.true_obs, device=device) - if isinstance(true_obs, Tensor): - true_obs.reshape(1, -1) + noise_tens = torch.randn(particles.shape, generator=rng) + noise_samp = noise_tens @ sqrt_prior_cov + return new_particles.add_(noise_samp) - sqrt_2 = torch.sqrt(torch.tensor(2, dtype=true_obs.dtype, device=device)) - def eks_step( - particles: BatchPtType, forecast_observations: Float[Tensor, "batch obs"] - ) -> tuple[BatchPtType, Float[Tensor, "dim dim"]]: - N_batch, dim = particles.shape - particle_mean = particles.mean(0, True) - forecast_obs_mean = forecast_observations.mean(0, True) - obs_diff = forecast_observations - true_obs - forecast_diff = forecast_observations - forecast_obs_mean - prior_ens_diff = particles - particle_mean - if prior_mean != 0.0: - prior_ens_diff -= prior_mean - cov_forecast = (prior_ens_diff.T @ prior_ens_diff) / N_batch +class EKS(UnweightedAdaptiveNAKAlgorithm[GaussianModel, None]): + rng: torch.Generator - if isinstance(likelihood_precision, float) or likelihood_precision.numel() == 1: - likely_term = torch.einsum( - "ko,jo,kd->jd", forecast_diff, obs_diff, particles + def __init__( + self, + dim: int, + n_particles: int, + device: Optional[DeviceLike] = None, + dtype: Optional[torch.dtype] = None, + *_, + rng_or_seed: Optional[torch.Generator | int] = None, + **kwargs, + ): + super().__init__(dim, n_particles, device, dtype, **kwargs) + if isinstance(rng_or_seed, int): + self.rng = torch.Generator(self.device).set_state( + torch.default_generator.get_state() ) - likely_term.mul_(dt * likelihood_precision / N_batch) - else: - likely_term = torch.einsum( - "kp,pq,jq,kd->jd", - forecast_diff, - likelihood_precision, - obs_diff, - particles, + self.rng.manual_seed(rng_or_seed) + elif rng_or_seed is None: + self.rng = torch.Generator(self.device).set_state( + torch.default_generator.get_state() ) - likely_term.mul_(dt / N_batch) - # INPLACE - cov_forecast.mul_(dt) - sqrt_prior_cov = sym_sqrtm(cov_forecast) - sqrt_prior_cov.mul_(sqrt_2) - if isinstance(prior_precision, float) or prior_precision.numel() == 1: - prior_term_premul = cov_forecast.mul_(prior_precision) - elif isinstance(prior_precision, Tensor): - prior_term_premul = torch.matmul(cov_forecast, prior_precision) else: - raise ValueError() - - prior_term_premul.add_(torch.eye(dim, device=device)) - new_particles = torch.linalg.solve( - prior_term_premul, particles - likely_term, left=False - ) - return new_particles, sqrt_prior_cov + self.rng = rng_or_seed + if self.rng.device != self.device: + raise ValueError( + f"Expected rng to live on device {self.device}, got {self.rng.device}" + ) - return torch.compile(eks_step) if compile_step else eks_step + def initialize( + self, init_particles: Tensor, target: GaussianModel, target_args: Any + ) -> tuple[None, None]: + return None, None - -def eks( - eks_model: GaussianModel, - n_particles: int, - n_steps: int, - dim: int, - lr: float, - noise=None, - seed=None, - device=None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - bounds: Optional[tuple[float, float]] = None, - keep_all: bool = True, - rng: Optional[torch.Generator] = None, - verbose: bool = False, - compile_step: bool = True, - **unused_kwargs, -): - if verbose and len(unused_kwargs) > 0: - warnings.warn("Unused kwargs:\n{}".format(unused_kwargs)) - - if rng is None: - rng = torch.default_generator - if seed is not None: - rng.manual_seed(seed) - - particles = initialize_particles( - n_particles, dim, init_particles, device, bounds, rng - ) - - if keep_all: - trajectories = torch.empty( - (n_steps, *particles.shape), device=device, dtype=particles.dtype + def step( + self, + lr: float, + particles: Tensor, + target: GaussianModel, + algorithm_args: None, + target_args: Any, + ) -> tuple[Tensor, None, None]: + forward_model, likelihood_precision, prior_precision, true_obs, prior_mean = ( + astuple(target) ) - trajectories[0].copy_(particles) - else: - trajectories = torch.empty(()) - - eks_step = build_eks_step(eks_model, lr, device, compile_step) - noise_tens = torch.empty_like(particles) - for idx in tqdm(range(n_steps), disable=not verbose): - forecast_obs = eks_model.forward_model(particles) - with torch.no_grad(): - particles, noise_sqrt_cov = eks_step(particles, forecast_obs) - noise_tens = torch.normal( - mean=0.0, std=1.0, size=particles.shape, generator=rng, out=noise_tens - ) - noise_samp = noise_tens @ noise_sqrt_cov - particles = particles.add_(noise_samp) - if bounds is not None: - particles.clamp_(bounds[0], bounds[1]) - if keep_all: - trajectories[idx].copy_(particles) - - return trajectories.detach() if keep_all else particles.unsqueeze_(0) + forecast_observations = forward_model(particles, target_args) + new_particles = eks_step( + particles, + forecast_observations, + prior_mean, + likelihood_precision, + prior_precision, + true_obs, + lr, + self.rng, + ) + return new_particles, None, algorithm_args diff --git a/src/nak_torch/algorithms/gradfree_aldi.py b/src/nak_torch/algorithms/gradfree_aldi.py index 2ec5e37..b83c497 100644 --- a/src/nak_torch/algorithms/gradfree_aldi.py +++ b/src/nak_torch/algorithms/gradfree_aldi.py @@ -36,7 +36,7 @@ def gradfree_aldi_step( prior_precision: CovType, true_observation: Float[Tensor, " obs"], rng: torch.Generator, -) -> tuple[BatchPtType, Float[Tensor, "dim dim"]]: +) -> tuple[BatchPtType, BatchPtType]: N_batch, dim = particles.shape particle_mean = particles.mean(dim=0, keepdim=True) @@ -137,4 +137,4 @@ def step( new_particles = ( particles_diff.mul_(lr).add_(particles).add_(particles_noise.mul_(sqrt_lr)) ) - return new_particles, None, None + return new_particles, None, algorithm_args diff --git a/src/nak_torch/tools/func.py b/src/nak_torch/tools/func.py index 5db6760..7692f5b 100644 --- a/src/nak_torch/tools/func.py +++ b/src/nak_torch/tools/func.py @@ -32,7 +32,10 @@ def __init__( ): self.dim = dim self.n_particles = n_particles - self.device = device + if device is None: + self.device = torch.get_default_device() + else: + self.device = device self.dtype = dtype self.verbose = verbose if verbose and len(kwargs) > 0: From 66e84b0e42b1845b75ec3962ad3041653a33a625 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 24 Apr 2026 18:42:27 -0400 Subject: [PATCH 20/60] Remove redundant gaussian example --- examples/gaussian_mvp.py | 426 --------------------------------------- 1 file changed, 426 deletions(-) delete mode 100644 examples/gaussian_mvp.py diff --git a/examples/gaussian_mvp.py b/examples/gaussian_mvp.py deleted file mode 100644 index 79df5de..0000000 --- a/examples/gaussian_mvp.py +++ /dev/null @@ -1,426 +0,0 @@ -# %% -from functools import partial - -import matplotlib.pyplot as plt -import torch -from torch import Tensor -from jaxtyping import Float - -import nak_torch -from nak_torch.algorithms import grad_aldi, eks, gradfree_aldi, cbs, msip, kfrflow -from nak_torch.algorithms.msip import ( - MSIPFredholm, - MSIPQuadGradientInformed, - MSIPQuadGradientFree, -) -from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre - -from pyro.infer import mcmc -import pyro_tools - -if torch.cuda.is_available(): - torch.set_default_device("cuda") -else: - torch.set_default_device("cpu") - -torch.set_default_dtype(torch.float64) - - -# %% -def make_gaussian_post( - forward_op: Float[Tensor, "obs dim"], - mean_pr: Float[Tensor, " dim"], - cov_pr: Float[Tensor, "dim dim"], - mean_li: Float[Tensor, " obs"], - cov_li: Float[Tensor, "obs obs"], -): - forward_op = forward_op.T - cov_post = torch.linalg.inv( - forward_op.T @ torch.linalg.solve(cov_li, forward_op) + torch.linalg.inv(cov_pr) - ) - mean_post = cov_post @ ( - forward_op.T @ torch.linalg.solve(cov_li, mean_li) - + torch.linalg.solve(cov_pr, mean_pr) - ) - return mean_post, cov_post - - -def weighted_cov(pts: Tensor, wts: Tensor): - mean = wts @ pts - second_moment = torch.einsum("b,bi,bj", wts, pts, pts) - return second_moment - mean.outer(mean) - - -# %% Everything related to the definition of the distribution -torch.manual_seed(1023921) -obs_op = torch.randn(2, 5) -obs_op.div_(obs_op.norm(dim=1, keepdim=True)) - - -# forward_model = torch.compile(lambda particles: particles @ obs_op) -def forward_model(particles, _obs_op=obs_op): - return particles @ _obs_op - - -true_obs = torch.tensor([1.0, 2.0, 3.0, 2.0, 1.0]) + 20 - -model = nak_torch.GaussianModel( - forward_model, - likelihood_precision=10.0, - prior_precision=0.9, - true_obs=true_obs, - is_vectorized=True, -) - - -# @torch.compile -def like_log_dens(pt): - ll_term = ( - model.likelihood_precision - * torch.linalg.norm(pt @ obs_op - model.true_obs, dim=-1) ** 2 - ) - return -0.5 * ll_term.squeeze() - - -# @torch.compile -def post_log_dens(pt): - ll_term = ( - model.likelihood_precision - * torch.linalg.norm(pt @ obs_op - model.true_obs, dim=-1) ** 2 - ) - prior_term = model.prior_precision * torch.linalg.norm(pt, dim=-1) ** 2 - return -0.5 * (ll_term + prior_term).squeeze() - - -post_log_dens_batch = torch.vmap(post_log_dens) -post_log_dens_grad_val = torch.func.grad_and_value(post_log_dens) -post_log_dens_grad_val_batch = torch.vmap(post_log_dens_grad_val) - -# %% -mean_pr, cov_pr = torch.zeros(2), torch.eye(2) / model.prior_precision -mean_li, cov_li = ( - model.true_obs, - torch.eye(len(model.true_obs)) / model.likelihood_precision, -) - -mean_post, cov_post = make_gaussian_post(obs_op, mean_pr, cov_pr, mean_li, cov_li) -vals, vecs = torch.linalg.eigh(cov_post) -cov_post_sqrt = vecs @ torch.diag(torch.sqrt(vals)) @ vecs.T -samps = torch.randn(10000, 2) @ cov_post_sqrt + mean_post - -# %% Parameters that are common to all algorithms -n_steps, n_particles = 50000, 50 -lr = 0.1 - -# %% Initialization -init_particles = torch.randn((n_particles, 2)) / model.prior_precision + torch.tensor( - [3.2, -5.0] -) - -# init_particles = torch.randn((n_particles, 2)) + torch.tensor([3, -3]) -# torch.randn((n_particles_kfr,2)) + torch.tensor([3,-5]) -# %% EKS -trajectories_eks = eks( - model, - n_particles=n_particles, - n_steps=n_steps, - dim=2, - lr=lr, - init_particles=init_particles, - keep_all=False, - compile_step=False, - verbose=True, -) - -# %% -n_steps_hmc = n_steps - 100 -pyro_model = pyro_tools.PyroModel(model, 2) -hmc_kernel = mcmc.NUTS(pyro_model) -mcmc_setup = mcmc.MCMC(hmc_kernel, num_samples=n_steps_hmc, warmup_steps=100) -mcmc_setup.run(model.true_obs) - -# %% -hmc_samples = mcmc_setup.get_samples()["theta"] - -plt.scatter(hmc_samples[:, 0], hmc_samples[:, 1]) - -# %% KFR - -# delta_ts = torch.ones(1000)/1000 -# def imq(pt1,pt2,h): -# return 1/torch.sqrt(1 + (torch.linalg.norm(pt1-pt2) / h)**2) - -trajectories_kfr = kfrflow( - like_log_dens, - n_particles, - n_steps, - 2, - init_particles=init_particles, - kernel_length_scale=1e-2, - kernel_diag_infl=1e-5, - # bounds=(-10,10), - # kernel_elem=imq, - keep_all=False, - compile_step=False, - verbose=True, -) - - -# %% GI-ALDI - -trajectories_galdi = grad_aldi( - post_log_dens, - n_particles, - n_steps, - dim=2, - lr=lr, - init_particles=init_particles, - keep_all=False, - compile_step=False, - verbose=True, -) - -# %% GF-ALDI - -trajectories_gfaldi = gradfree_aldi( - model, - n_particles, - n_steps, - dim=2, - lr=lr, - init_particles=init_particles, - keep_all=True, - compile_step=False, - verbose=True, -) - -# %% CBS - -trajectories_cbs = cbs( - post_log_dens, - n_particles, - n_steps, - inverse_temp=0.95, - dim=2, - lr=lr, - init_particles=init_particles, - keep_all=True, - compile_step=False, - verbose=True, -) - -# %% F-MSIP - -kernel_length_scale = 0.03 -bounds = (-100.0, 100.0) -gradient_decay = 1.0 -lr_msip = 100e-2 -kernel_diag_infl = 1e-8 -msip_fredholm = MSIPFredholm(gradient_decay, post_log_dens_grad_val_batch) - -trajectories_msip, traj_wts_msip = msip( - msip_fredholm, - n_particles, - n_steps, - dim=2, - lr=lr_msip, - init_particles=init_particles[:n_particles], - kernel_length_scale=kernel_length_scale, - is_log_density_batched=True, - kernel_diag_infl=kernel_diag_infl, - bounds=bounds, - gradient_decay=gradient_decay, - keep_all=True, - compile_step=False, - verbose=True, -) - -# %% - - -def mc_quad_rule(batch_size: int, N_quad: int = 5, dim: int = 2): - pts = torch.randn((batch_size, N_quad, dim)) - wts = torch.ones((batch_size, N_quad)).div_(N_quad) - return pts, wts - - -def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): - pts, wts = spherical_MC_radial_Laguerre(batch_size, N_spherical, 2, N_radial) - return pts, wts - - -# %% -# kernel_length_scale = 1e-3 -# gradient_decay = 1. -msip_quadgrad = MSIPQuadGradientInformed( - post_log_dens_grad_val_batch, mc_quad_rule, gradient_decay -) - -trajectories_msip_qg, traj_wts_msip_qg = msip( - msip_quadgrad, - n_particles, - n_steps, - dim=2, - lr=10.0, - init_particles=init_particles[:n_particles], - kernel_length_scale=kernel_length_scale, - # is_log_density_batched=True, - kernel_diag_infl=1e-8, - bounds=(-1000, 1000), - # gradient_decay=gradient_decay, - keep_all=False, - compile_step=False, - verbose=True, -) - -# %% -# n_particles_msip = 500 -# kernel_length_scale = 1e-2 -msip_quadgf = MSIPQuadGradientFree( - post_log_dens_batch, partial(mc_quad_rule, N_quad=100) -) - -trajectories_msip_qgf, traj_wts_msip_qgf = msip( - msip_quadgf, - n_particles, - n_steps, - dim=2, - lr=1.0, - init_particles=init_particles[:n_particles], - kernel_length_scale=kernel_length_scale, - kernel_diag_infl=1e-8, - bounds=(-1000.0, 1000.0), - keep_all=False, - compile_step=False, - verbose=True, -) - - -# %% -pts_eks = trajectories_eks[-1] -# pts_kfr = particles_kfr -pts_galdi = trajectories_galdi[-1] -pts_gfaldi = trajectories_gfaldi[-1] -pts_cbs = trajectories_cbs[-1] -idx_msip = 100 -pts_msip = trajectories_msip[idx_msip] -wts_msip = traj_wts_msip[idx_msip] -# wts_msip /= wts_msip.sum() -pts_msip_qg = trajectories_msip_qg[-1] -wts_msip_qg = traj_wts_msip_qg[-1] -wts_msip_qg = wts_msip_qg / wts_msip_qg.sum() -pts_msip_qgf = trajectories_msip_qgf[-1] -wts_msip_qgf = traj_wts_msip_qgf[-1] -# wts_msip_qgf = wts_msip_qgf/wts_msip_qgf.sum() - -Ngrid = 100 -xgrid = torch.linspace(-1, 1, Ngrid) -xgrid = 3 * xgrid * cov_post_sqrt[0, 0] + mean_post[0] -ygrid = torch.linspace(-1, 1, Ngrid) -ygrid = 3 * ygrid * cov_post_sqrt[1, 1] + mean_post[1] -X, Y = torch.meshgrid(xgrid, ygrid, indexing="ij") -grid_pts = torch.stack((X.flatten(), Y.flatten()), 1) - -# fig, ax = plt.subplots() -# ax.contour(X, Y, post_log_dens(grid_pts).reshape(Ngrid, Ngrid), levels=10) -# # ax.scatter(samps[:, 0], samps[:, 1], alpha=0.025, label="Truth") -# # ax.scatter(pts_galdi[:, 0], pts_galdi[:, 1], alpha=0.2, label="Grad-ALDI") -# # ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], -# # alpha=0.2, label="GradFree-ALDI") -# ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") -# # ax.scatter(pts_eks[:, 0], pts_eks[:, 1], alpha=0.1, label="EKS") -# # ax.scatter(pts_cbs[:, 0], pts_cbs[:, 1], alpha=0.1, label="CBS") -# # s = ax.scatter(pts_msip[:, 0], pts_msip[:, 1], -# # c=wts_msip, alpha=0.15, label="MSIP") -# # s = ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], -# # c = wts_msip_qg, alpha=0.15, label="MSIP-QuadGrad") -# # s = ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], -# # c = wts_msip_qgf, alpha=0.15, label="MSIP-QuadGradFree") -# # plt.colorbar(s) -# ax.set_aspect(1.0) -# ax.legend() -# plt.show() - - -# fig, ax = plt.subplots() -# ax.contour(X, Y, post_log_dens(grid_pts).reshape(Ngrid, Ngrid), levels=10) -# # ax.scatter(samps[:, 0], samps[:, 1], alpha=0.025, label="Truth") -# # ax.scatter(pts_galdi[:, 0], pts_galdi[:, 1], alpha=0.2, label="Grad-ALDI") -# # ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], -# # alpha=0.2, label="GradFree-ALDI") -# #ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") -# # ax.scatter(pts_eks[:, 0], pts_eks[:, 1], alpha=0.1, label="EKS") -# # ax.scatter(pts_cbs[:, 0], pts_cbs[:, 1], alpha=0.1, label="CBS") -# # s = ax.scatter(pts_msip[:, 0], pts_msip[:, 1], -# # c=wts_msip, alpha=0.15, label="MSIP") -# s = ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], -# c = wts_msip_qg, alpha=0.15, label="MSIP-QuadGrad") -# # s = ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], -# # c = wts_msip_qgf, alpha=0.15, label="MSIP-QuadGradFree") -# # plt.colorbar(s) -# ax.set_aspect(1.0) -# ax.legend() -# plt.show() - - -fig, ax = plt.subplots() -ax.contour(X, Y, post_log_dens(grid_pts).reshape(Ngrid, Ngrid), levels=10) -# ax.scatter(samps[:, 0], samps[:, 1], alpha=0.025, label="Truth") -# ax.scatter(pts_galdi[:, 0], pts_galdi[:, 1], alpha=0.2, label="Grad-ALDI") -# ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], -# alpha=0.2, label="GradFree-ALDI") -# ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") -# ax.scatter(pts_eks[:, 0], pts_eks[:, 1], alpha=0.1, label="EKS") -# ax.scatter(pts_cbs[:, 0], pts_cbs[:, 1], alpha=0.1, label="CBS") -s = ax.scatter(pts_msip[:, 0], pts_msip[:, 1], c=wts_msip, alpha=0.15, label="MSIP") -# s = ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], -# c = wts_msip_qg, alpha=0.15, label="MSIP-QuadGrad") -# s = ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], -# c = wts_msip_qgf, alpha=0.15, label="MSIP-QuadGradFree") -# plt.colorbar(s) -ax.set_aspect(1.0) -ax.legend() -plt.show() - - -fig, ax = plt.subplots() -ax.contour(X, Y, post_log_dens(grid_pts).reshape(Ngrid, Ngrid), levels=10) -# ax.scatter(samps[:, 0], samps[:, 1], alpha=0.025, label="Truth") -ax.scatter(pts_galdi[:, 0], pts_galdi[:, 1], alpha=0.2, label="Grad-ALDI") -# ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], -# alpha=0.2, label="GradFree-ALDI") -# ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") -# ax.scatter(pts_eks[:, 0], pts_eks[:, 1], alpha=0.1, label="EKS") -# ax.scatter(pts_cbs[:, 0], pts_cbs[:, 1], alpha=0.1, label="CBS") -# s = ax.scatter(pts_msip[:, 0], pts_msip[:, 1], -# c=wts_msip, alpha=0.15, label="MSIP") -# s = ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], -# c = wts_msip_qg, alpha=0.15, label="MSIP-QuadGrad") -# s = ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], -# c = wts_msip_qgf, alpha=0.15, label="MSIP-QuadGradFree") -# plt.colorbar(s) -ax.set_aspect(1.0) -ax.legend() -plt.show() - - -# %% -print(f""" -Covariances--- -Truth: -{cov_post} -EKS: -{pts_eks.T.cov()} -Grad-ALDI: -{pts_galdi.T.cov()} -GradFree-ALDI: -{pts_gfaldi.T.cov()} -MSIP: -{weighted_cov(pts_msip, wts_msip)} -MSIP-QuadGrad: -{weighted_cov(pts_msip_qg, wts_msip_qg)} -MSIP-QuadGradFree: -{weighted_cov(pts_msip_qgf, wts_msip_qgf)} -""") - -# %% From 5b7f93aadf5b34783368d22264d8b55b2e251679 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 24 Apr 2026 18:42:37 -0400 Subject: [PATCH 21/60] Fix gaussianmodel --- src/nak_torch/tools/types.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index b115217..c6bb005 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -143,16 +143,13 @@ def to_log_dens(self, use_compiled: bool = True) -> BatchLogDensity: def log_dens(pts: BatchPtType, aux_args: Any) -> BatchType: model_eval = self.forward_model(pts, aux_args) obs_error = model_eval.sub_(self.true_obs) - like_term = torch.square(torch.linalg.norm(obs_error, dim=-1)).mul_( - self.likelihood_precision - ) - like_term.mul_(self.likelihood_precision) - prior_diff = pts + like_sq_norm = obs_error.square().sum(dim=-1) + like_term = like_sq_norm.mul_(self.likelihood_precision) + prior_diff = pts.clone() if self.prior_mean != 0.0: - prior_diff -= self.prior_mean - prior_term = torch.square(torch.linalg.norm(prior_diff, dim=-1)).mul_( - self.prior_precision - ) + prior_diff.sub_(self.prior_mean) + prior_sq_norm = prior_diff.square().sum(dim=-1) + prior_term = prior_sq_norm.mul_(self.prior_precision) return -0.5 * (prior_term + like_term) return torch.compile(log_dens) if use_compiled else log_dens From 0227b6e9596263018b1cae3500deef9302c6add8 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 24 Apr 2026 18:43:06 -0400 Subject: [PATCH 22/60] Fix sigma_sq --- src/nak_torch/algorithms/msip/estimators.py | 32 ++++++++++++--------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/src/nak_torch/algorithms/msip/estimators.py b/src/nak_torch/algorithms/msip/estimators.py index 21864d9..49ab25b 100644 --- a/src/nak_torch/algorithms/msip/estimators.py +++ b/src/nak_torch/algorithms/msip/estimators.py @@ -38,10 +38,11 @@ def __init__( self.gradient_decay = gradient_decay self.log_dens_grad_val = log_dens_grad_val - def __call__(self, particles, kernel_length_scale, target_args): - grads, v0 = self.log_dens_grad_val(particles, target_args) - sigma_sq_log_v0 = grads.mul_(kernel_length_scale * self.gradient_decay) - return v0, sigma_sq_log_v0 + def __call__(self, particles, kernel_lengthscale, target_args): + grads, log_v0 = self.log_dens_grad_val(particles, target_args) + sigma_sq = kernel_lengthscale * kernel_lengthscale + sigma_sq_log_v0 = grads.mul_(sigma_sq * self.gradient_decay) + return log_v0, sigma_sq_log_v0 vmap_recursive_weighted_average_alpha_v = torch.vmap( @@ -93,23 +94,28 @@ def __init__( self.quadrature, self.gradient_decay = quadrature, gradient_decay self.log_dens_grad_val = log_dens_grad_val - def __call__(self, particles, kernel_length_scale, target_args): - quad_pts, quad_wts = self.quadrature(particles.shape[0]) - particle_quad_pts = quad_pts.mul_(kernel_length_scale).add( + def __call__(self, particles, kernel_lengthscale, target_args): + n_particles, dim = particles.shape + quad_pts, quad_wts = self.quadrature(n_particles) + sigma_sq = kernel_lengthscale * kernel_lengthscale + quad_pts_correct_var = quad_pts.mul(kernel_lengthscale) + particle_quad_pts = quad_pts_correct_var.add( particles.unsqueeze(1) ) # (N_part, N_quad, dim) + log_dens_grads, log_dens_evals = self.log_dens_grad_val( - particle_quad_pts.reshape(-1, particles.shape[1]), target_args + particle_quad_pts.reshape(-1, dim), target_args ) log_dens_grads = log_dens_grads.reshape_as(particle_quad_pts) - log_dens_evals = log_dens_evals.reshape(particle_quad_pts.shape[:-1]) + log_dens_evals = log_dens_evals.reshape(n_particles, -1) + + v1_integrand_gf = (1 - self.gradient_decay) * quad_pts_correct_var + v1_integrand_gi = self.gradient_decay * (sigma_sq * log_dens_grads) + v1_integrand = v1_integrand_gf + v1_integrand_gi - v1_integrand = quad_pts.mul_(1 - self.gradient_decay).add_( - log_dens_grads.mul_(self.gradient_decay * kernel_length_scale) - ) # Note that previously multiplied particle_quad_pts by kernel_length_scale sigma_sq_score_v0, log_v0 = vmap_recursive_weighted_average_alpha_v( - v1_integrand, quad_wts, log_dens_evals + v1_integrand, quad_wts, log_v=log_dens_evals ) return log_v0, sigma_sq_score_v0 From ae35e0c46d33833215259ddcdcb824f67dc211cc Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 24 Apr 2026 18:43:30 -0400 Subject: [PATCH 23/60] Fix MSIP_GS --- src/nak_torch/algorithms/msip/msip_gs.py | 26 +++++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/nak_torch/algorithms/msip/msip_gs.py b/src/nak_torch/algorithms/msip/msip_gs.py index e0a7794..7e04262 100644 --- a/src/nak_torch/algorithms/msip/msip_gs.py +++ b/src/nak_torch/algorithms/msip/msip_gs.py @@ -2,6 +2,7 @@ import torch + from .msip_map import msip_map, get_msip_wts from .msip_tools import GeneralMSIPAlgorithm, MSIPGSAlgorithmArgs @@ -17,27 +18,32 @@ def initialize(self, init_particles, target, target_args): return wts, MSIPGSAlgorithmArgs(kernel_lengthscale, estimator_output) def step(self, lr, particles, target, algorithm_args, target_args): - kernel_lengthscale, _, estimator_output = astuple(algorithm_args) + kernel_lengthscale, estimator_output = astuple(algorithm_args) est_out_0, est_out_1 = estimator_output new_particles = particles.clone() - for i in range(particles.shape[0]): - km_i = self.get_infl_kernel_matrix(particles, kernel_lengthscale) - km_inv_i = torch.linalg.pinv(km_i) + kernel_matrix = self.get_infl_kernel_matrix(new_particles, kernel_lengthscale) + for i in range(new_particles.shape[0]): + km_inv_i = torch.linalg.pinv(kernel_matrix) est_out_i_0, est_out_i_1 = target( new_particles[i].unsqueeze(0), kernel_lengthscale, target_args ) est_out_0[i].copy_(est_out_i_0.squeeze()) est_out_1[i].copy_(est_out_i_1.squeeze()) - - target_i = msip_map(estimator_output, particles, km_inv_i, output_idx=i) - - new_particles[i].mul_(1.0 - lr).add_(target_i.mul_(lr)) + target_i = msip_map(estimator_output, new_particles, km_inv_i, output_idx=i) + new_particles[i] = new_particles[i].mul(1.0 - lr).add_(target_i.mul_(lr)) + kernel_matrix = self.get_infl_kernel_matrix( + new_particles, kernel_lengthscale + ) # Update the parameters new_kernel_lengthscale = self.get_adaptive_lengthscale(new_particles) - kernel_matrix = self.get_infl_kernel_matrix(new_particles, kernel_lengthscale) if new_kernel_lengthscale != kernel_lengthscale: - estimator_output = target(particles, new_kernel_lengthscale, target_args) + estimator_output = target( + new_particles, new_kernel_lengthscale, target_args + ) + kernel_matrix = self.get_infl_kernel_matrix( + new_particles, new_kernel_lengthscale + ) kernel_lengthscale = new_kernel_lengthscale algorithm_args = MSIPGSAlgorithmArgs(kernel_lengthscale, estimator_output) new_weights = get_msip_wts(new_particles, estimator_output, kernel_matrix) From 85e19e932554f5e92d0abe95eaeda44fc5c726d3 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Fri, 24 Apr 2026 18:43:43 -0400 Subject: [PATCH 24/60] Add every algorithm so far to Gaussian example --- examples/gaussian.py | 312 +++++++++++++++++++++++-------------------- 1 file changed, 164 insertions(+), 148 deletions(-) diff --git a/examples/gaussian.py b/examples/gaussian.py index ebd9c31..e44b759 100644 --- a/examples/gaussian.py +++ b/examples/gaussian.py @@ -1,3 +1,4 @@ +# Gaussian example with all algorithms. # %% from functools import partial import math @@ -8,12 +9,13 @@ from torch import Tensor import nak_torch -from nak_torch.algorithms import kfrflow from nak_torch.algorithms import SVGD, MSIP, MSIPGS, GradALDI, CBS, GradFreeALDI, EKS -from tqdm import tqdm -from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientInformed, MSIPQuadGradientFree +from nak_torch.algorithms.msip import ( + MSIPFredholm, + MSIPQuadGradientInformed, + MSIPQuadGradientFree, +) from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre -from nak_torch.tools.kernel import sqexp_kernel_elem as kernel_elem, sqexp_kernel_matrix from nak_torch.tools.types import BatchGradLogDensityEvaluator, BatchLogDensityEvaluator if torch.cuda.is_available(): @@ -23,23 +25,23 @@ torch.set_default_dtype(torch.float64) + # %% def make_gaussian_post( forward_op: Float[Tensor, "obs dim"], mean_pr: Float[Tensor, " dim"], cov_pr: Float[Tensor, "dim dim"], mean_li: Float[Tensor, " obs"], - cov_li: Float[Tensor, "obs obs"] + cov_li: Float[Tensor, "obs obs"], ): forward_op = forward_op.T cov_post = torch.linalg.inv( - forward_op.T @ torch.linalg.solve( - cov_li, forward_op - ) + torch.linalg.inv(cov_pr) + forward_op.T @ torch.linalg.solve(cov_li, forward_op) + torch.linalg.inv(cov_pr) + ) + mean_post = cov_post @ ( + forward_op.T @ torch.linalg.solve(cov_li, mean_li) + + torch.linalg.solve(cov_pr, mean_pr) ) - mean_post = cov_post @ (forward_op.T @ torch.linalg.solve( - cov_li, mean_li - ) + torch.linalg.solve(cov_pr, mean_pr)) return mean_post, cov_post @@ -49,7 +51,6 @@ def weighted_cov(pts: Tensor, wts: Tensor): return second_moment - mean.outer(mean) - # %% torch.manual_seed(1023921) obs_op = torch.randn(2, 5) @@ -58,43 +59,31 @@ def weighted_cov(pts: Tensor, wts: Tensor): true_obs = torch.tensor([1.0, 2.0, 3.0, 2.0, 1.0]) + 20 model = nak_torch.GaussianModel( - forward_model, likelihood_precision=10.0, - prior_precision=0.9, true_obs=true_obs, - is_vectorized=True + forward_model, + likelihood_precision=10.0, + prior_precision=0.9, + true_obs=true_obs, + is_vectorized=True, ) +post_log_dens = model.to_log_dens() -@torch.compile -def like_log_dens(pt, model): - ll_term = model.likelihood_precision * \ - torch.linalg.norm(pt @ obs_op - model.true_obs, dim=-1)**2 - return -0.5 * ll_term.squeeze() - - -# @torch.compile -# def post_log_dens(pt, model): -# ll_term = model.likelihood_precision * \ -# torch.linalg.norm(pt @ obs_op - model.true_obs, dim=-1)**2 -# prior_term = model.prior_precision * torch.linalg.norm(pt, dim=-1)**2 -# return -0.5 * (ll_term + prior_term).squeeze() -# post_log_dens_batch = torch.vmap(post_log_dens, in_dims=(0,None)) -# post_log_dens_grad_val = torch.func.grad_and_value(post_log_dens) -post_log_dens = model.to_log_dens() def _tmp_post_log_dens(pts, args): - out = post_log_dens(pts,args) + out = post_log_dens(pts, args) return out.sum(), out + + post_log_dens_grad_val_batch = torch.func.grad(_tmp_post_log_dens, has_aux=True) # %% mean_pr, cov_pr = torch.zeros(2), torch.eye(2) / model.prior_precision -mean_li, cov_li = model.true_obs, torch.eye( - len(model.true_obs) -) / model.likelihood_precision - -mean_post, cov_post = make_gaussian_post( - obs_op, mean_pr, cov_pr, mean_li, cov_li +mean_li, cov_li = ( + model.true_obs, + torch.eye(len(model.true_obs)) / model.likelihood_precision, ) + +mean_post, cov_post = make_gaussian_post(obs_op, mean_pr, cov_pr, mean_li, cov_li) vals, vecs = torch.linalg.eigh(cov_post) cov_post_sqrt = vecs @ torch.diag(torch.sqrt(vals)) @ vecs.T samps = torch.randn(10000, 2) @ cov_post_sqrt + mean_post @@ -102,126 +91,145 @@ def _tmp_post_log_dens(pts, args): # %% n_steps, n_particles = 1000, 500 lr = 1e-2 -bounds = (-100., 100.) +bounds = (-100.0, 100.0) rng = torch.Generator(device=torch.get_default_device()) rng.manual_seed(0) -init_particles = torch.randn((n_particles, 2), generator=rng) / \ - model.prior_precision + model.prior_mean +init_particles = ( + torch.randn((n_particles, 2), generator=rng) / model.prior_precision + + model.prior_mean +) # %% eks = EKS(dim=2, n_particles=n_particles, rng_or_seed=rng) trajectories_eks = nak_torch.nak( - model, eks, n_steps=n_steps, lr=lr, - rng_or_seed=rng, target_args=None, bounds=bounds, - init_particles=init_particles + model, + eks, + n_steps=n_steps, + lr=lr, + rng_or_seed=rng, + target_args=None, + bounds=bounds, + init_particles=init_particles, ) # %% -# init_particles = torch.randn((n_particles, 2)) + torch.tensor([3, -3]) -# delta_ts = torch.ones(1000)/1000 -# n_particles_kfr = 100 -# init_kfr = init_particles[:n_particles_kfr] #torch.randn((n_particles_kfr,2)) + torch.tensor([3,-5]) -# def imq(pt1,pt2,h): -# return 1/torch.sqrt(1 + (torch.linalg.norm(pt1-pt2) / h)**2) -# trajectories_kfr = kfrflow( -# like_log_dens, -# n_particles_kfr, -# 10000, 2, -# init_particles=init_kfr, -# kernel_length_scale = 1e-2, -# kernel_diag_infl=1e-5, -# # bounds=(-10,10), -# # kernel_elem=imq, -# keep_all=False -# ) - -# %% -# kernel_vec = torch.compile(torch.vmap(kernel_elem, in_dims=(None,0,None))) -# jac_kernel_vec = torch.vmap(torch.func.grad(kernel_elem), in_dims = (None, 0, None)) -# kernel_mat = sqexp_kernel_matrix -# n_steps_kfr = 100 -# delta_t = 1 / n_steps_kfr -# particles = init_particles.clone() -# kernel_length_scale = 1e-2 -# grad_ks = torch.empty((n_particles, n_particles, 2)) -# M_t = torch.empty((n_particles, n_particles)) -# for n in tqdm(range(n_steps_kfr)): -# log_likely_evals = like_log_dens(particles, model) -# M_t.zero_() -# for i in range(n_particles): -# grad_K = jac_kernel_vec(particles[i], particles, kernel_length_scale) -# grad_ks[i].copy_(grad_K) -# M_t.add_(grad_K @ grad_K.T) -# M_t = M_t.div_(n_particles) -# M_t[torch.arange(n_particles), torch.arange(n_particles)] += 1e-4 -# wts_shift = log_likely_evals.mean() -# wts = log_likely_evals.sub_(wts_shift).div_(n_particles) -# K_mat = kernel_mat(particles, kernel_length_scale) -# kernelized_wts = K_mat @ wts -# particles += torch.einsum("jid,i->jd", grad_ks, torch.linalg.solve(M_t, kernelized_wts)).mul_(delta_t) - - -# %% -grad_aldi = GradALDI(dim=2, n_particles=n_particles, rng = rng) -grad_aldi_target = BatchGradLogDensityEvaluator(post_log_dens, is_grad=False, is_batched=True) -trajectories_galdi = nak_torch.nak(grad_aldi_target, grad_aldi, - n_steps=n_steps, lr=lr, - init_particles=init_particles, keep_all=False, - rng_or_seed=rng, target_args=None, bounds=bounds +grad_aldi = GradALDI(dim=2, n_particles=n_particles, rng=rng) +grad_aldi_target = BatchGradLogDensityEvaluator( + post_log_dens, is_grad=False, is_batched=True +) +trajectories_galdi = nak_torch.nak( + grad_aldi_target, + grad_aldi, + n_steps=n_steps, + lr=lr, + init_particles=init_particles, + keep_all=False, + rng_or_seed=rng, + target_args=None, + bounds=bounds, ) # %% gf_aldi = GradFreeALDI(dim=2, n_particles=n_particles) -# trajectories_gfaldi = gradfree_aldi( -# model, n_particles, n_steps, dim=2, -# lr=lr, init_particles=init_particles, -# keep_all=True -# ) -trajectories_gfaldi = nak_torch.nak(model, gf_aldi, - n_steps=n_steps, lr=1e-2, - init_particles=init_particles, keep_all=True, - rng_or_seed=rng, target_args=None, bounds=bounds +trajectories_gfaldi = nak_torch.nak( + model, + gf_aldi, + n_steps=n_steps, + lr=1e-2, + init_particles=init_particles, + keep_all=True, + rng_or_seed=rng, + target_args=None, + bounds=bounds, ) # %% cbs_target = BatchLogDensityEvaluator(post_log_dens, is_batched=True) cbs = CBS(dim=2, n_particles=n_particles, default_inverse_temp=0.5, rng=rng) trajectories_cbs = nak_torch.nak( - cbs_target, cbs, 5000, lr=lr, - rng_or_seed=rng, init_particles=init_particles, - target_args=None, bounds=bounds + cbs_target, + cbs, + 5000, + lr=lr, + rng_or_seed=rng, + init_particles=init_particles, + target_args=None, + bounds=bounds, ) # %% -kernel_lengthscale = 0.15 +target_svgd = BatchGradLogDensityEvaluator( + post_log_dens, is_grad=False, is_batched=True +) +svgd = SVGD( + dim=2, + n_particles=n_particles, + kernel_lengthscale_quantile=0.5, # Median heuristic +) +trajectories_pts_svgd = nak_torch.nak( + target_svgd, + svgd, + n_steps=n_steps, + lr=lr, + init_particles=init_particles, + bounds=bounds, + target_args=None, +) + +# %% +kernel_lengthscale = 0.1 gradient_decay = 0.95 n_particles_msip = 25 n_steps_msip = 1000 -lr_msip = 5e-3 -kernel_diag_infl = 1e-5 +lr_msip = 0.1 +kernel_diag_infl = 1e-6 msip = MSIP( dim=2, n_particles=n_particles_msip, kernel_diag_infl=kernel_diag_infl, kernel_lengthscale=kernel_lengthscale, - # kernel_lengthscale_quantile=0.25 + # kernel_lengthscale_quantile=0.25 # If you want adaptive bandwidth. ) -msip_fredholm_target = MSIPFredholm( - gradient_decay, - post_log_dens_grad_val_batch +# %% +msip_fredholm_target = MSIPFredholm(gradient_decay, post_log_dens_grad_val_batch) + +trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( + msip_fredholm_target, + msip, + n_steps_msip, + lr_msip, + rng_or_seed=rng, + init_particles=init_particles[: msip.n_particles], + target_args=None, + keep_all=True, + bounds=bounds, ) -msip.kernel_lengthscale_quantile = 0.01 +# %% +msipgs = MSIPGS( + dim=2, + n_particles=n_particles_msip, + kernel_diag_infl=kernel_diag_infl, + kernel_lengthscale=kernel_lengthscale, + # kernel_lengthscale_quantile=0.25 # If you want adaptive bandwidth. +) -trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( - msip_fredholm_target, msip, n_steps_msip, lr_msip, - rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], - target_args=model, keep_all=True, bounds=bounds +trajectories_pts_msipgs_fr, trajectories_wts_msipgs_fr = nak_torch.nak( + msip_fredholm_target, + msipgs, + n_steps=500, + lr=1e-1, + rng_or_seed=rng, + init_particles=init_particles[: msipgs.n_particles], + target_args=None, + keep_all=True, + bounds=bounds, ) + # %% def mc_quad_rule(batch_size: int, N_quad: int = 5, dim: int = 2): pts = torch.randn((batch_size, N_quad, dim), generator=rng) @@ -230,57 +238,66 @@ def mc_quad_rule(batch_size: int, N_quad: int = 5, dim: int = 2): def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): - pts, wts = spherical_MC_radial_Laguerre( - batch_size, N_spherical, 2, N_radial - ) + pts, wts = spherical_MC_radial_Laguerre(batch_size, N_spherical, 2, N_radial) return pts, wts # %% msip_quadgrad_target = MSIPQuadGradientInformed( - post_log_dens_grad_val_batch, mc_quad_rule, - gradient_decay + post_log_dens_grad_val_batch, + # partial(spherical_quad, N_spherical=10, N_radial=4), + mc_quad_rule, + 1.0 ) -msip.kernel_lengthscale_quantile = 0.5 -msip.kernel_diag_infl = 1e-1 - trajectories_pts_msip_qg, trajectories_wts_msip_qg = nak_torch.nak( - msip_quadgrad_target, msip, n_steps_msip, 1e-1, - rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], target_args=model, - keep_all=False, bounds=bounds + msip_quadgrad_target, + msip, + 2000, + lr=5e-2, + rng_or_seed=rng, + init_particles=init_particles[: msip.n_particles], + target_args=None, + keep_all=False, + bounds=bounds, ) # %% msip_quadgf_target = MSIPQuadGradientFree( - post_log_dens, partial(spherical_quad, N_spherical=5, N_radial=4) + post_log_dens, partial(spherical_quad, N_spherical=10, N_radial=3) ) -msip.kernel_lengthscale_quantile = 0.05 trajectories_pts_msip_qgf, trajectories_wts_msip_qgf = nak_torch.nak( - msip_quadgf_target, msip, n_steps=n_steps_msip, lr=1e-1, - rng_or_seed=rng, init_particles=init_particles[:msip.n_particles], target_args=model, - keep_all=False, bounds=bounds + msip_quadgf_target, + msip, + n_steps=n_steps_msip, + lr=5e-2, + rng_or_seed=rng, + init_particles=init_particles[: msip.n_particles], + target_args=None, + keep_all=False, + bounds=bounds, ) # %% pts_eks = trajectories_eks[-1] -# pts_kfr = particles pts_galdi = trajectories_galdi[-1] pts_gfaldi = trajectories_gfaldi[-1] pts_cbs = trajectories_cbs[-1] idx_msip = -1 -alpha_msip = 2/math.sqrt(n_particles_msip) -pts_msip = trajectories_pts_msip_fr[idx_msip] -wts_msip = trajectories_wts_msip_fr[idx_msip] -# wts_msip /= wts_msip.sum() +alpha_msip = 2 / math.sqrt(n_particles_msip) +pts_msip_fr = trajectories_pts_msip_fr[idx_msip] +wts_msip_fr = trajectories_wts_msip_fr[idx_msip] +idx_msip_gs = -1 +pts_msipgs_fr = trajectories_pts_msipgs_fr[idx_msip_gs] +wts_msipgs_fr = trajectories_wts_msipgs_fr[idx_msip_gs] pts_msip_qg = trajectories_pts_msip_qg[-1] wts_msip_qg = trajectories_wts_msip_qg[-1] -# wts_msip_qg = wts_msip_qg/wts_msip_qg.sum() pts_msip_qgf = trajectories_pts_msip_qgf[-1] wts_msip_qgf = trajectories_wts_msip_qgf[-1] -# wts_msip_qgf = wts_msip_qgf/wts_msip_qgf.sum() +pts_svgd = trajectories_pts_svgd[-1] + Ngrid = 100 xgrid = torch.linspace(-1, 1, Ngrid) @@ -292,16 +309,15 @@ def spherical_quad(batch_size: int, N_spherical: int = 5, N_radial: int = 3): fig, ax = plt.subplots() ax.contour(X, Y, post_log_dens(grid_pts, model).reshape(Ngrid, Ngrid), levels=10) -handles = [] # ax.scatter(samps[:, 0], samps[:, 1], alpha=0.025, label="Truth") # ax.scatter(pts_galdi[:, 0], pts_galdi[:, 1], alpha=0.2, label="Grad-ALDI") -# ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], -# alpha=0.2, label="GradFree-ALDI") -# ax.scatter(pts_kfr[:,0], pts_kfr[:,1], label="KFR") +# ax.scatter(pts_gfaldi[:, 0], pts_gfaldi[:, 1], alpha=0.2, label="GradFree-ALDI") # ax.scatter(pts_eks[:, 0], pts_eks[:, 1], alpha=0.1, label="EKS") # ax.scatter(pts_cbs[:, 0], pts_cbs[:, 1], alpha=0.1, label="CBS") -ax.scatter(pts_msip[:, 0], pts_msip[:, 1], alpha=alpha_msip, label="MSIP") -# ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], alpha=alpha_msip, label="MSIP-QuadGrad") +# ax.scatter(pts_svgd[:, 0], pts_svgd[:, 1], alpha=0.1, label="SVGD") +# ax.scatter(pts_msip_fr[:, 0], pts_msip_fr[:, 1], alpha=alpha_msip, label="MSIP") +# ax.scatter(pts_msipgs_fr[:, 0], pts_msipgs_fr[:, 1], alpha=alpha_msip, label="MSIP-GS") +ax.scatter(pts_msip_qg[:, 0], pts_msip_qg[:, 1], alpha=alpha_msip, label="MSIP-QuadGrad") # ax.scatter(pts_msip_qgf[:, 0], pts_msip_qgf[:, 1], # s = 50*wts_msip_qgf.abs()/wts_msip_qgf.max(), alpha=alpha_msip, label="MSIP-QuadGradFree") # plt.colorbar(s) From edc62fd20e8625121d59f5606dd658d1b432fb99 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 25 Apr 2026 15:43:03 -0400 Subject: [PATCH 25/60] Work on AB example --- examples/aristoff_bangerth.py | 178 ++++++++++++-------- examples/functions/aristoff_bangerth.py | 18 +- src/nak_torch/algorithms/msip/estimators.py | 2 +- src/nak_torch/tools/util.py | 2 +- 4 files changed, 121 insertions(+), 79 deletions(-) diff --git a/examples/aristoff_bangerth.py b/examples/aristoff_bangerth.py index ef8f8ab..3459488 100644 --- a/examples/aristoff_bangerth.py +++ b/examples/aristoff_bangerth.py @@ -2,13 +2,16 @@ import math import torch from functions import aristoff_bangerth as ab, build_aristoff_bangerth -from nak_torch.algorithms import msip, svgd +import nak_torch +from nak_torch.algorithms import MSIP, SVGD +from nak_torch.algorithms.msip import MSIPFredholm from matplotlib import ticker import gc import matplotlib.pyplot as plt from nak_torch.tools.kernel import sqexp_kernel_matrix from tqdm import tqdm import pandas as pd +from nak_torch.tools.types import BatchGradLogDensityEvaluator import pyro_tools from pyro.infer import mcmc @@ -18,25 +21,73 @@ torch.set_default_device("cpu") torch.set_default_dtype(torch.float64) +# %% +def plot_samples(pts, max_side_len = 6): + n_particles = pts.shape[0] + side_len = min(max_side_len, int(math.floor(math.sqrt(n_particles)))) + pts = pts[:side_len**2] + fig = plt.figure(figsize=(9, 6), layout='constrained') + gs = fig.add_gridspec(side_len, side_len + 2) + vabs = max(pts.min().abs(), pts.max().abs()) + plt_kwargs = {'vmin': -vabs, 'vmax': vabs, 'extent': (0, 8, 0, 8)} + + for i in range(side_len): + for j in range(side_len): + ax = fig.add_subplot(gs[i, j]) + # ax.set_axis_off() + ax.set_aspect('equal') + t = ax.matshow(pts[i*side_len + j].reshape(8, 8), **plt_kwargs) + # ax.vlines(jnp.arange(1,8), -0.1, 8.1, color='w', lw=0.75) + # ax.hlines(jnp.arange(1,8), -0.1, 8.1, color='w', lw=0.75) + ax.minorticks_on() + ax.set_xticks([]) + ax.set_yticks([]) + ax.xaxis.set_minor_locator(ticker.MultipleLocator()) + ax.yaxis.set_minor_locator(ticker.MultipleLocator()) + ax.grid(which="both", linewidth=1.5, color="w") + ax.tick_params(which="minor", length=0) + ax_cb = fig.add_subplot(gs[:-2, -2:]) + ax_cb.set_title(r"Scale of $\log\theta$", y=0.6) + cax_cb = ax_cb.inset_axes((0.1, 0.45, 0.8, 0.1)) + ax_cb.axis('off') + fig.colorbar(t, cax=cax_cb, orientation='horizontal') # type: ignore + ax_true = fig.add_subplot(gs[-2:, -2:]) + ax_true.set_aspect('equal') + ax_true.matshow(ab.theta_true.log().reshape(8, 8), **plt_kwargs) + ax_true.minorticks_on() + ax_true.set_xticks([]) + ax_true.set_yticks([]) + ax_true.xaxis.set_minor_locator(ticker.MultipleLocator()) + ax_true.yaxis.set_minor_locator(ticker.MultipleLocator()) + ax_true.set_title(r"True $\theta$") + ax_true.grid(which="both", linewidth=1.5, color="w") + ax_true.tick_params(which="minor", length=0) + return fig + + # %% use_compiled = True model = build_aristoff_bangerth(use_compiled=use_compiled, dtype=torch.float64) log_p = model.to_log_dens(use_compiled=use_compiled) -log_th = torch.randn(500, 64, dtype=torch.float64) -test_out = log_p(log_th) +log_th = torch.randn(25, 64, dtype=torch.float64) +test_out = log_p(log_th, None) # %% -grad_log_p = torch.func.grad(lambda t: log_p(t).sum()) -test_eval = grad_log_p(log_th) +def _tmp_log_p(log_theta, arg: None): + ret = log_p(log_theta, arg) + return ret.sum(), ret + +grad_log_p = torch.func.grad(lambda t,a: log_p(t, a).sum()) +grad_val_log_p = torch.func.grad(_tmp_log_p, has_aux=True) +test_grad = grad_log_p(log_th, None) +test_grad_2, test_out_2 = grad_val_log_p(log_th, None) # %% -del log_th -del test_out -# del test_eval +del log_th, test_out, test_grad, test_grad_2, test_out_2 gc.collect() # %% -n_particles, n_steps, dim = 500, 25, 64 +n_particles, n_steps, dim = 25, 25, 64 kernel_bandwidth = 0.75 torch.manual_seed(1) @@ -45,84 +96,73 @@ dtype=torch.float64, ) # Sample from prior -msip_args = { +default_kwargs = { + "dim": dim, + "bounds": (-8, 8), + "n_steps": n_steps, "n_particles": n_particles, - "n_steps": n_steps, # "epochs" (passes over all particles) - "dim": 64, - "bounds": (-8, 8), # [a,b]^d - "gradient_informed": True, + "keep_all": False, "lr": 1e-1, - "noise": 0.05, # currently unused + "kernel_lengthscale": 0.1, "init_particles": init_particles, - "kernel_bandwidth": kernel_bandwidth, - "bandwidth_factor": 0.25, - "seed": 0, - "kernel_diag_infl": 1e-10, - "keep_all": False, - "device": None + "gradient_decay": 0.95, + "kernel_diag_infl": 1e-6, } # %% -trajectories_msip = msip( - log_p, - **msip_args -) +msip_kwargs = default_kwargs.copy() +msip_kwargs["lr"] = 1e-2 +msip_kwargs["kernel_lengthscale_quantile"] = 0.05 +msip = MSIP(**msip_kwargs) +target_msip_fr = MSIPFredholm(log_dens_grad_val=grad_val_log_p, **msip_kwargs) + +# %% +trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak(target_msip_fr, msip, **msip_kwargs) # %% -n_steps_hmc = 1000 +n_steps_hmc = 100 pyro_model = pyro_tools.PyroModel(model, dim) hmc_kernel = mcmc.NUTS(pyro_model) -mcmc_setup = mcmc.MCMC(hmc_kernel, num_samples=n_steps_hmc, warmup_steps=100) +mcmc_setup = mcmc.MCMC(hmc_kernel, num_samples=n_steps_hmc, warmup_steps=10) mcmc_setup.run(model.true_obs) hmc_samples = mcmc_setup.get_samples()["theta"] # %% -trajectories_svgd = svgd( - log_p, - is_log_density_batched=True, - **msip_args +target_svgd = BatchGradLogDensityEvaluator( + log_p, is_grad=False, is_batched=True ) +svgd = SVGD( + kernel_lengthscale_quantile=0.5, # Median heuristic + **msip_kwargs +) +svgd_kwargs = msip_kwargs.copy() +svgd_kwargs["lr"] = 1e-1 +svgd_kwargs["n_steps"] = 100 + +trajectories_pts_svgd = nak_torch.nak( + target_svgd, + svgd, + **svgd_kwargs +) + +# %% +pts_msip = trajectories_pts_msip_fr[-1] - init_particles +fig = plot_samples(pts_msip) +fig.suptitle("MSIP Samples") +plt.show() + +# %% +pts_hmc = hmc_samples[10::3] +fig = plot_samples(pts_hmc) +fig.suptitle("HMC Samples") +plt.show() + # %% -side_len = min(6, int(math.floor(math.sqrt(n_particles)))) -pts = trajectories_msip[-1][:side_len**2].detach().cpu()# - init_particles[:side_len**2] -fig = plt.figure(figsize=(9, 6), layout='constrained') -gs = fig.add_gridspec(side_len, side_len + 2) -vabs = max(pts.min().abs(), pts.max().abs()) -plt_kwargs = {'vmin': -vabs, 'vmax': vabs, 'extent': (0, 8, 0, 8)} - -for i in range(side_len): - for j in range(side_len): - ax = fig.add_subplot(gs[i, j]) - # ax.set_axis_off() - ax.set_aspect('equal') - t = ax.matshow(pts[i*side_len + j].reshape(8, 8), **plt_kwargs) - # ax.vlines(jnp.arange(1,8), -0.1, 8.1, color='w', lw=0.75) - # ax.hlines(jnp.arange(1,8), -0.1, 8.1, color='w', lw=0.75) - ax.minorticks_on() - ax.set_xticks([]) - ax.set_yticks([]) - ax.xaxis.set_minor_locator(ticker.MultipleLocator()) - ax.yaxis.set_minor_locator(ticker.MultipleLocator()) - ax.grid(which="both", linewidth=1.5, color="w") - ax.tick_params(which="minor", length=0) -ax_cb = fig.add_subplot(gs[:-2, -2:]) -ax_cb.set_title(r"Scale of $\log\theta$", y=0.6) -cax_cb = ax_cb.inset_axes((0.1, 0.45, 0.8, 0.1)) -ax_cb.axis('off') -fig.colorbar(t, cax=cax_cb, orientation='horizontal') -ax_true = fig.add_subplot(gs[-2:, -2:]) -ax_true.set_aspect('equal') -ax_true.matshow(ab.theta_true.log().reshape(8, 8), **plt_kwargs) -ax_true.minorticks_on() -ax_true.set_xticks([]) -ax_true.set_yticks([]) -ax_true.xaxis.set_minor_locator(ticker.MultipleLocator()) -ax_true.yaxis.set_minor_locator(ticker.MultipleLocator()) -ax_true.set_title(r"True $\theta$") -ax_true.grid(which="both", linewidth=1.5, color="w") -ax_true.tick_params(which="minor", length=0) +pts_svgd = trajectories_pts_svgd[-1] +fig = plot_samples(pts_svgd) +fig.suptitle("SVGD Samples") plt.show() # %% diff --git a/examples/functions/aristoff_bangerth.py b/examples/functions/aristoff_bangerth.py index 48f5507..afb4ecd 100644 --- a/examples/functions/aristoff_bangerth.py +++ b/examples/functions/aristoff_bangerth.py @@ -16,7 +16,9 @@ import torch from torch import Tensor from typing import Optional +from jaxtyping import Float from nak_torch import GaussianModel +from nak_torch.tools.types import BatchForwardModel, BatchPtType if __name__ == '__main__': torch.set_default_dtype(torch.float64) @@ -269,13 +271,13 @@ def build_forward_solver_args(N, N_obs, device=None, dtype: Optional[torch.dtype ###################### forward solver function ############################ ########################################################################### def forward_solver( - theta: Tensor, # (64,) + theta: BatchPtType, # (N_batch, 64) N: int, - M: Tensor, # (N_obs, N) - boundaries: Tensor, # (4*N,), - A_loc: Tensor, # (4, 4), - b: Tensor, # ((N+1)**2, ) -) -> Tensor: # (N+1, ) + M: Float[Tensor, "obs grid"], # (N_obs, N) + boundaries: Float[Tensor, " 4*grid"], # (4*N,), + A_loc: Float[Tensor, "p p"], # (4, 4) + b: Float[Tensor, " (grid+1)**2"], # ((N+1)**2, ) +) -> Float[Tensor, "batch obs"]: # (N+1, ) """ Solve Poisson PDE for Aristoff-Bangerth example. @@ -359,8 +361,8 @@ def log_prior(log_theta: Tensor, sig_pr_sq: float): norm_sq = log_theta.square().sum(-1) return -norm_sq / (2 * sig_pr_sq) -def build_forward_solver(N: int, H_obs: Tensor, *solve_args): - def prefill_forward_solver(log_theta: Tensor): +def build_forward_solver(N: int, H_obs: Tensor, *solve_args) -> BatchForwardModel: + def prefill_forward_solver(log_theta: Tensor, _ = None): return forward_solver(log_theta.exp(), N, H_obs, *solve_args) @ H_obs.T return prefill_forward_solver diff --git a/src/nak_torch/algorithms/msip/estimators.py b/src/nak_torch/algorithms/msip/estimators.py index 49ab25b..2c3c446 100644 --- a/src/nak_torch/algorithms/msip/estimators.py +++ b/src/nak_torch/algorithms/msip/estimators.py @@ -33,7 +33,7 @@ class MSIPFredholm(MSIPEstimator): log_dens_grad_val: BatchLogDensityGradVal def __init__( - self, gradient_decay: float, log_dens_grad_val: BatchLogDensityGradVal + self, gradient_decay: float, log_dens_grad_val: BatchLogDensityGradVal, **kwargs ): self.gradient_decay = gradient_decay self.log_dens_grad_val = log_dens_grad_val diff --git a/src/nak_torch/tools/util.py b/src/nak_torch/tools/util.py index 00cd3d3..c29ec71 100644 --- a/src/nak_torch/tools/util.py +++ b/src/nak_torch/tools/util.py @@ -82,7 +82,7 @@ def batched_grad_log_density_factory( def quantile_distance(pts: BatchPtType, quantile: float = 0.5) -> Float: """If quantile <= 0, get minimum. If quantile >= 1, get maximum""" assert pts.ndim == 2 - diffs = torch.sum(torch.square(pts.unsqueeze(0) - pts.unsqueeze(1)), -1).sqrt_() + diffs = torch.sum(torch.square(pts.unsqueeze(0) - pts.unsqueeze(1)), dim=-1).sqrt_() diffs_idxs = torch.triu_indices( pts.shape[0], pts.shape[0], offset=1, device=pts.device ) From bc5a158439c13897e6f923affbfe8cd75dcd9193 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 25 Apr 2026 18:44:44 -0400 Subject: [PATCH 26/60] Add DeepEnsembles --- src/nak_torch/algorithms/__init__.py | 5 +- src/nak_torch/algorithms/deepensembles.py | 78 +++-------------------- 2 files changed, 12 insertions(+), 71 deletions(-) diff --git a/src/nak_torch/algorithms/__init__.py b/src/nak_torch/algorithms/__init__.py index ca00ac3..8a42a04 100644 --- a/src/nak_torch/algorithms/__init__.py +++ b/src/nak_torch/algorithms/__init__.py @@ -10,7 +10,7 @@ from .eks import EKS from .msip import MSIP, MSIPGS from .svgd import SVGD -from .deepensembles import deepensembles +from .deepensembles import DeepEnsembles from .grad_aldi import GradALDI from .gradfree_aldi import GradFreeALDI from .cbs import CBS @@ -23,9 +23,8 @@ "MSIP", "MSIPGS", "SVGD", - "deepensembles", + "DeepEnsembles", "GradALDI", - # "gradfree_aldi", "GradFreeALDI", "EKS", "CBS", diff --git a/src/nak_torch/algorithms/deepensembles.py b/src/nak_torch/algorithms/deepensembles.py index 0fa87f7..003e7f5 100644 --- a/src/nak_torch/algorithms/deepensembles.py +++ b/src/nak_torch/algorithms/deepensembles.py @@ -6,75 +6,17 @@ # Ayoub Belhadji # 05/12/2025 -import warnings -import numpy as np -import torch -from typing import Optional, Callable -from tqdm import tqdm -from nak_torch.tools.kernel import sqexp_kernel_elem -from nak_torch.tools.types import KernelFunction, BatchGradLogDensity, BatchPtType -from nak_torch.tools.util import batched_grad_log_density_factory, initialize_particles +from nak_torch.tools.func import UnweightedAdaptiveNAKAlgorithm +from nak_torch.tools.types import BatchGradLogDensityEvaluator +__all__ = ["DeepEnsembles"] -def create_deepensembles_step( - grad_log_p: BatchGradLogDensity, -) -> Callable[[BatchPtType], BatchPtType]: - def deepensembles_step_dir(points: BatchPtType): - log_p_grad_ev = grad_log_p(points) - return log_p_grad_ev +class DeepEnsembles(UnweightedAdaptiveNAKAlgorithm[BatchGradLogDensityEvaluator, None]): + def initialize(self, init_particles, target, target_args): + return None, None - return deepensembles_step_dir - - -def deepensembles( - log_density, - n_particles: int, - n_steps: int, - dim: int, - lr: float, - seed: Optional[int] = None, - device: Optional[torch.device] = None, - init_particles: Optional[torch.Tensor | np.ndarray] = None, - kernel_length_scale: float = 1.0, - kernel_elem: KernelFunction = sqexp_kernel_elem, - bounds: Optional[tuple[float, float]] = None, - keep_all: bool = True, - is_log_density_batched: bool = False, - grad_log_density: Optional[BatchGradLogDensity] = None, - verbose: bool = False, - **unused_kwargs, -): - if verbose and len(unused_kwargs) > 0: - warnings.warn("Unused kwargs:\n{}".format(unused_kwargs)) - - if seed is not None: - torch.manual_seed(seed) - - particles = initialize_particles(n_particles, dim, init_particles, device, bounds) - - if keep_all: - trajectories = torch.empty( - (n_steps, *particles.shape), device=device, dtype=particles.dtype - ) - trajectories[0].copy_(particles) - else: - trajectories = torch.empty(()) - - grad_log_p = batched_grad_log_density_factory( - log_density, is_log_density_batched, grad_log_density - ) - step_fcn = create_deepensembles_step(grad_log_p) - - trajectories[0].copy_(particles) - - for idx in tqdm(range(n_steps - 1), disable=not verbose): - particles_diff = step_fcn(particles) - with torch.no_grad(): - particles = particles + lr * particles_diff - if bounds is not None: - particles.clamp_(bounds[0], bounds[1]) - if keep_all: - trajectories[idx + 1].copy_(particles) - - return trajectories.detach() if keep_all else particles.unsqueeze_(0) + def step(self, lr, particles, target, algorithm_args, target_args): + grad_log_dens_eval = target(particles, target_args) + new_particles = particles.add(grad_log_dens_eval.mul_(lr)) + return new_particles, None, algorithm_args From fe304ee445397888cba13b3bd43599bc125f73fa Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 25 Apr 2026 20:57:19 -0400 Subject: [PATCH 27/60] Add scipy as dev dep --- pyproject.toml | 1 + uv.lock | 2 ++ 2 files changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 22bcd09..1033393 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dev = [ "pytest>=9.0.2", "pytest-cov>=7.1.0", "ruff>=0.15.8", + "scipy>=1.17.1", ] [tool.pytest.ini_options] diff --git a/uv.lock b/uv.lock index 08c4634..28a48ca 100644 --- a/uv.lock +++ b/uv.lock @@ -765,6 +765,7 @@ dev = [ { name = "pytest" }, { name = "pytest-cov" }, { name = "ruff" }, + { name = "scipy" }, ] [package.metadata] @@ -788,6 +789,7 @@ dev = [ { name = "pytest", specifier = ">=9.0.2" }, { name = "pytest-cov", specifier = ">=7.1.0" }, { name = "ruff", specifier = ">=0.15.8" }, + { name = "scipy", specifier = ">=1.17.1" }, ] [[package]] From b3b1cb37774b6edb856656fd295a952965ebb858 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 25 Apr 2026 20:57:40 -0400 Subject: [PATCH 28/60] Work on batching --- src/nak_torch/algorithms/loop.py | 6 ++++- src/nak_torch/tools/types.py | 38 +++++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index 54049f6..2794b11 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Iterator, Optional import warnings from tqdm import tqdm @@ -28,6 +28,7 @@ def nak( bounds: Optional[tuple[float, float]] = None, keep_all: bool = True, target_args: Any = None, + get_target_args: Optional[Iterator] = None, **kwargs, ) -> Tensor | tuple[Tensor, Tensor]: r""" @@ -77,6 +78,9 @@ def nak( if algorithm.is_weighted(): traj_wts[idx + 1].copy_(particle_wts) + if get_target_args is not None: + target_args = next(get_target_args) + particles, particle_wts, algorithm_args = algorithm.step( lr, particles, target, algorithm_args, target_args ) diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index e2f1e65..b0dd709 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -5,6 +5,7 @@ import torch import numpy as np from torch import Tensor +from torch.utils import data as torch_data from jaxtyping import Float, Bool from typing import Self @@ -192,7 +193,7 @@ class LogisticRegressionModel(AbstractModel): def __init__( self, - data_or_fname: Float[Tensor, "dim-1 labels"] | str, + data_or_fname: Float[Tensor, "labels dim-1"] | str, labels: Optional[Float[Tensor, " labels"]], prior_mean: float | Float[Tensor, " dim"] | None = None, dtype=None, @@ -225,18 +226,18 @@ def as_tensor(t): if labels is None or labels.shape[0] != N_pts: raise ValueError("Unexpected type or size of argument `labels`.") constant = as_tensor(torch.ones(N_pts)) - data = torch.column_stack((constant, data)).T + data = torch.column_stack((constant, data)) if train_proportion >= 1.0: self.train_data, self.test_data = data, None self.train_labels, self.test_labels = labels, None else: ridx = torch.randperm(N_pts) num_train = int(np.floor(N_pts * train_proportion)) - self.train_data = data[:, ridx[:num_train]] + self.train_data = data[ridx[:num_train]] self.train_labels = labels[ridx[:num_train]] - self.test_data = data[:, ridx[num_train:]] + self.test_data = data[ridx[num_train:]] self.test_labels = labels[ridx[num_train:]] - self.dim = data.shape[0] + 1 + self.dim = data.shape[1] + 1 self.prior_mean = prior_mean self.sum_bernoulli = sum_bernoulli self.hyperprior = torch.distributions.Gamma( @@ -271,7 +272,7 @@ def log_dens( prior_term = prior_diff.square().sum(dim=-1).mul_(0.5 * precision).neg_() # log-normalization constant of prior w.r.t. alpha = precision prior_term += 0.5 * self.dim * log_precision - logits = coeffs @ data + logits = coeffs @ data.T likelihood = bernoulli_loglikelihood_logit_v(logits, labels) if not self.sum_bernoulli: likelihood /= labels.numel() @@ -279,3 +280,28 @@ def log_dens( return post if is_batch else post[0] return torch.compile(log_dens) if use_compiled else log_dens + + def get_data_loader( + self, + use_test_data: bool, + batch_size: int = 1, + shuffle: bool = False, + num_workers: int = 0, + *data_loader_args, + **data_loader_kwargs, + ): + data: torch_data.TensorDataset + if use_test_data: + if self.test_data is None or self.test_labels is None: + raise ValueError("Cannot use test data as None") + data = torch_data.TensorDataset(self.test_data, self.test_labels) + else: + data = torch_data.TensorDataset(self.train_data, self.train_labels) + return torch_data.DataLoader( + data, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + *data_loader_args, + **data_loader_kwargs, + ) From 6e965afb5d33dbcde5f62c87208a6e09d23669ae Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 25 Apr 2026 20:57:46 -0400 Subject: [PATCH 29/60] Add batching to covtype --- examples/logistic_regression/covertype.py | 63 ++++++++++++----------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/examples/logistic_regression/covertype.py b/examples/logistic_regression/covertype.py index 7e394b3..e870548 100644 --- a/examples/logistic_regression/covertype.py +++ b/examples/logistic_regression/covertype.py @@ -2,7 +2,8 @@ import os from urllib.request import urlretrieve import torch -from nak_torch.algorithms import msip, msip_gs, svgd +import nak_torch +from nak_torch.algorithms import MSIP, MSIPGS, SVGD import matplotlib.pyplot as plt from nak_torch import LogisticRegressionModel from nak_torch.tools import pyro_tools @@ -11,7 +12,9 @@ from nak_torch.algorithms.msip import ( MSIPFredholm, MSIPQuadGradientInformed, + MSIPQuadGradientFree, ) + from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre import scipy.io import numpy as np @@ -43,11 +46,12 @@ def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): # %% data_path = DATA_PATH regression_model = LogisticRegressionModel(data_path, None, hyperprior_b=0.01, train_proportion=0.8, sum_bernoulli=False) -log_dens = regression_model.to_log_dens(use_compiled=False) +log_dens = regression_model.to_log_dens(use_compiled=True) +train_data_loader = regression_model.get_data_loader(False, batch_size=64) # %% N_plot = 10000 -plt.scatter(regression_model.train_data[2,:N_plot], regression_model.train_data[3, :N_plot], c=regression_model.train_labels[:N_plot], alpha=0.2) +plt.scatter(regression_model.train_data[:N_plot,2], regression_model.train_data[:N_plot,3], c=regression_model.train_labels[:N_plot], alpha=0.2) plt.show() # %% @@ -65,7 +69,7 @@ def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): lr_msip = 0.05 kernel_diag_infl = 1e-5 n_steps = 1000 -grad_val_log_p = torch.vmap(torch.func.grad_and_value(log_dens)) +grad_val_log_p = torch.vmap(torch.func.grad_and_value(log_dens), in_dims=(0, None)) @torch.compile(dynamic=False) def mc_quad_rule(batch_size: int, N_quad: int = 500, dim: int = 56): @@ -78,30 +82,29 @@ def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, di pts, wts = spherical_MC_radial_Laguerre(batch_size, N_spherical, dim, N_radial) return pts, wts +# %% +msip = MSIP( + dim = regression_model.dim, + n_particles = n_particles, + kernel_diag_infl = 1e-6, + kernel_lengthscale=1e-1, +) -msip_f = MSIPFredholm(gradient_decay, grad_val_log_p) -msip_gi = MSIPQuadGradientInformed(grad_val_log_p, mc_quad_rule, gradient_decay) +target_msip_f = MSIPFredholm(gradient_decay, grad_val_log_p) +target_msip_gi = MSIPQuadGradientInformed(grad_val_log_p, mc_quad_rule, gradient_decay) # %% -trajectories_msip, traj_wts_msip = msip( - msip_f, - n_particles, - n_steps, - dim=state_dim, - lr=lr_msip, - init_particles=init_particles[:n_particles], - kernel_length_scale=kernel_length_scale, - is_log_density_batched=True, - kernel_diag_infl=kernel_diag_infl, - bounds=bounds, - keep_all=True, - compile_step=True, - verbose=True, +trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( + target_msip_f, + msip, + n_steps=n_steps, + lr=1e-2, + init_particles=init_particles, + get_target_args=iter(train_data_loader), + bounds=(-100, 100) ) -trajectories_msip[-1] - # %% -msip_end = trajectories_msip[-1] +msip_end = trajectories_pts_msip_fr[-1] dist_end = torch.sqrt(torch.sum(torch.square_(msip_end[None,:] - msip_end[:,None]), -1)) lower_tri_idx = torch.tril_indices(*dist_end.shape, -1) lower_tri_dist = dist_end[*lower_tri_idx] @@ -113,12 +116,12 @@ def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, di # @torch.compile def bce_logit_t(traj_t): - logits_t = traj_t[:,:-1] @ regression_model.test_data + logits_t = traj_t[:,:-1] @ regression_model.test_data.T return bce_logit_v(logits_t, regression_model.test_labels) bce_logit_traj = torch.vmap(bce_logit_t) bse_traj_list = [] -for j in tqdm(range(trajectories_msip.shape[0])): - bse_traj_list.append(bce_logit_t(trajectories_msip[j])) +for j in tqdm(range(trajectories_pts_msip_fr.shape[0])): + bse_traj_list.append(bce_logit_t(trajectories_pts_msip_fr[j])) bce_traj = torch.stack(bse_traj_list) # logits_t = trajectories_msip[:,:,:-1].reshape(-1, trajectories_msip.shape[-1] - 1) @ regression_model.data # bce_traj = bce_logit_v(logits_t, regression_model.labels).reshape(*trajectories_msip.shape[:2], -1) @@ -133,17 +136,17 @@ def bce_logit_t(traj_t): # %% def accuracy(coeffs): data, labels = regression_model.test_data, regression_model.test_labels - prob = torch.sigmoid(coeffs[:-1] @ data) + prob = torch.sigmoid(coeffs[:-1] @ data.T) pred_labels = prob > 0.5 print(pred_labels.sum()) N_true = torch.sum(pred_labels == labels) - return N_true / data.shape[1] + return N_true / data.shape[0] accuracy_v = torch.vmap(accuracy) -accuracy_v(trajectories_msip[-1]) +accuracy_v(trajectories_pts_msip_fr[-1]) # %% -trajectories_msip, traj_wts_msip = svgd( +trajectories_svgd, traj_wts_svgd = svgd( msip_f, n_particles, n_steps, From e18229b6b63f7fbb9cf657bfefb0a2a35c535ccb Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 26 Apr 2026 14:49:21 -0400 Subject: [PATCH 30/60] Add adagrad for svgd --- src/nak_torch/algorithms/loop.py | 1 + src/nak_torch/algorithms/svgd.py | 29 +++++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index 2794b11..a3e1a8a 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -33,6 +33,7 @@ def nak( ) -> Tensor | tuple[Tensor, Tensor]: r""" TODO: Document + target_args: If `get_target_args` is not None, nak uses this for initializing the algorithm's parameters. """ verbose, n_particles = algorithm.verbose, algorithm.n_particles if verbose and len(kwargs) > 0: diff --git a/src/nak_torch/algorithms/svgd.py b/src/nak_torch/algorithms/svgd.py index bde255d..6d7cf32 100644 --- a/src/nak_torch/algorithms/svgd.py +++ b/src/nak_torch/algorithms/svgd.py @@ -53,6 +53,7 @@ def svgd_step( @dataclass class SVGDAlgorithmArgs: kernel_lengthscale: float + historical_grad: BatchPtType class SVGD( @@ -61,6 +62,8 @@ class SVGD( default_kernel_lengthscale: float kernel_lengthscale_quantile: Optional[float] kernel_grad_val: BatchKernelGradValFunction + gradient_decay_factor: float + fudge_factor: float def get_adaptive_lengthscale(self, particles: BatchPtType) -> float: q = self.kernel_lengthscale_quantile @@ -78,6 +81,8 @@ def __init__( kernel_lengthscale: Optional[float] = None, kernel_lengthscale_quantile: Optional[float] = None, kernel_elem: Optional[KernelFunction] = None, + gradient_decay_factor: float = 0.9, + fudge_factor: float = 1e-6, **kwargs, ): super().__init__(dim, n_particles, device, dtype, **kwargs) @@ -98,17 +103,33 @@ def __init__( ) self.kernel_lengthscale_quantile = kernel_lengthscale_quantile self.kernel_grad_val = create_svgd_kernel_grad_val(kernel_elem) + self.gradient_decay_factor = gradient_decay_factor + self.fudge_factor = fudge_factor def initialize(self, init_particles, target, target_args): kernel_lengthscale = self.get_adaptive_lengthscale(init_particles) - return None, SVGDAlgorithmArgs(kernel_lengthscale) + grad_log_dens_eval = target(init_particles, target_args) + particles_diff = svgd_step( + self.kernel_grad_val, init_particles, grad_log_dens_eval, kernel_lengthscale + ) + historical_grad = particles_diff.square() + return None, SVGDAlgorithmArgs(kernel_lengthscale, historical_grad) def step(self, lr, particles, target, algorithm_args, target_args): - (kernel_lengthscale,) = astuple(algorithm_args) + kernel_lengthscale: float + historical_grad: BatchPtType + kernel_lengthscale, historical_grad = astuple(algorithm_args) grad_log_dens_eval = target(particles, target_args) particles_diff = svgd_step( self.kernel_grad_val, particles, grad_log_dens_eval, kernel_lengthscale ) - new_particles = particles_diff.mul_(lr).add_(particles) + alpha = self.gradient_decay_factor + historical_grad = alpha * historical_grad + (1 - alpha) * (particles_diff**2) + adj_grad = particles_diff.divide_(self.fudge_factor + historical_grad.sqrt()) + new_particles = adj_grad.mul_(lr).add_(particles) new_kernel_lengthscale = self.get_adaptive_lengthscale(new_particles) - return new_particles, None, SVGDAlgorithmArgs(new_kernel_lengthscale) + return ( + new_particles, + None, + SVGDAlgorithmArgs(new_kernel_lengthscale, historical_grad), + ) From 42c70641fabad4a66ae9c09208ca209a527b57c0 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 26 Apr 2026 14:49:46 -0400 Subject: [PATCH 31/60] Adjust logistic regression interface --- src/nak_torch/tools/types.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index b0dd709..6f5535c 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -188,7 +188,7 @@ class LogisticRegressionModel(AbstractModel): test_data: Optional[Float | Float[Tensor, "dim labels"]] train_labels: Float | Float[Tensor, " labels"] test_labels: Optional[Float | Float[Tensor, " labels"]] - sum_bernoulli: bool + use_mean_reduction: bool hyperprior: torch.distributions.Gamma def __init__( @@ -201,7 +201,7 @@ def __init__( hyperprior_a=1.0, hyperprior_b=0.1, train_proportion=1.0, - sum_bernoulli=True, + reduction="mean", ): data: torch.Tensor dtype = torch.get_default_dtype() if dtype is None else dtype @@ -210,6 +210,15 @@ def __init__( def as_tensor(t): return torch.as_tensor(t, dtype=dtype, device=device) + match reduction: + case "mean": + self.use_mean_reduction = True + case "sum": + self.use_mean_reduction = False + case _: + raise ValueError( + f"Expected reduction to be sum or mean, got {reduction}" + ) self.prior_mean = prior_mean if prior_mean is None else as_tensor(prior_mean) if isinstance(data_or_fname, str): data = as_tensor(np.load(data_or_fname)) @@ -239,7 +248,6 @@ def as_tensor(t): self.test_labels = labels[ridx[num_train:]] self.dim = data.shape[1] + 1 self.prior_mean = prior_mean - self.sum_bernoulli = sum_bernoulli self.hyperprior = torch.distributions.Gamma( as_tensor(hyperprior_a), as_tensor(hyperprior_b) ) @@ -274,7 +282,7 @@ def log_dens( prior_term += 0.5 * self.dim * log_precision logits = coeffs @ data.T likelihood = bernoulli_loglikelihood_logit_v(logits, labels) - if not self.sum_bernoulli: + if self.use_mean_reduction: likelihood /= labels.numel() post = likelihood + prior_term + hyperprior_term return post if is_batch else post[0] From 411bb706d96aac61b4daccaebf8bbb94de34807a Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 26 Apr 2026 14:49:56 -0400 Subject: [PATCH 32/60] Add svgd to covertype --- examples/logistic_regression/covertype.py | 149 ++++++++++++++-------- 1 file changed, 94 insertions(+), 55 deletions(-) diff --git a/examples/logistic_regression/covertype.py b/examples/logistic_regression/covertype.py index e870548..924ac6b 100644 --- a/examples/logistic_regression/covertype.py +++ b/examples/logistic_regression/covertype.py @@ -8,6 +8,7 @@ from nak_torch import LogisticRegressionModel from nak_torch.tools import pyro_tools from pyro.infer import mcmc +from tqdm import tqdm from nak_torch.algorithms.msip import ( MSIPFredholm, @@ -19,6 +20,8 @@ import scipy.io import numpy as np +from nak_torch.tools.types import BatchGradLogDensityEvaluator + if torch.cuda.is_available(): torch.set_default_device("cuda") else: @@ -29,10 +32,11 @@ DATA_URL = "https://raw.githubusercontent.com/DartML/Stein-Variational-Gradient-Descent/refs/heads/master/data/covertype.mat" DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "covertype.npy") + def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): urlretrieve(data_url, data_path) data_mat = scipy.io.loadmat(data_path) - data_arr = data_mat['covtype'] + data_arr = data_mat["covtype"] # Flip first col to be (0,1) instead of (2,1) (where 2 is false) covariates = data_arr[:, 1:] labels = -1 * (data_arr[:, 0] - 2) @@ -40,84 +44,110 @@ def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): # Save np.save(data_path, data_arr) + if not os.path.isfile(DATA_PATH): download_file() # %% data_path = DATA_PATH -regression_model = LogisticRegressionModel(data_path, None, hyperprior_b=0.01, train_proportion=0.8, sum_bernoulli=False) +regression_model = LogisticRegressionModel( + data_path, None, hyperprior_b=0.01, train_proportion=0.8 +) log_dens = regression_model.to_log_dens(use_compiled=True) -train_data_loader = regression_model.get_data_loader(False, batch_size=64) # %% N_plot = 10000 -plt.scatter(regression_model.train_data[:N_plot,2], regression_model.train_data[:N_plot,3], c=regression_model.train_labels[:N_plot], alpha=0.2) +plt.scatter( + regression_model.train_data[:N_plot, 2], + regression_model.train_data[:N_plot, 3], + c=regression_model.train_labels[:N_plot], + alpha=0.2, +) plt.show() # %% -n_particles, state_dim = 20, regression_model.dim -alpha_init = regression_model.hyperprior.sample((n_particles,1)) +N_PARTICLES = 100 +STATE_DIM = regression_model.dim +alpha_init = regression_model.hyperprior.sample((N_PARTICLES, 1)) log_alpha_init = alpha_init.log() -coeff_init = torch.randn((n_particles, regression_model.dim - 1)) / alpha_init.sqrt() +coeff_init = torch.randn((N_PARTICLES, STATE_DIM - 1)) / alpha_init.sqrt() init_particles = torch.column_stack((coeff_init, log_alpha_init)) log_dens(init_particles) # test eval # %% -kernel_length_scale = 0.05 -bounds = (-100.0, 100.0) -gradient_decay = 0.9 -lr_msip = 0.05 -kernel_diag_infl = 1e-5 -n_steps = 1000 +BATCH_SIZE = 50 +train_data_loader = regression_model.get_data_loader(False, batch_size=BATCH_SIZE) + +# %% grad_val_log_p = torch.vmap(torch.func.grad_and_value(log_dens), in_dims=(0, None)) + @torch.compile(dynamic=False) def mc_quad_rule(batch_size: int, N_quad: int = 500, dim: int = 56): pts = torch.randn((batch_size, N_quad, dim), dtype=torch.get_default_dtype()) wts = torch.ones((batch_size, N_quad), dtype=torch.get_default_dtype()).div_(N_quad) return pts, wts + @torch.compile(dynamic=False) -def spherical_quad(batch_size: int, N_spherical: int = 10, N_radial: int = 3, dim: int = 56): +def spherical_quad( + batch_size: int, N_spherical: int = 10, N_radial: int = 3, dim: int = 56 +): pts, wts = spherical_MC_radial_Laguerre(batch_size, N_spherical, dim, N_radial) return pts, wts + # %% +KERNEL_LENGTHSCALE = 0.1 +GRADIENT_DECAY = 0.9 +KERNEL_DIAG_INFL = 1e-5 + msip = MSIP( - dim = regression_model.dim, - n_particles = n_particles, - kernel_diag_infl = 1e-6, - kernel_lengthscale=1e-1, + dim=STATE_DIM, + n_particles=N_PARTICLES, + kernel_diag_infl=KERNEL_DIAG_INFL, + kernel_lengthscale=KERNEL_LENGTHSCALE, + kernel_lengthscale_quantile=0.01, ) -target_msip_f = MSIPFredholm(gradient_decay, grad_val_log_p) -target_msip_gi = MSIPQuadGradientInformed(grad_val_log_p, mc_quad_rule, gradient_decay) +target_msip_f = MSIPFredholm(GRADIENT_DECAY, grad_val_log_p) +target_msip_gi = MSIPQuadGradientInformed(grad_val_log_p, mc_quad_rule, GRADIENT_DECAY) # %% -trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( - target_msip_f, - msip, - n_steps=n_steps, - lr=1e-2, - init_particles=init_particles, - get_target_args=iter(train_data_loader), - bounds=(-100, 100) -) +BOUNDS = (-100.0, 100.0) +N_STEPS = 6000 +LR_MSIP = 0.05 +# trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( +# target_msip_f, +# msip, +# n_steps=N_STEPS, +# lr=LR_MSIP, +# init_particles=init_particles, +# get_target_args=iter(train_data_loader), +# bounds=BOUNDS, +# ) + # %% msip_end = trajectories_pts_msip_fr[-1] -dist_end = torch.sqrt(torch.sum(torch.square_(msip_end[None,:] - msip_end[:,None]), -1)) +dist_end = torch.sqrt( + torch.sum(torch.square_(msip_end[None, :] - msip_end[:, None]), -1) +) lower_tri_idx = torch.tril_indices(*dist_end.shape, -1) lower_tri_dist = dist_end[*lower_tri_idx] plt.hist(lower_tri_dist) # %% -from tqdm import tqdm -bce_logit_v = torch.vmap(torch.nn.functional.binary_cross_entropy_with_logits, in_dims=(0,None)) +bce_logit_v = torch.vmap( + torch.nn.functional.binary_cross_entropy_with_logits, in_dims=(0, None) +) + # @torch.compile def bce_logit_t(traj_t): - logits_t = traj_t[:,:-1] @ regression_model.test_data.T + logits_t = traj_t[:, :-1] @ regression_model.test_data.T return bce_logit_v(logits_t, regression_model.test_labels) + + bce_logit_traj = torch.vmap(bce_logit_t) bse_traj_list = [] for j in tqdm(range(trajectories_pts_msip_fr.shape[0])): @@ -129,35 +159,44 @@ def bce_logit_t(traj_t): # %% fig, ax = plt.subplots() -for particle_idx in range(n_particles): - ax.loglog(bce_traj[:,particle_idx], alpha= 0.4) +for particle_idx in range(N_PARTICLES): + ax.loglog(bce_traj[:, particle_idx], alpha=0.4) plt.show() + # %% def accuracy(coeffs): data, labels = regression_model.test_data, regression_model.test_labels prob = torch.sigmoid(coeffs[:-1] @ data.T) pred_labels = prob > 0.5 - print(pred_labels.sum()) - N_true = torch.sum(pred_labels == labels) - return N_true / data.shape[0] + return torch.mean((pred_labels == labels).to(torch.float64)) accuracy_v = torch.vmap(accuracy) -accuracy_v(trajectories_pts_msip_fr[-1]) - -# %% -trajectories_svgd, traj_wts_svgd = svgd( - msip_f, - n_particles, - n_steps, - dim=state_dim, - lr=lr_msip, - init_particles=init_particles[:n_particles], - kernel_length_scale=kernel_length_scale, - is_log_density_batched=True, - kernel_diag_infl=kernel_diag_infl, - bounds=bounds, - keep_all=True, - compile_step=True, - verbose=True, +# accuracy_v(trajectories_pts_msip_fr[-1]) + +# %% +svgd = SVGD( + STATE_DIM, + N_PARTICLES, + kernel_lengthscale_quantile=0.5 +) + +target_svgd = BatchGradLogDensityEvaluator( + log_dens, is_grad=False, is_batched=True ) + +# %% +trajectories_pts_svgd = nak_torch.nak( + target_svgd, + svgd, + n_steps=N_STEPS, + lr=LR_MSIP, + init_particles=init_particles, + get_target_args=iter(train_data_loader), + bounds=BOUNDS, +) + +# %% +accuracy(trajectories_pts_svgd[-1].mean(dim=0)) + +# %% From c7699e56fadbcecab07f1d76fe120ca5800e142c Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 26 Apr 2026 16:22:41 -0400 Subject: [PATCH 33/60] Add stanpy and posteriordb to deps for examples --- pyproject.toml | 2 + uv.lock | 139 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 141 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 1033393..8e9d441 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,10 @@ nak-torch = "nak_torch:main" examples = [ "ipykernel>=7.2.0", "matplotlib>=3.10.8", + "posteriordb>=0.2.0", "pyro-ppl>=1.9.1", "scipy>=1.17.1", + "stanpy>=0.2.11", ] [build-system] diff --git a/uv.lock b/uv.lock index 28a48ca..2f99b57 100644 --- a/uv.lock +++ b/uv.lock @@ -28,6 +28,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" }, ] +[[package]] +name = "certifi" +version = "2026.4.22" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/25/ee/6caf7a40c36a1220410afe15a1cc64993a1f864871f698c0f93acb72842a/certifi-2026.4.22.tar.gz", hash = "sha256:8d455352a37b71bf76a79caa83a3d6c25afee4a385d632127b6afb3963f1c580", size = 137077, upload-time = "2026-04-22T11:26:11.191Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/30/7cd8fdcdfbc5b869528b079bfb76dcdf6056b1a2097a662e5e8c04f42965/certifi-2026.4.22-py3-none-any.whl", hash = "sha256:3cb2210c8f88ba2318d29b0388d1023c8492ff72ecdde4ebdaddbb13a31b1c4a", size = 135707, upload-time = "2026-04-22T11:26:09.372Z" }, +] + [[package]] name = "cffi" version = "2.0.0" @@ -85,6 +94,79 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/3a/dbeec9d1ee0844c679f6bb5d6ad4e9f198b1224f4e7a32825f47f6192b0c/cffi-2.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:0a1527a803f0a659de1af2e1fd700213caba79377e27e4693648c2923da066f9", size = 184195, upload-time = "2025-09-08T23:23:43.004Z" }, ] +[[package]] +name = "charset-normalizer" +version = "3.4.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/a1/67fe25fac3c7642725500a3f6cfe5821ad557c3abb11c9d20d12c7008d3e/charset_normalizer-3.4.7.tar.gz", hash = "sha256:ae89db9e5f98a11a4bf50407d4363e7b09b31e55bc117b4f7d80aab97ba009e5", size = 144271, upload-time = "2026-04-02T09:28:39.342Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/eb/4fc8d0a7110eb5fc9cc161723a34a8a6c200ce3b4fbf681bc86feee22308/charset_normalizer-3.4.7-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:eca9705049ad3c7345d574e3510665cb2cf844c2f2dcfe675332677f081cbd46", size = 311328, upload-time = "2026-04-02T09:26:24.331Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e3/0fadc706008ac9d7b9b5be6dc767c05f9d3e5df51744ce4cc9605de7b9f4/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6178f72c5508bfc5fd446a5905e698c6212932f25bcdd4b47a757a50605a90e2", size = 208061, upload-time = "2026-04-02T09:26:25.568Z" }, + { url = "https://files.pythonhosted.org/packages/42/f0/3dd1045c47f4a4604df85ec18ad093912ae1344ac706993aff91d38773a2/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e1421b502d83040e6d7fb2fb18dff63957f720da3d77b2fbd3187ceb63755d7b", size = 229031, upload-time = "2026-04-02T09:26:26.865Z" }, + { url = "https://files.pythonhosted.org/packages/dc/67/675a46eb016118a2fbde5a277a5d15f4f69d5f3f5f338e5ee2f8948fcf43/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:edac0f1ab77644605be2cbba52e6b7f630731fc42b34cb0f634be1a6eface56a", size = 225239, upload-time = "2026-04-02T09:26:28.044Z" }, + { url = "https://files.pythonhosted.org/packages/4b/f8/d0118a2f5f23b02cd166fa385c60f9b0d4f9194f574e2b31cef350ad7223/charset_normalizer-3.4.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5649fd1c7bade02f320a462fdefd0b4bd3ce036065836d4f42e0de958038e116", size = 216589, upload-time = "2026-04-02T09:26:29.239Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f1/6d2b0b261b6c4ceef0fcb0d17a01cc5bc53586c2d4796fa04b5c540bc13d/charset_normalizer-3.4.7-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:203104ed3e428044fd943bc4bf45fa73c0730391f9621e37fe39ecf477b128cb", size = 202733, upload-time = "2026-04-02T09:26:30.5Z" }, + { url = "https://files.pythonhosted.org/packages/6f/c0/7b1f943f7e87cc3db9626ba17807d042c38645f0a1d4415c7a14afb5591f/charset_normalizer-3.4.7-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:298930cec56029e05497a76988377cbd7457ba864beeea92ad7e844fe74cd1f1", size = 212652, upload-time = "2026-04-02T09:26:31.709Z" }, + { url = "https://files.pythonhosted.org/packages/38/dd/5a9ab159fe45c6e72079398f277b7d2b523e7f716acc489726115a910097/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:708838739abf24b2ceb208d0e22403dd018faeef86ddac04319a62ae884c4f15", size = 211229, upload-time = "2026-04-02T09:26:33.282Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ff/531a1cad5ca855d1c1a8b69cb71abfd6d85c0291580146fda7c82857caa1/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:0f7eb884681e3938906ed0434f20c63046eacd0111c4ba96f27b76084cd679f5", size = 203552, upload-time = "2026-04-02T09:26:34.845Z" }, + { url = "https://files.pythonhosted.org/packages/c1/4c/a5fb52d528a8ca41f7598cb619409ece30a169fbdf9cdce592e53b46c3a6/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4dc1e73c36828f982bfe79fadf5919923f8a6f4df2860804db9a98c48824ce8d", size = 230806, upload-time = "2026-04-02T09:26:36.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/7a/071feed8124111a32b316b33ae4de83d36923039ef8cf48120266844285b/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:aed52fea0513bac0ccde438c188c8a471c4e0f457c2dd20cdbf6ea7a450046c7", size = 212316, upload-time = "2026-04-02T09:26:37.672Z" }, + { url = "https://files.pythonhosted.org/packages/fd/35/f7dba3994312d7ba508e041eaac39a36b120f32d4c8662b8814dab876431/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:fea24543955a6a729c45a73fe90e08c743f0b3334bbf3201e6c4bc1b0c7fa464", size = 227274, upload-time = "2026-04-02T09:26:38.93Z" }, + { url = "https://files.pythonhosted.org/packages/8a/2d/a572df5c9204ab7688ec1edc895a73ebded3b023bb07364710b05dd1c9be/charset_normalizer-3.4.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bb6d88045545b26da47aa879dd4a89a71d1dce0f0e549b1abcb31dfe4a8eac49", size = 218468, upload-time = "2026-04-02T09:26:40.17Z" }, + { url = "https://files.pythonhosted.org/packages/86/eb/890922a8b03a568ca2f336c36585a4713c55d4d67bf0f0c78924be6315ca/charset_normalizer-3.4.7-cp312-cp312-win32.whl", hash = "sha256:2257141f39fe65a3fdf38aeccae4b953e5f3b3324f4ff0daf9f15b8518666a2c", size = 148460, upload-time = "2026-04-02T09:26:41.416Z" }, + { url = "https://files.pythonhosted.org/packages/35/d9/0e7dffa06c5ab081f75b1b786f0aefc88365825dfcd0ac544bdb7b2b6853/charset_normalizer-3.4.7-cp312-cp312-win_amd64.whl", hash = "sha256:5ed6ab538499c8644b8a3e18debabcd7ce684f3fa91cf867521a7a0279cab2d6", size = 159330, upload-time = "2026-04-02T09:26:42.554Z" }, + { url = "https://files.pythonhosted.org/packages/9e/5d/481bcc2a7c88ea6b0878c299547843b2521ccbc40980cb406267088bc701/charset_normalizer-3.4.7-cp312-cp312-win_arm64.whl", hash = "sha256:56be790f86bfb2c98fb742ce566dfb4816e5a83384616ab59c49e0604d49c51d", size = 147828, upload-time = "2026-04-02T09:26:44.075Z" }, + { url = "https://files.pythonhosted.org/packages/c1/3b/66777e39d3ae1ddc77ee606be4ec6d8cbd4c801f65e5a1b6f2b11b8346dd/charset_normalizer-3.4.7-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:f496c9c3cc02230093d8330875c4c3cdfc3b73612a5fd921c65d39cbcef08063", size = 309627, upload-time = "2026-04-02T09:26:45.198Z" }, + { url = "https://files.pythonhosted.org/packages/2e/4e/b7f84e617b4854ade48a1b7915c8ccfadeba444d2a18c291f696e37f0d3b/charset_normalizer-3.4.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0ea948db76d31190bf08bd371623927ee1339d5f2a0b4b1b4a4439a65298703c", size = 207008, upload-time = "2026-04-02T09:26:46.824Z" }, + { url = "https://files.pythonhosted.org/packages/c4/bb/ec73c0257c9e11b268f018f068f5d00aa0ef8c8b09f7753ebd5f2880e248/charset_normalizer-3.4.7-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a277ab8928b9f299723bc1a2dabb1265911b1a76341f90a510368ca44ad9ab66", size = 228303, upload-time = "2026-04-02T09:26:48.397Z" }, + { url = "https://files.pythonhosted.org/packages/85/fb/32d1f5033484494619f701e719429c69b766bfc4dbc61aa9e9c8c166528b/charset_normalizer-3.4.7-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3bec022aec2c514d9cf199522a802bd007cd588ab17ab2525f20f9c34d067c18", size = 224282, upload-time = "2026-04-02T09:26:49.684Z" }, + { url = "https://files.pythonhosted.org/packages/fa/07/330e3a0dda4c404d6da83b327270906e9654a24f6c546dc886a0eb0ffb23/charset_normalizer-3.4.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e044c39e41b92c845bc815e5ae4230804e8e7bc29e399b0437d64222d92809dd", size = 215595, upload-time = "2026-04-02T09:26:50.915Z" }, + { url = "https://files.pythonhosted.org/packages/e3/7c/fc890655786e423f02556e0216d4b8c6bcb6bdfa890160dc66bf52dee468/charset_normalizer-3.4.7-cp313-cp313-manylinux_2_31_armv7l.whl", hash = "sha256:f495a1652cf3fbab2eb0639776dad966c2fb874d79d87ca07f9d5f059b8bd215", size = 201986, upload-time = "2026-04-02T09:26:52.197Z" }, + { url = "https://files.pythonhosted.org/packages/d8/97/bfb18b3db2aed3b90cf54dc292ad79fdd5ad65c4eae454099475cbeadd0d/charset_normalizer-3.4.7-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e712b419df8ba5e42b226c510472b37bd57b38e897d3eca5e8cfd410a29fa859", size = 211711, upload-time = "2026-04-02T09:26:53.49Z" }, + { url = "https://files.pythonhosted.org/packages/6f/a5/a581c13798546a7fd557c82614a5c65a13df2157e9ad6373166d2a3e645d/charset_normalizer-3.4.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7804338df6fcc08105c7745f1502ba68d900f45fd770d5bdd5288ddccb8a42d8", size = 210036, upload-time = "2026-04-02T09:26:54.975Z" }, + { url = "https://files.pythonhosted.org/packages/8c/bf/b3ab5bcb478e4193d517644b0fb2bf5497fbceeaa7a1bc0f4d5b50953861/charset_normalizer-3.4.7-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:481551899c856c704d58119b5025793fa6730adda3571971af568f66d2424bb5", size = 202998, upload-time = "2026-04-02T09:26:56.303Z" }, + { url = "https://files.pythonhosted.org/packages/e7/4e/23efd79b65d314fa320ec6017b4b5834d5c12a58ba4610aa353af2e2f577/charset_normalizer-3.4.7-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f59099f9b66f0d7145115e6f80dd8b1d847176df89b234a5a6b3f00437aa0832", size = 230056, upload-time = "2026-04-02T09:26:57.554Z" }, + { url = "https://files.pythonhosted.org/packages/b9/9f/1e1941bc3f0e01df116e68dc37a55c4d249df5e6fa77f008841aef68264f/charset_normalizer-3.4.7-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:f59ad4c0e8f6bba240a9bb85504faa1ab438237199d4cce5f622761507b8f6a6", size = 211537, upload-time = "2026-04-02T09:26:58.843Z" }, + { url = "https://files.pythonhosted.org/packages/80/0f/088cbb3020d44428964a6c97fe1edfb1b9550396bf6d278330281e8b709c/charset_normalizer-3.4.7-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:3dedcc22d73ec993f42055eff4fcfed9318d1eeb9a6606c55892a26964964e48", size = 226176, upload-time = "2026-04-02T09:27:00.437Z" }, + { url = "https://files.pythonhosted.org/packages/6a/9f/130394f9bbe06f4f63e22641d32fc9b202b7e251c9aef4db044324dac493/charset_normalizer-3.4.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:64f02c6841d7d83f832cd97ccf8eb8a906d06eb95d5276069175c696b024b60a", size = 217723, upload-time = "2026-04-02T09:27:02.021Z" }, + { url = "https://files.pythonhosted.org/packages/73/55/c469897448a06e49f8fa03f6caae97074fde823f432a98f979cc42b90e69/charset_normalizer-3.4.7-cp313-cp313-win32.whl", hash = "sha256:4042d5c8f957e15221d423ba781e85d553722fc4113f523f2feb7b188cc34c5e", size = 148085, upload-time = "2026-04-02T09:27:03.192Z" }, + { url = "https://files.pythonhosted.org/packages/5d/78/1b74c5bbb3f99b77a1715c91b3e0b5bdb6fe302d95ace4f5b1bec37b0167/charset_normalizer-3.4.7-cp313-cp313-win_amd64.whl", hash = "sha256:3946fa46a0cf3e4c8cb1cc52f56bb536310d34f25f01ca9b6c16afa767dab110", size = 158819, upload-time = "2026-04-02T09:27:04.454Z" }, + { url = "https://files.pythonhosted.org/packages/68/86/46bd42279d323deb8687c4a5a811fd548cb7d1de10cf6535d099877a9a9f/charset_normalizer-3.4.7-cp313-cp313-win_arm64.whl", hash = "sha256:80d04837f55fc81da168b98de4f4b797ef007fc8a79ab71c6ec9bc4dd662b15b", size = 147915, upload-time = "2026-04-02T09:27:05.971Z" }, + { url = "https://files.pythonhosted.org/packages/97/c8/c67cb8c70e19ef1960b97b22ed2a1567711de46c4ddf19799923adc836c2/charset_normalizer-3.4.7-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:c36c333c39be2dbca264d7803333c896ab8fa7d4d6f0ab7edb7dfd7aea6e98c0", size = 309234, upload-time = "2026-04-02T09:27:07.194Z" }, + { url = "https://files.pythonhosted.org/packages/99/85/c091fdee33f20de70d6c8b522743b6f831a2f1cd3ff86de4c6a827c48a76/charset_normalizer-3.4.7-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1c2aed2e5e41f24ea8ef1590b8e848a79b56f3a5564a65ceec43c9d692dc7d8a", size = 208042, upload-time = "2026-04-02T09:27:08.749Z" }, + { url = "https://files.pythonhosted.org/packages/87/1c/ab2ce611b984d2fd5d86a5a8a19c1ae26acac6bad967da4967562c75114d/charset_normalizer-3.4.7-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:54523e136b8948060c0fa0bc7b1b50c32c186f2fceee897a495406bb6e311d2b", size = 228706, upload-time = "2026-04-02T09:27:09.951Z" }, + { url = "https://files.pythonhosted.org/packages/a8/29/2b1d2cb00bf085f59d29eb773ce58ec2d325430f8c216804a0a5cd83cbca/charset_normalizer-3.4.7-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:715479b9a2802ecac752a3b0efa2b0b60285cf962ee38414211abdfccc233b41", size = 224727, upload-time = "2026-04-02T09:27:11.175Z" }, + { url = "https://files.pythonhosted.org/packages/47/5c/032c2d5a07fe4d4855fea851209cca2b6f03ebeb6d4e3afdb3358386a684/charset_normalizer-3.4.7-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bd6c2a1c7573c64738d716488d2cdd3c00e340e4835707d8fdb8dc1a66ef164e", size = 215882, upload-time = "2026-04-02T09:27:12.446Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c2/356065d5a8b78ed04499cae5f339f091946a6a74f91e03476c33f0ab7100/charset_normalizer-3.4.7-cp314-cp314-manylinux_2_31_armv7l.whl", hash = "sha256:c45e9440fb78f8ddabcf714b68f936737a121355bf59f3907f4e17721b9d1aae", size = 200860, upload-time = "2026-04-02T09:27:13.721Z" }, + { url = "https://files.pythonhosted.org/packages/0c/cd/a32a84217ced5039f53b29f460962abb2d4420def55afabe45b1c3c7483d/charset_normalizer-3.4.7-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:3534e7dcbdcf757da6b85a0bbf5b6868786d5982dd959b065e65481644817a18", size = 211564, upload-time = "2026-04-02T09:27:15.272Z" }, + { url = "https://files.pythonhosted.org/packages/44/86/58e6f13ce26cc3b8f4a36b94a0f22ae2f00a72534520f4ae6857c4b81f89/charset_normalizer-3.4.7-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:e8ac484bf18ce6975760921bb6148041faa8fef0547200386ea0b52b5d27bf7b", size = 211276, upload-time = "2026-04-02T09:27:16.834Z" }, + { url = "https://files.pythonhosted.org/packages/8f/fe/d17c32dc72e17e155e06883efa84514ca375f8a528ba2546bee73fc4df81/charset_normalizer-3.4.7-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:a5fe03b42827c13cdccd08e6c0247b6a6d4b5e3cdc53fd1749f5896adcdc2356", size = 201238, upload-time = "2026-04-02T09:27:18.229Z" }, + { url = "https://files.pythonhosted.org/packages/6a/29/f33daa50b06525a237451cdb6c69da366c381a3dadcd833fa5676bc468b3/charset_normalizer-3.4.7-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:2d6eb928e13016cea4f1f21d1e10c1cebd5a421bc57ddf5b1142ae3f86824fab", size = 230189, upload-time = "2026-04-02T09:27:19.445Z" }, + { url = "https://files.pythonhosted.org/packages/b6/6e/52c84015394a6a0bdcd435210a7e944c5f94ea1055f5cc5d56c5fe368e7b/charset_normalizer-3.4.7-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:e74327fb75de8986940def6e8dee4f127cc9752bee7355bb323cc5b2659b6d46", size = 211352, upload-time = "2026-04-02T09:27:20.79Z" }, + { url = "https://files.pythonhosted.org/packages/8c/d7/4353be581b373033fb9198bf1da3cf8f09c1082561e8e922aa7b39bf9fe8/charset_normalizer-3.4.7-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:d6038d37043bced98a66e68d3aa2b6a35505dc01328cd65217cefe82f25def44", size = 227024, upload-time = "2026-04-02T09:27:22.063Z" }, + { url = "https://files.pythonhosted.org/packages/30/45/99d18aa925bd1740098ccd3060e238e21115fffbfdcb8f3ece837d0ace6c/charset_normalizer-3.4.7-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:7579e913a5339fb8fa133f6bbcfd8e6749696206cf05acdbdca71a1b436d8e72", size = 217869, upload-time = "2026-04-02T09:27:23.486Z" }, + { url = "https://files.pythonhosted.org/packages/5c/05/5ee478aa53f4bb7996482153d4bfe1b89e0f087f0ab6b294fcf92d595873/charset_normalizer-3.4.7-cp314-cp314-win32.whl", hash = "sha256:5b77459df20e08151cd6f8b9ef8ef1f961ef73d85c21a555c7eed5b79410ec10", size = 148541, upload-time = "2026-04-02T09:27:25.146Z" }, + { url = "https://files.pythonhosted.org/packages/48/77/72dcb0921b2ce86420b2d79d454c7022bf5be40202a2a07906b9f2a35c97/charset_normalizer-3.4.7-cp314-cp314-win_amd64.whl", hash = "sha256:92a0a01ead5e668468e952e4238cccd7c537364eb7d851ab144ab6627dbbe12f", size = 159634, upload-time = "2026-04-02T09:27:26.642Z" }, + { url = "https://files.pythonhosted.org/packages/c6/a3/c2369911cd72f02386e4e340770f6e158c7980267da16af8f668217abaa0/charset_normalizer-3.4.7-cp314-cp314-win_arm64.whl", hash = "sha256:67f6279d125ca0046a7fd386d01b311c6363844deac3e5b069b514ba3e63c246", size = 148384, upload-time = "2026-04-02T09:27:28.271Z" }, + { url = "https://files.pythonhosted.org/packages/94/09/7e8a7f73d24dba1f0035fbbf014d2c36828fc1bf9c88f84093e57d315935/charset_normalizer-3.4.7-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:effc3f449787117233702311a1b7d8f59cba9ced946ba727bdc329ec69028e24", size = 330133, upload-time = "2026-04-02T09:27:29.474Z" }, + { url = "https://files.pythonhosted.org/packages/8d/da/96975ddb11f8e977f706f45cddd8540fd8242f71ecdb5d18a80723dcf62c/charset_normalizer-3.4.7-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fbccdc05410c9ee21bbf16a35f4c1d16123dcdeb8a1d38f33654fa21d0234f79", size = 216257, upload-time = "2026-04-02T09:27:30.793Z" }, + { url = "https://files.pythonhosted.org/packages/e5/e8/1d63bf8ef2d388e95c64b2098f45f84758f6d102a087552da1485912637b/charset_normalizer-3.4.7-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:733784b6d6def852c814bce5f318d25da2ee65dd4839a0718641c696e09a2960", size = 234851, upload-time = "2026-04-02T09:27:32.44Z" }, + { url = "https://files.pythonhosted.org/packages/9b/40/e5ff04233e70da2681fa43969ad6f66ca5611d7e669be0246c4c7aaf6dc8/charset_normalizer-3.4.7-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a89c23ef8d2c6b27fd200a42aa4ac72786e7c60d40efdc76e6011260b6e949c4", size = 233393, upload-time = "2026-04-02T09:27:34.03Z" }, + { url = "https://files.pythonhosted.org/packages/be/c1/06c6c49d5a5450f76899992f1ee40b41d076aee9279b49cf9974d2f313d5/charset_normalizer-3.4.7-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6c114670c45346afedc0d947faf3c7f701051d2518b943679c8ff88befe14f8e", size = 223251, upload-time = "2026-04-02T09:27:35.369Z" }, + { url = "https://files.pythonhosted.org/packages/2b/9f/f2ff16fb050946169e3e1f82134d107e5d4ae72647ec8a1b1446c148480f/charset_normalizer-3.4.7-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:a180c5e59792af262bf263b21a3c49353f25945d8d9f70628e73de370d55e1e1", size = 206609, upload-time = "2026-04-02T09:27:36.661Z" }, + { url = "https://files.pythonhosted.org/packages/69/d5/a527c0cd8d64d2eab7459784fb4169a0ac76e5a6fc5237337982fd61347e/charset_normalizer-3.4.7-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:3c9a494bc5ec77d43cea229c4f6db1e4d8fe7e1bbffa8b6f0f0032430ff8ab44", size = 220014, upload-time = "2026-04-02T09:27:38.019Z" }, + { url = "https://files.pythonhosted.org/packages/7e/80/8a7b8104a3e203074dc9aa2c613d4b726c0e136bad1cc734594b02867972/charset_normalizer-3.4.7-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8d828b6667a32a728a1ad1d93957cdf37489c57b97ae6c4de2860fa749b8fc1e", size = 218979, upload-time = "2026-04-02T09:27:39.37Z" }, + { url = "https://files.pythonhosted.org/packages/02/9a/b759b503d507f375b2b5c153e4d2ee0a75aa215b7f2489cf314f4541f2c0/charset_normalizer-3.4.7-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:cf1493cd8607bec4d8a7b9b004e699fcf8f9103a9284cc94962cb73d20f9d4a3", size = 209238, upload-time = "2026-04-02T09:27:40.722Z" }, + { url = "https://files.pythonhosted.org/packages/c2/4e/0f3f5d47b86bdb79256e7290b26ac847a2832d9a4033f7eb2cd4bcf4bb5b/charset_normalizer-3.4.7-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:0c96c3b819b5c3e9e165495db84d41914d6894d55181d2d108cc1a69bfc9cce0", size = 236110, upload-time = "2026-04-02T09:27:42.33Z" }, + { url = "https://files.pythonhosted.org/packages/96/23/bce28734eb3ed2c91dcf93abeb8a5cf393a7b2749725030bb630e554fdd8/charset_normalizer-3.4.7-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:752a45dc4a6934060b3b0dab47e04edc3326575f82be64bc4fc293914566503e", size = 219824, upload-time = "2026-04-02T09:27:43.924Z" }, + { url = "https://files.pythonhosted.org/packages/2c/6f/6e897c6984cc4d41af319b077f2f600fc8214eb2fe2d6bcb79141b882400/charset_normalizer-3.4.7-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:8778f0c7a52e56f75d12dae53ae320fae900a8b9b4164b981b9c5ce059cd1fcb", size = 233103, upload-time = "2026-04-02T09:27:45.348Z" }, + { url = "https://files.pythonhosted.org/packages/76/22/ef7bd0fe480a0ae9b656189ec00744b60933f68b4f42a7bb06589f6f576a/charset_normalizer-3.4.7-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ce3412fbe1e31eb81ea42f4169ed94861c56e643189e1e75f0041f3fe7020abe", size = 225194, upload-time = "2026-04-02T09:27:46.706Z" }, + { url = "https://files.pythonhosted.org/packages/c5/a7/0e0ab3e0b5bc1219bd80a6a0d4d72ca74d9250cb2382b7c699c147e06017/charset_normalizer-3.4.7-cp314-cp314t-win32.whl", hash = "sha256:c03a41a8784091e67a39648f70c5f97b5b6a37f216896d44d2cdcb82615339a0", size = 159827, upload-time = "2026-04-02T09:27:48.053Z" }, + { url = "https://files.pythonhosted.org/packages/7a/1d/29d32e0fb40864b1f878c7f5a0b343ae676c6e2b271a2d55cc3a152391da/charset_normalizer-3.4.7-cp314-cp314t-win_amd64.whl", hash = "sha256:03853ed82eeebbce3c2abfdbc98c96dc205f32a79627688ac9a27370ea61a49c", size = 174168, upload-time = "2026-04-02T09:27:49.795Z" }, + { url = "https://files.pythonhosted.org/packages/de/32/d92444ad05c7a6e41fb2036749777c163baf7a0301a040cb672d6b2b1ae9/charset_normalizer-3.4.7-cp314-cp314t-win_arm64.whl", hash = "sha256:c35abb8bfff0185efac5878da64c45dafd2b37fb0383add1be155a763c1f083d", size = 153018, upload-time = "2026-04-02T09:27:51.116Z" }, + { url = "https://files.pythonhosted.org/packages/db/8f/61959034484a4a7c527811f4721e75d02d653a35afb0b6054474d8185d4c/charset_normalizer-3.4.7-py3-none-any.whl", hash = "sha256:3dce51d0f5e7951f8bb4900c257dad282f49190fdbebecd4ba99bcc41fef404d", size = 61958, upload-time = "2026-04-02T09:28:37.794Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -383,6 +465,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ab/fb21f4c939bb440104cc2b396d3be1d9b7a9fd3c6c2a53d98c45b3d7c954/fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437", size = 202505, upload-time = "2026-02-05T21:50:51.819Z" }, ] +[[package]] +name = "idna" +version = "3.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ce/cc/762dfb036166873f0059f3b7de4565e1b5bc3d6f28a414c13da27e442f99/idna-3.13.tar.gz", hash = "sha256:585ea8fe5d69b9181ec1afba340451fba6ba764af97026f92a91d4eef164a242", size = 194210, upload-time = "2026-04-22T16:42:42.314Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/13/ad7d7ca3808a898b4612b6fe93cde56b53f3034dcde235acb1f0e1df24c6/idna-3.13-py3-none-any.whl", hash = "sha256:892ea0cde124a99ce773decba204c5552b69c3c67ffd5f232eb7696135bc8bb3", size = 68629, upload-time = "2026-04-22T16:42:40.909Z" }, +] + [[package]] name = "iniconfig" version = "2.3.0" @@ -753,8 +844,10 @@ dependencies = [ examples = [ { name = "ipykernel" }, { name = "matplotlib" }, + { name = "posteriordb" }, { name = "pyro-ppl" }, { name = "scipy" }, + { name = "stanpy" }, ] [package.dev-dependencies] @@ -774,8 +867,10 @@ requires-dist = [ { name = "jaxtyping", specifier = ">=0.3.5" }, { name = "matplotlib", marker = "extra == 'examples'", specifier = ">=3.10.8" }, { name = "numpy", specifier = ">=2.4.1" }, + { name = "posteriordb", marker = "extra == 'examples'", specifier = ">=0.2.0" }, { name = "pyro-ppl", marker = "extra == 'examples'", specifier = ">=1.9.1" }, { name = "scipy", marker = "extra == 'examples'", specifier = ">=1.17.1" }, + { name = "stanpy", marker = "extra == 'examples'", specifier = ">=0.2.11" }, { name = "torch", specifier = ">=2.10" }, { name = "tqdm", specifier = ">=4.67.1" }, ] @@ -1183,6 +1278,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "posteriordb" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/4d/b72e0782abec07f3d8dabf24cf12673d26b173af2046eb4e67365c776ccf/posteriordb-0.2.0-py3-none-any.whl", hash = "sha256:b6d6f3a349d34db6d4a68da899c818a95e5824c5e23824fc0ebe422f4bd6bac1", size = 24059, upload-time = "2020-11-25T12:04:47.729Z" }, +] + [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -1378,6 +1484,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/d6/4bfbb40c9a0b42fc53c7cf442f6385db70b40f74a783130c5d0a5aa62228/pyzmq-27.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:dc5dbf68a7857b59473f7df42650c621d7e8923fb03fa74a526890f4d33cc4d7", size = 575170, upload-time = "2025-09-08T23:09:01.418Z" }, ] +[[package]] +name = "requests" +version = "2.33.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5f/a4/98b9c7c6428a668bf7e42ebb7c79d576a1c3c1e3ae2d47e674b468388871/requests-2.33.1.tar.gz", hash = "sha256:18817f8c57c6263968bc123d237e3b8b08ac046f5456bd1e307ee8f4250d3517", size = 134120, upload-time = "2026-03-30T16:09:15.531Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d7/8e/7540e8a2036f79a125c1d2ebadf69ed7901608859186c856fa0388ef4197/requests-2.33.1-py3-none-any.whl", hash = "sha256:4e6d1ef462f3626a1f0a0a9c42dd93c63bad33f9f1c1937509b8c5c8718ab56a", size = 64947, upload-time = "2026-03-30T16:09:13.83Z" }, +] + [[package]] name = "ruff" version = "0.15.8" @@ -1496,6 +1617,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, ] +[[package]] +name = "stanpy" +version = "0.2.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/cc/74da52b7de8ee281a2fdffcfba578a8b6b317a7801631d3800cb6a21ec80/stanpy-0.2.11.tar.gz", hash = "sha256:6b6354d042a705b9657392a1cee8c17ebacc2e43de8ed5dfd44e6cab52822530", size = 7809464, upload-time = "2022-04-09T10:49:48.038Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fe/06/19356ea9f4f20d997d5c2dc5f32d72cd220038f3d5dd2f33957d56cb6ba3/stanpy-0.2.11-py2.py3-none-any.whl", hash = "sha256:64fec89761e56a520d124f9487c365f78545145f0d1fee64c1e085d1f6c4adff", size = 28068, upload-time = "2022-04-09T10:49:34.812Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -1636,6 +1766,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/70/d460bd685a170790ec89317e9bd33047988e4bce507b831f5db771e142de/tzdata-2026.1-py2.py3-none-any.whl", hash = "sha256:4b1d2be7ac37ceafd7327b961aa3a54e467efbdb563a23655fbfe0d39cfc42a9", size = 348952, upload-time = "2026-04-03T11:25:20.313Z" }, ] +[[package]] +name = "urllib3" +version = "2.6.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, +] + [[package]] name = "wadler-lindig" version = "0.1.7" From 4341df762c0302005915565805e00edef5a2c018 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 26 Apr 2026 17:02:47 -0400 Subject: [PATCH 34/60] Fix stan versioning --- pyproject.toml | 2 +- uv.lock | 656 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 646 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8e9d441..3bdd0bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,8 @@ examples = [ "matplotlib>=3.10.8", "posteriordb>=0.2.0", "pyro-ppl>=1.9.1", + "pystan>=3.10.1", "scipy>=1.17.1", - "stanpy>=0.2.11", ] [build-system] diff --git a/uv.lock b/uv.lock index 2f99b57..cd9ce29 100644 --- a/uv.lock +++ b/uv.lock @@ -10,6 +10,122 @@ resolution-markers = [ "python_full_version < '3.14' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] +[[package]] +name = "aiohappyeyeballs" +version = "2.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/26/30/f84a107a9c4331c14b2b586036f40965c128aa4fee4dda5d3d51cb14ad54/aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558", size = 22760, upload-time = "2025-03-12T01:42:48.764Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/15/5bf3b99495fb160b63f95972b81750f18f7f4e02ad051373b669d17d44f2/aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8", size = 15265, upload-time = "2025-03-12T01:42:47.083Z" }, +] + +[[package]] +name = "aiohttp" +version = "3.13.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohappyeyeballs" }, + { name = "aiosignal" }, + { name = "attrs" }, + { name = "frozenlist" }, + { name = "multidict" }, + { name = "propcache" }, + { name = "yarl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/9a/152096d4808df8e4268befa55fba462f440f14beab85e8ad9bf990516918/aiohttp-3.13.5.tar.gz", hash = "sha256:9d98cc980ecc96be6eb4c1994ce35d28d8b1f5e5208a23b421187d1209dbb7d1", size = 7858271, upload-time = "2026-03-31T22:01:03.343Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/6f/353954c29e7dcce7cf00280a02c75f30e133c00793c7a2ed3776d7b2f426/aiohttp-3.13.5-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:023ecba036ddd840b0b19bf195bfae970083fd7024ce1ac22e9bba90464620e9", size = 748876, upload-time = "2026-03-31T21:57:36.319Z" }, + { url = "https://files.pythonhosted.org/packages/f5/1b/428a7c64687b3b2e9cd293186695affc0e1e54a445d0361743b231f11066/aiohttp-3.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:15c933ad7920b7d9a20de151efcd05a6e38302cbf0e10c9b2acb9a42210a2416", size = 499557, upload-time = "2026-03-31T21:57:38.236Z" }, + { url = "https://files.pythonhosted.org/packages/29/47/7be41556bfbb6917069d6a6634bb7dd5e163ba445b783a90d40f5ac7e3a7/aiohttp-3.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ab2899f9fa2f9f741896ebb6fa07c4c883bfa5c7f2ddd8cf2aafa86fa981b2d2", size = 500258, upload-time = "2026-03-31T21:57:39.923Z" }, + { url = "https://files.pythonhosted.org/packages/67/84/c9ecc5828cb0b3695856c07c0a6817a99d51e2473400f705275a2b3d9239/aiohttp-3.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60eaa2d440cd4707696b52e40ed3e2b0f73f65be07fd0ef23b6b539c9c0b0b4", size = 1749199, upload-time = "2026-03-31T21:57:41.938Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d3/3c6d610e66b495657622edb6ae7c7fd31b2e9086b4ec50b47897ad6042a9/aiohttp-3.13.5-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55b3bdd3292283295774ab585160c4004f4f2f203946997f49aac032c84649e9", size = 1721013, upload-time = "2026-03-31T21:57:43.904Z" }, + { url = "https://files.pythonhosted.org/packages/49/a0/24409c12217456df0bae7babe3b014e460b0b38a8e60753d6cb339f6556d/aiohttp-3.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c2b2355dc094e5f7d45a7bb262fe7207aa0460b37a0d87027dcf21b5d890e7d5", size = 1781501, upload-time = "2026-03-31T21:57:46.285Z" }, + { url = "https://files.pythonhosted.org/packages/98/9d/b65ec649adc5bccc008b0957a9a9c691070aeac4e41cea18559fef49958b/aiohttp-3.13.5-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b38765950832f7d728297689ad78f5f2cf79ff82487131c4d26fe6ceecdc5f8e", size = 1878981, upload-time = "2026-03-31T21:57:48.734Z" }, + { url = "https://files.pythonhosted.org/packages/57/d8/8d44036d7eb7b6a8ec4c5494ea0c8c8b94fbc0ed3991c1a7adf230df03bf/aiohttp-3.13.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b18f31b80d5a33661e08c89e202edabf1986e9b49c42b4504371daeaa11b47c1", size = 1767934, upload-time = "2026-03-31T21:57:51.171Z" }, + { url = "https://files.pythonhosted.org/packages/31/04/d3f8211f273356f158e3464e9e45484d3fb8c4ce5eb2f6fe9405c3273983/aiohttp-3.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:33add2463dde55c4f2d9635c6ab33ce154e5ecf322bd26d09af95c5f81cfa286", size = 1566671, upload-time = "2026-03-31T21:57:53.326Z" }, + { url = "https://files.pythonhosted.org/packages/41/db/073e4ebe00b78e2dfcacff734291651729a62953b48933d765dc513bf798/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:327cc432fdf1356fb4fbc6fe833ad4e9f6aacb71a8acaa5f1855e4b25910e4a9", size = 1705219, upload-time = "2026-03-31T21:57:55.385Z" }, + { url = "https://files.pythonhosted.org/packages/48/45/7dfba71a2f9fd97b15c95c06819de7eb38113d2cdb6319669195a7d64270/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:7c35b0bf0b48a70b4cb4fc5d7bed9b932532728e124874355de1a0af8ec4bc88", size = 1743049, upload-time = "2026-03-31T21:57:57.341Z" }, + { url = "https://files.pythonhosted.org/packages/18/71/901db0061e0f717d226386a7f471bb59b19566f2cae5f0d93874b017271f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:df23d57718f24badef8656c49743e11a89fd6f5358fa8a7b96e728fda2abf7d3", size = 1749557, upload-time = "2026-03-31T21:57:59.626Z" }, + { url = "https://files.pythonhosted.org/packages/08/d5/41eebd16066e59cd43728fe74bce953d7402f2b4ddfdfef2c0e9f17ca274/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:02e048037a6501a5ec1f6fc9736135aec6eb8a004ce48838cb951c515f32c80b", size = 1558931, upload-time = "2026-03-31T21:58:01.972Z" }, + { url = "https://files.pythonhosted.org/packages/30/e6/4a799798bf05740e66c3a1161079bda7a3dd8e22ca392481d7a7f9af82a6/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:31cebae8b26f8a615d2b546fee45d5ffb76852ae6450e2a03f42c9102260d6fe", size = 1774125, upload-time = "2026-03-31T21:58:04.007Z" }, + { url = "https://files.pythonhosted.org/packages/84/63/7749337c90f92bc2cb18f9560d67aa6258c7060d1397d21529b8004fcf6f/aiohttp-3.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:888e78eb5ca55a615d285c3c09a7a91b42e9dd6fc699b166ebd5dee87c9ccf14", size = 1732427, upload-time = "2026-03-31T21:58:06.337Z" }, + { url = "https://files.pythonhosted.org/packages/98/de/cf2f44ff98d307e72fb97d5f5bbae3bfcb442f0ea9790c0bf5c5c2331404/aiohttp-3.13.5-cp312-cp312-win32.whl", hash = "sha256:8bd3ec6376e68a41f9f95f5ed170e2fcf22d4eb27a1f8cb361d0508f6e0557f3", size = 433534, upload-time = "2026-03-31T21:58:08.712Z" }, + { url = "https://files.pythonhosted.org/packages/aa/ca/eadf6f9c8fa5e31d40993e3db153fb5ed0b11008ad5d9de98a95045bed84/aiohttp-3.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:110e448e02c729bcebb18c60b9214a87ba33bac4a9fa5e9a5f139938b56c6cb1", size = 460446, upload-time = "2026-03-31T21:58:10.945Z" }, + { url = "https://files.pythonhosted.org/packages/78/e9/d76bf503005709e390122d34e15256b88f7008e246c4bdbe915cd4f1adce/aiohttp-3.13.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5029cc80718bbd545123cd8fe5d15025eccaaaace5d0eeec6bd556ad6163d61", size = 742930, upload-time = "2026-03-31T21:58:13.155Z" }, + { url = "https://files.pythonhosted.org/packages/57/00/4b7b70223deaebd9bb85984d01a764b0d7bd6526fcdc73cca83bcbe7243e/aiohttp-3.13.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4bb6bf5811620003614076bdc807ef3b5e38244f9d25ca5fe888eaccea2a9832", size = 496927, upload-time = "2026-03-31T21:58:15.073Z" }, + { url = "https://files.pythonhosted.org/packages/9c/f5/0fb20fb49f8efdcdce6cd8127604ad2c503e754a8f139f5e02b01626523f/aiohttp-3.13.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a84792f8631bf5a94e52d9cc881c0b824ab42717165a5579c760b830d9392ac9", size = 497141, upload-time = "2026-03-31T21:58:17.009Z" }, + { url = "https://files.pythonhosted.org/packages/3b/86/b7c870053e36a94e8951b803cb5b909bfbc9b90ca941527f5fcafbf6b0fa/aiohttp-3.13.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:57653eac22c6a4c13eb22ecf4d673d64a12f266e72785ab1c8b8e5940d0e8090", size = 1732476, upload-time = "2026-03-31T21:58:18.925Z" }, + { url = "https://files.pythonhosted.org/packages/b5/e5/4e161f84f98d80c03a238671b4136e6530453d65262867d989bbe78244d0/aiohttp-3.13.5-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5e5f7debc7a57af53fdf5c5009f9391d9f4c12867049d509bf7bb164a6e295b", size = 1706507, upload-time = "2026-03-31T21:58:21.094Z" }, + { url = "https://files.pythonhosted.org/packages/d4/56/ea11a9f01518bd5a2a2fcee869d248c4b8a0cfa0bb13401574fa31adf4d4/aiohttp-3.13.5-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c719f65bebcdf6716f10e9eff80d27567f7892d8988c06de12bbbd39307c6e3a", size = 1773465, upload-time = "2026-03-31T21:58:23.159Z" }, + { url = "https://files.pythonhosted.org/packages/eb/40/333ca27fb74b0383f17c90570c748f7582501507307350a79d9f9f3c6eb1/aiohttp-3.13.5-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d97f93fdae594d886c5a866636397e2bcab146fd7a132fd6bb9ce182224452f8", size = 1873523, upload-time = "2026-03-31T21:58:25.59Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d2/e2f77eef1acb7111405433c707dc735e63f67a56e176e72e9e7a2cd3f493/aiohttp-3.13.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3df334e39d4c2f899a914f1dba283c1aadc311790733f705182998c6f7cae665", size = 1754113, upload-time = "2026-03-31T21:58:27.624Z" }, + { url = "https://files.pythonhosted.org/packages/fb/56/3f653d7f53c89669301ec9e42c95233e2a0c0a6dd051269e6e678db4fdb0/aiohttp-3.13.5-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe6970addfea9e5e081401bcbadf865d2b6da045472f58af08427e108d618540", size = 1562351, upload-time = "2026-03-31T21:58:29.918Z" }, + { url = "https://files.pythonhosted.org/packages/ec/a6/9b3e91eb8ae791cce4ee736da02211c85c6f835f1bdfac0594a8a3b7018c/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7becdf835feff2f4f335d7477f121af787e3504b48b449ff737afb35869ba7bb", size = 1693205, upload-time = "2026-03-31T21:58:32.214Z" }, + { url = "https://files.pythonhosted.org/packages/98/fc/bfb437a99a2fcebd6b6eaec609571954de2ed424f01c352f4b5504371dd3/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:676e5651705ad5d8a70aeb8eb6936c436d8ebbd56e63436cb7dd9bb36d2a9a46", size = 1730618, upload-time = "2026-03-31T21:58:34.728Z" }, + { url = "https://files.pythonhosted.org/packages/e4/b6/c8534862126191a034f68153194c389addc285a0f1347d85096d349bbc15/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:9b16c653d38eb1a611cc898c41e76859ca27f119d25b53c12875fd0474ae31a8", size = 1745185, upload-time = "2026-03-31T21:58:36.909Z" }, + { url = "https://files.pythonhosted.org/packages/0b/93/4ca8ee2ef5236e2707e0fd5fecb10ce214aee1ff4ab307af9c558bda3b37/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:999802d5fa0389f58decd24b537c54aa63c01c3219ce17d1214cbda3c2b22d2d", size = 1557311, upload-time = "2026-03-31T21:58:39.38Z" }, + { url = "https://files.pythonhosted.org/packages/57/ae/76177b15f18c5f5d094f19901d284025db28eccc5ae374d1d254181d33f4/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:ec707059ee75732b1ba130ed5f9580fe10ff75180c812bc267ded039db5128c6", size = 1773147, upload-time = "2026-03-31T21:58:41.476Z" }, + { url = "https://files.pythonhosted.org/packages/01/a4/62f05a0a98d88af59d93b7fcac564e5f18f513cb7471696ac286db970d6a/aiohttp-3.13.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2d6d44a5b48132053c2f6cd5c8cb14bc67e99a63594e336b0f2af81e94d5530c", size = 1730356, upload-time = "2026-03-31T21:58:44.049Z" }, + { url = "https://files.pythonhosted.org/packages/e4/85/fc8601f59dfa8c9523808281f2da571f8b4699685f9809a228adcc90838d/aiohttp-3.13.5-cp313-cp313-win32.whl", hash = "sha256:329f292ed14d38a6c4c435e465f48bebb47479fd676a0411936cc371643225cc", size = 432637, upload-time = "2026-03-31T21:58:46.167Z" }, + { url = "https://files.pythonhosted.org/packages/c0/1b/ac685a8882896acf0f6b31d689e3792199cfe7aba37969fa91da63a7fa27/aiohttp-3.13.5-cp313-cp313-win_amd64.whl", hash = "sha256:69f571de7500e0557801c0b51f4780482c0ec5fe2ac851af5a92cfce1af1cb83", size = 458896, upload-time = "2026-03-31T21:58:48.119Z" }, + { url = "https://files.pythonhosted.org/packages/5d/ce/46572759afc859e867a5bc8ec3487315869013f59281ce61764f76d879de/aiohttp-3.13.5-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:eb4639f32fd4a9904ab8fb45bf3383ba71137f3d9d4ba25b3b3f3109977c5b8c", size = 745721, upload-time = "2026-03-31T21:58:50.229Z" }, + { url = "https://files.pythonhosted.org/packages/13/fe/8a2efd7626dbe6049b2ef8ace18ffda8a4dfcbe1bcff3ac30c0c7575c20b/aiohttp-3.13.5-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:7e5dc4311bd5ac493886c63cbf76ab579dbe4641268e7c74e48e774c74b6f2be", size = 497663, upload-time = "2026-03-31T21:58:52.232Z" }, + { url = "https://files.pythonhosted.org/packages/9b/91/cc8cc78a111826c54743d88651e1687008133c37e5ee615fee9b57990fac/aiohttp-3.13.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:756c3c304d394977519824449600adaf2be0ccee76d206ee339c5e76b70ded25", size = 499094, upload-time = "2026-03-31T21:58:54.566Z" }, + { url = "https://files.pythonhosted.org/packages/0a/33/a8362cb15cf16a3af7e86ed11962d5cd7d59b449202dc576cdc731310bde/aiohttp-3.13.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecc26751323224cf8186efcf7fbcbc30f4e1d8c7970659daf25ad995e4032a56", size = 1726701, upload-time = "2026-03-31T21:58:56.864Z" }, + { url = "https://files.pythonhosted.org/packages/45/0c/c091ac5c3a17114bd76cbf85d674650969ddf93387876cf67f754204bd77/aiohttp-3.13.5-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10a75acfcf794edf9d8db50e5a7ec5fc818b2a8d3f591ce93bc7b1210df016d2", size = 1683360, upload-time = "2026-03-31T21:58:59.072Z" }, + { url = "https://files.pythonhosted.org/packages/23/73/bcee1c2b79bc275e964d1446c55c54441a461938e70267c86afaae6fba27/aiohttp-3.13.5-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:0f7a18f258d124cd678c5fe072fe4432a4d5232b0657fca7c1847f599233c83a", size = 1773023, upload-time = "2026-03-31T21:59:01.776Z" }, + { url = "https://files.pythonhosted.org/packages/c7/ef/720e639df03004fee2d869f771799d8c23046dec47d5b81e396c7cda583a/aiohttp-3.13.5-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:df6104c009713d3a89621096f3e3e88cc323fd269dbd7c20afe18535094320be", size = 1853795, upload-time = "2026-03-31T21:59:04.568Z" }, + { url = "https://files.pythonhosted.org/packages/bd/c9/989f4034fb46841208de7aeeac2c6d8300745ab4f28c42f629ba77c2d916/aiohttp-3.13.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:241a94f7de7c0c3b616627aaad530fe2cb620084a8b144d3be7b6ecfe95bae3b", size = 1730405, upload-time = "2026-03-31T21:59:07.221Z" }, + { url = "https://files.pythonhosted.org/packages/ce/75/ee1fd286ca7dc599d824b5651dad7b3be7ff8d9a7e7b3fe9820d9180f7db/aiohttp-3.13.5-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c974fb66180e58709b6fc402846f13791240d180b74de81d23913abe48e96d94", size = 1558082, upload-time = "2026-03-31T21:59:09.484Z" }, + { url = "https://files.pythonhosted.org/packages/c3/20/1e9e6650dfc436340116b7aa89ff8cb2bbdf0abc11dfaceaad8f74273a10/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:6e27ea05d184afac78aabbac667450c75e54e35f62238d44463131bd3f96753d", size = 1692346, upload-time = "2026-03-31T21:59:12.068Z" }, + { url = "https://files.pythonhosted.org/packages/d8/40/8ebc6658d48ea630ac7903912fe0dd4e262f0e16825aa4c833c56c9f1f56/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:a79a6d399cef33a11b6f004c67bb07741d91f2be01b8d712d52c75711b1e07c7", size = 1698891, upload-time = "2026-03-31T21:59:14.552Z" }, + { url = "https://files.pythonhosted.org/packages/d8/78/ea0ae5ec8ba7a5c10bdd6e318f1ba5e76fcde17db8275188772afc7917a4/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:c632ce9c0b534fbe25b52c974515ed674937c5b99f549a92127c85f771a78772", size = 1742113, upload-time = "2026-03-31T21:59:17.068Z" }, + { url = "https://files.pythonhosted.org/packages/8a/66/9d308ed71e3f2491be1acb8769d96c6f0c47d92099f3bc9119cada27b357/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:fceedde51fbd67ee2bcc8c0b33d0126cc8b51ef3bbde2f86662bd6d5a6f10ec5", size = 1553088, upload-time = "2026-03-31T21:59:19.541Z" }, + { url = "https://files.pythonhosted.org/packages/da/a6/6cc25ed8dfc6e00c90f5c6d126a98e2cf28957ad06fa1036bd34b6f24a2c/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f92995dfec9420bb69ae629abf422e516923ba79ba4403bc750d94fb4a6c68c1", size = 1757976, upload-time = "2026-03-31T21:59:22.311Z" }, + { url = "https://files.pythonhosted.org/packages/c1/2b/cce5b0ffe0de99c83e5e36d8f828e4161e415660a9f3e58339d07cce3006/aiohttp-3.13.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:20ae0ff08b1f2c8788d6fb85afcb798654ae6ba0b747575f8562de738078457b", size = 1712444, upload-time = "2026-03-31T21:59:24.635Z" }, + { url = "https://files.pythonhosted.org/packages/6c/cf/9e1795b4160c58d29421eafd1a69c6ce351e2f7c8d3c6b7e4ca44aea1a5b/aiohttp-3.13.5-cp314-cp314-win32.whl", hash = "sha256:b20df693de16f42b2472a9c485e1c948ee55524786a0a34345511afdd22246f3", size = 438128, upload-time = "2026-03-31T21:59:27.291Z" }, + { url = "https://files.pythonhosted.org/packages/22/4d/eaedff67fc805aeba4ba746aec891b4b24cebb1a7d078084b6300f79d063/aiohttp-3.13.5-cp314-cp314-win_amd64.whl", hash = "sha256:f85c6f327bf0b8c29da7d93b1cabb6363fb5e4e160a32fa241ed2dce21b73162", size = 464029, upload-time = "2026-03-31T21:59:29.429Z" }, + { url = "https://files.pythonhosted.org/packages/79/11/c27d9332ee20d68dd164dc12a6ecdef2e2e35ecc97ed6cf0d2442844624b/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:1efb06900858bb618ff5cee184ae2de5828896c448403d51fb633f09e109be0a", size = 778758, upload-time = "2026-03-31T21:59:31.547Z" }, + { url = "https://files.pythonhosted.org/packages/04/fb/377aead2e0a3ba5f09b7624f702a964bdf4f08b5b6728a9799830c80041e/aiohttp-3.13.5-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:fee86b7c4bd29bdaf0d53d14739b08a106fdda809ca5fe032a15f52fae5fe254", size = 512883, upload-time = "2026-03-31T21:59:34.098Z" }, + { url = "https://files.pythonhosted.org/packages/bb/a6/aa109a33671f7a5d3bd78b46da9d852797c5e665bfda7d6b373f56bff2ec/aiohttp-3.13.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:20058e23909b9e65f9da62b396b77dfa95965cbe840f8def6e572538b1d32e36", size = 516668, upload-time = "2026-03-31T21:59:36.497Z" }, + { url = "https://files.pythonhosted.org/packages/79/b3/ca078f9f2fa9563c36fb8ef89053ea2bb146d6f792c5104574d49d8acb63/aiohttp-3.13.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cf20a8d6868cb15a73cab329ffc07291ba8c22b1b88176026106ae39aa6df0f", size = 1883461, upload-time = "2026-03-31T21:59:38.723Z" }, + { url = "https://files.pythonhosted.org/packages/b7/e3/a7ad633ca1ca497b852233a3cce6906a56c3225fb6d9217b5e5e60b7419d/aiohttp-3.13.5-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:330f5da04c987f1d5bdb8ae189137c77139f36bd1cb23779ca1a354a4b027800", size = 1747661, upload-time = "2026-03-31T21:59:41.187Z" }, + { url = "https://files.pythonhosted.org/packages/33/b9/cd6fe579bed34a906d3d783fe60f2fa297ef55b27bb4538438ee49d4dc41/aiohttp-3.13.5-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6f1cbf0c7926d315c3c26c2da41fd2b5d2fe01ac0e157b78caefc51a782196cf", size = 1863800, upload-time = "2026-03-31T21:59:43.84Z" }, + { url = "https://files.pythonhosted.org/packages/c0/3f/2c1e2f5144cefa889c8afd5cf431994c32f3b29da9961698ff4e3811b79a/aiohttp-3.13.5-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:53fc049ed6390d05423ba33103ded7281fe897cf97878f369a527070bd95795b", size = 1958382, upload-time = "2026-03-31T21:59:46.187Z" }, + { url = "https://files.pythonhosted.org/packages/66/1d/f31ec3f1013723b3babe3609e7f119c2c2fb6ef33da90061a705ef3e1bc8/aiohttp-3.13.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:898703aa2667e3c5ca4c54ca36cd73f58b7a38ef87a5606414799ebce4d3fd3a", size = 1803724, upload-time = "2026-03-31T21:59:48.656Z" }, + { url = "https://files.pythonhosted.org/packages/0e/b4/57712dfc6f1542f067daa81eb61da282fab3e6f1966fca25db06c4fc62d5/aiohttp-3.13.5-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:0494a01ca9584eea1e5fbd6d748e61ecff218c51b576ee1999c23db7066417d8", size = 1640027, upload-time = "2026-03-31T21:59:51.284Z" }, + { url = "https://files.pythonhosted.org/packages/25/3c/734c878fb43ec083d8e31bf029daae1beafeae582d1b35da234739e82ee7/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:6cf81fe010b8c17b09495cbd15c1d35afbc8fb405c0c9cf4738e5ae3af1d65be", size = 1806644, upload-time = "2026-03-31T21:59:53.753Z" }, + { url = "https://files.pythonhosted.org/packages/20/a5/f671e5cbec1c21d044ff3078223f949748f3a7f86b14e34a365d74a5d21f/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:c564dd5f09ddc9d8f2c2d0a301cd30a79a2cc1b46dd1a73bef8f0038863d016b", size = 1791630, upload-time = "2026-03-31T21:59:56.239Z" }, + { url = "https://files.pythonhosted.org/packages/0b/63/fb8d0ad63a0b8a99be97deac8c04dacf0785721c158bdf23d679a87aa99e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:2994be9f6e51046c4f864598fd9abeb4fba6e88f0b2152422c9666dcd4aea9c6", size = 1809403, upload-time = "2026-03-31T21:59:59.103Z" }, + { url = "https://files.pythonhosted.org/packages/59/0c/bfed7f30662fcf12206481c2aac57dedee43fe1c49275e85b3a1e1742294/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:157826e2fa245d2ef46c83ea8a5faf77ca19355d278d425c29fda0beb3318037", size = 1634924, upload-time = "2026-03-31T22:00:02.116Z" }, + { url = "https://files.pythonhosted.org/packages/17/d6/fd518d668a09fd5a3319ae5e984d4d80b9a4b3df4e21c52f02251ef5a32e/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:a8aca50daa9493e9e13c0f566201a9006f080e7c50e5e90d0b06f53146a54500", size = 1836119, upload-time = "2026-03-31T22:00:04.756Z" }, + { url = "https://files.pythonhosted.org/packages/78/b7/15fb7a9d52e112a25b621c67b69c167805cb1f2ab8f1708a5c490d1b52fe/aiohttp-3.13.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:3b13560160d07e047a93f23aaa30718606493036253d5430887514715b67c9d9", size = 1772072, upload-time = "2026-03-31T22:00:07.494Z" }, + { url = "https://files.pythonhosted.org/packages/7e/df/57ba7f0c4a553fc2bd8b6321df236870ec6fd64a2a473a8a13d4f733214e/aiohttp-3.13.5-cp314-cp314t-win32.whl", hash = "sha256:9a0f4474b6ea6818b41f82172d799e4b3d29e22c2c520ce4357856fced9af2f8", size = 471819, upload-time = "2026-03-31T22:00:10.277Z" }, + { url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" }, +] + +[[package]] +name = "aiosignal" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "frozenlist" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, +] + +[[package]] +name = "appdirs" +version = "1.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/d8/05696357e0311f5b5c316d7b95f46c669dd9c15aaeecbb48c7d0aeb88c40/appdirs-1.4.4.tar.gz", hash = "sha256:7d5d0167b2b1ba821647616af46a749d1c653740dd0d2415100fe26e27afdf41", size = 13470, upload-time = "2020-05-11T07:59:51.037Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/00/2344469e2084fb287c2e0b57b72910309874c3245463acd6cf5e3db69324/appdirs-1.4.4-py2.py3-none-any.whl", hash = "sha256:a841dacd6b99318a741b166adb07e19ee71a274450e68237b4650ca1055ab128", size = 9566, upload-time = "2020-05-11T07:59:49.499Z" }, +] + [[package]] name = "appnope" version = "0.1.4" @@ -28,6 +144,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl", hash = "sha256:15a3ebc0f43c2d0a50eeafea25e19046c68398e487b9f1f5b517f7c0f40f976a", size = 27047, upload-time = "2025-11-15T16:43:16.109Z" }, ] +[[package]] +name = "attrs" +version = "26.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/8e/82a0fe20a541c03148528be8cac2408564a6c9a0cc7e9171802bc1d26985/attrs-26.1.0.tar.gz", hash = "sha256:d03ceb89cb322a8fd706d4fb91940737b6642aa36998fe130a9bc96c985eff32", size = 952055, upload-time = "2026-03-19T14:22:25.026Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" }, +] + [[package]] name = "certifi" version = "2026.4.22" @@ -167,6 +292,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/8f/61959034484a4a7c527811f4721e75d02d653a35afb0b6054474d8185d4c/charset_normalizer-3.4.7-py3-none-any.whl", hash = "sha256:3dce51d0f5e7951f8bb4900c257dad282f49190fdbebecd4ba99bcc41fef404d", size = 61958, upload-time = "2026-04-02T09:28:37.794Z" }, ] +[[package]] +name = "clikit" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "crashtest", marker = "python_full_version < '4'" }, + { name = "pastel" }, + { name = "pylev" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/07/27d700f8447c0ca81454a4acdb7eb200229a6d06fe0b1439acc3da49a53f/clikit-0.6.2.tar.gz", hash = "sha256:442ee5db9a14120635c5990bcdbfe7c03ada5898291f0c802f77be71569ded59", size = 56214, upload-time = "2020-06-09T20:17:18.298Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f2/3d/4394c710b9195b83382dc67bdd1040e5ebfc3fc8df90e20fe74341298c57/clikit-0.6.2-py2.py3-none-any.whl", hash = "sha256:71268e074e68082306e23d7369a7b99f824a0ef926e55ba2665e911f7208489e", size = 91825, upload-time = "2020-06-09T20:17:17.178Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -335,6 +474,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/ee/a4cf96b8ce1e566ed238f0659ac2d3f007ed1d14b181bcb684e19561a69a/coverage-7.13.5-py3-none-any.whl", hash = "sha256:34b02417cf070e173989b3db962f7ed56d2f644307b2cf9d5a0f258e13084a61", size = 211346, upload-time = "2026-03-17T10:33:15.691Z" }, ] +[[package]] +name = "crashtest" +version = "0.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/08/3c/5ec13020a4693fab34e1f438fe6e96aed6551740e1f4a5cc66e8b84491ea/crashtest-0.3.1.tar.gz", hash = "sha256:42ca7b6ce88b6c7433e2ce47ea884e91ec93104a4b754998be498a8e6c3d37dd", size = 4333, upload-time = "2020-07-31T13:32:29.862Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/97/2a99f020be5e4a5a97ba10bc480e2e6a889b5087103a2c6b952b5f819d27/crashtest-0.3.1-py3-none-any.whl", hash = "sha256:300f4b0825f57688b47b6d70c6a31de33512eb2fa1ac614f780939aa0cf91680", size = 6966, upload-time = "2020-07-31T13:32:28.18Z" }, +] + [[package]] name = "cuda-bindings" version = "12.9.4" @@ -456,6 +604,95 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fd/ba/56147c165442cc5ba7e82ecf301c9a68353cede498185869e6e02b4c264f/fonttools-4.62.1-py3-none-any.whl", hash = "sha256:7487782e2113861f4ddcc07c3436450659e3caa5e470b27dc2177cade2d8e7fd", size = 1152647, upload-time = "2026-03-13T13:54:22.735Z" }, ] +[[package]] +name = "frozenlist" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/f5/c831fac6cc817d26fd54c7eaccd04ef7e0288806943f7cc5bbf69f3ac1f0/frozenlist-1.8.0.tar.gz", hash = "sha256:3ede829ed8d842f6cd48fc7081d7a41001a56f1f38603f9d49bf3020d59a31ad", size = 45875, upload-time = "2025-10-06T05:38:17.865Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/29/948b9aa87e75820a38650af445d2ef2b6b8a6fab1a23b6bb9e4ef0be2d59/frozenlist-1.8.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:78f7b9e5d6f2fdb88cdde9440dc147259b62b9d3b019924def9f6478be254ac1", size = 87782, upload-time = "2025-10-06T05:36:06.649Z" }, + { url = "https://files.pythonhosted.org/packages/64/80/4f6e318ee2a7c0750ed724fa33a4bdf1eacdc5a39a7a24e818a773cd91af/frozenlist-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:229bf37d2e4acdaf808fd3f06e854a4a7a3661e871b10dc1f8f1896a3b05f18b", size = 50594, upload-time = "2025-10-06T05:36:07.69Z" }, + { url = "https://files.pythonhosted.org/packages/2b/94/5c8a2b50a496b11dd519f4a24cb5496cf125681dd99e94c604ccdea9419a/frozenlist-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f833670942247a14eafbb675458b4e61c82e002a148f49e68257b79296e865c4", size = 50448, upload-time = "2025-10-06T05:36:08.78Z" }, + { url = "https://files.pythonhosted.org/packages/6a/bd/d91c5e39f490a49df14320f4e8c80161cfcce09f1e2cde1edd16a551abb3/frozenlist-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:494a5952b1c597ba44e0e78113a7266e656b9794eec897b19ead706bd7074383", size = 242411, upload-time = "2025-10-06T05:36:09.801Z" }, + { url = "https://files.pythonhosted.org/packages/8f/83/f61505a05109ef3293dfb1ff594d13d64a2324ac3482be2cedc2be818256/frozenlist-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96f423a119f4777a4a056b66ce11527366a8bb92f54e541ade21f2374433f6d4", size = 243014, upload-time = "2025-10-06T05:36:11.394Z" }, + { url = "https://files.pythonhosted.org/packages/d8/cb/cb6c7b0f7d4023ddda30cf56b8b17494eb3a79e3fda666bf735f63118b35/frozenlist-1.8.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3462dd9475af2025c31cc61be6652dfa25cbfb56cbbf52f4ccfe029f38decaf8", size = 234909, upload-time = "2025-10-06T05:36:12.598Z" }, + { url = "https://files.pythonhosted.org/packages/31/c5/cd7a1f3b8b34af009fb17d4123c5a778b44ae2804e3ad6b86204255f9ec5/frozenlist-1.8.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c4c800524c9cd9bac5166cd6f55285957fcfc907db323e193f2afcd4d9abd69b", size = 250049, upload-time = "2025-10-06T05:36:14.065Z" }, + { url = "https://files.pythonhosted.org/packages/c0/01/2f95d3b416c584a1e7f0e1d6d31998c4a795f7544069ee2e0962a4b60740/frozenlist-1.8.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d6a5df73acd3399d893dafc71663ad22534b5aa4f94e8a2fabfe856c3c1b6a52", size = 256485, upload-time = "2025-10-06T05:36:15.39Z" }, + { url = "https://files.pythonhosted.org/packages/ce/03/024bf7720b3abaebcff6d0793d73c154237b85bdf67b7ed55e5e9596dc9a/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:405e8fe955c2280ce66428b3ca55e12b3c4e9c336fb2103a4937e891c69a4a29", size = 237619, upload-time = "2025-10-06T05:36:16.558Z" }, + { url = "https://files.pythonhosted.org/packages/69/fa/f8abdfe7d76b731f5d8bd217827cf6764d4f1d9763407e42717b4bed50a0/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:908bd3f6439f2fef9e85031b59fd4f1297af54415fb60e4254a95f75b3cab3f3", size = 250320, upload-time = "2025-10-06T05:36:17.821Z" }, + { url = "https://files.pythonhosted.org/packages/f5/3c/b051329f718b463b22613e269ad72138cc256c540f78a6de89452803a47d/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:294e487f9ec720bd8ffcebc99d575f7eff3568a08a253d1ee1a0378754b74143", size = 246820, upload-time = "2025-10-06T05:36:19.046Z" }, + { url = "https://files.pythonhosted.org/packages/0f/ae/58282e8f98e444b3f4dd42448ff36fa38bef29e40d40f330b22e7108f565/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:74c51543498289c0c43656701be6b077f4b265868fa7f8a8859c197006efb608", size = 250518, upload-time = "2025-10-06T05:36:20.763Z" }, + { url = "https://files.pythonhosted.org/packages/8f/96/007e5944694d66123183845a106547a15944fbbb7154788cbf7272789536/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:776f352e8329135506a1d6bf16ac3f87bc25b28e765949282dcc627af36123aa", size = 239096, upload-time = "2025-10-06T05:36:22.129Z" }, + { url = "https://files.pythonhosted.org/packages/66/bb/852b9d6db2fa40be96f29c0d1205c306288f0684df8fd26ca1951d461a56/frozenlist-1.8.0-cp312-cp312-win32.whl", hash = "sha256:433403ae80709741ce34038da08511d4a77062aa924baf411ef73d1146e74faf", size = 39985, upload-time = "2025-10-06T05:36:23.661Z" }, + { url = "https://files.pythonhosted.org/packages/b8/af/38e51a553dd66eb064cdf193841f16f077585d4d28394c2fa6235cb41765/frozenlist-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:34187385b08f866104f0c0617404c8eb08165ab1272e884abc89c112e9c00746", size = 44591, upload-time = "2025-10-06T05:36:24.958Z" }, + { url = "https://files.pythonhosted.org/packages/a7/06/1dc65480ab147339fecc70797e9c2f69d9cea9cf38934ce08df070fdb9cb/frozenlist-1.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:fe3c58d2f5db5fbd18c2987cba06d51b0529f52bc3a6cdc33d3f4eab725104bd", size = 40102, upload-time = "2025-10-06T05:36:26.333Z" }, + { url = "https://files.pythonhosted.org/packages/2d/40/0832c31a37d60f60ed79e9dfb5a92e1e2af4f40a16a29abcc7992af9edff/frozenlist-1.8.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8d92f1a84bb12d9e56f818b3a746f3efba93c1b63c8387a73dde655e1e42282a", size = 85717, upload-time = "2025-10-06T05:36:27.341Z" }, + { url = "https://files.pythonhosted.org/packages/30/ba/b0b3de23f40bc55a7057bd38434e25c34fa48e17f20ee273bbde5e0650f3/frozenlist-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:96153e77a591c8adc2ee805756c61f59fef4cf4073a9275ee86fe8cba41241f7", size = 49651, upload-time = "2025-10-06T05:36:28.855Z" }, + { url = "https://files.pythonhosted.org/packages/0c/ab/6e5080ee374f875296c4243c381bbdef97a9ac39c6e3ce1d5f7d42cb78d6/frozenlist-1.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f21f00a91358803399890ab167098c131ec2ddd5f8f5fd5fe9c9f2c6fcd91e40", size = 49417, upload-time = "2025-10-06T05:36:29.877Z" }, + { url = "https://files.pythonhosted.org/packages/d5/4e/e4691508f9477ce67da2015d8c00acd751e6287739123113a9fca6f1604e/frozenlist-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:fb30f9626572a76dfe4293c7194a09fb1fe93ba94c7d4f720dfae3b646b45027", size = 234391, upload-time = "2025-10-06T05:36:31.301Z" }, + { url = "https://files.pythonhosted.org/packages/40/76/c202df58e3acdf12969a7895fd6f3bc016c642e6726aa63bd3025e0fc71c/frozenlist-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eaa352d7047a31d87dafcacbabe89df0aa506abb5b1b85a2fb91bc3faa02d822", size = 233048, upload-time = "2025-10-06T05:36:32.531Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c0/8746afb90f17b73ca5979c7a3958116e105ff796e718575175319b5bb4ce/frozenlist-1.8.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:03ae967b4e297f58f8c774c7eabcce57fe3c2434817d4385c50661845a058121", size = 226549, upload-time = "2025-10-06T05:36:33.706Z" }, + { url = "https://files.pythonhosted.org/packages/7e/eb/4c7eefc718ff72f9b6c4893291abaae5fbc0c82226a32dcd8ef4f7a5dbef/frozenlist-1.8.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f6292f1de555ffcc675941d65fffffb0a5bcd992905015f85d0592201793e0e5", size = 239833, upload-time = "2025-10-06T05:36:34.947Z" }, + { url = "https://files.pythonhosted.org/packages/c2/4e/e5c02187cf704224f8b21bee886f3d713ca379535f16893233b9d672ea71/frozenlist-1.8.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:29548f9b5b5e3460ce7378144c3010363d8035cea44bc0bf02d57f5a685e084e", size = 245363, upload-time = "2025-10-06T05:36:36.534Z" }, + { url = "https://files.pythonhosted.org/packages/1f/96/cb85ec608464472e82ad37a17f844889c36100eed57bea094518bf270692/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ec3cc8c5d4084591b4237c0a272cc4f50a5b03396a47d9caaf76f5d7b38a4f11", size = 229314, upload-time = "2025-10-06T05:36:38.582Z" }, + { url = "https://files.pythonhosted.org/packages/5d/6f/4ae69c550e4cee66b57887daeebe006fe985917c01d0fff9caab9883f6d0/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:517279f58009d0b1f2e7c1b130b377a349405da3f7621ed6bfae50b10adf20c1", size = 243365, upload-time = "2025-10-06T05:36:40.152Z" }, + { url = "https://files.pythonhosted.org/packages/7a/58/afd56de246cf11780a40a2c28dc7cbabbf06337cc8ddb1c780a2d97e88d8/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:db1e72ede2d0d7ccb213f218df6a078a9c09a7de257c2fe8fcef16d5925230b1", size = 237763, upload-time = "2025-10-06T05:36:41.355Z" }, + { url = "https://files.pythonhosted.org/packages/cb/36/cdfaf6ed42e2644740d4a10452d8e97fa1c062e2a8006e4b09f1b5fd7d63/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b4dec9482a65c54a5044486847b8a66bf10c9cb4926d42927ec4e8fd5db7fed8", size = 240110, upload-time = "2025-10-06T05:36:42.716Z" }, + { url = "https://files.pythonhosted.org/packages/03/a8/9ea226fbefad669f11b52e864c55f0bd57d3c8d7eb07e9f2e9a0b39502e1/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:21900c48ae04d13d416f0e1e0c4d81f7931f73a9dfa0b7a8746fb2fe7dd970ed", size = 233717, upload-time = "2025-10-06T05:36:44.251Z" }, + { url = "https://files.pythonhosted.org/packages/1e/0b/1b5531611e83ba7d13ccc9988967ea1b51186af64c42b7a7af465dcc9568/frozenlist-1.8.0-cp313-cp313-win32.whl", hash = "sha256:8b7b94a067d1c504ee0b16def57ad5738701e4ba10cec90529f13fa03c833496", size = 39628, upload-time = "2025-10-06T05:36:45.423Z" }, + { url = "https://files.pythonhosted.org/packages/d8/cf/174c91dbc9cc49bc7b7aab74d8b734e974d1faa8f191c74af9b7e80848e6/frozenlist-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:878be833caa6a3821caf85eb39c5ba92d28e85df26d57afb06b35b2efd937231", size = 43882, upload-time = "2025-10-06T05:36:46.796Z" }, + { url = "https://files.pythonhosted.org/packages/c1/17/502cd212cbfa96eb1388614fe39a3fc9ab87dbbe042b66f97acb57474834/frozenlist-1.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:44389d135b3ff43ba8cc89ff7f51f5a0bb6b63d829c8300f79a2fe4fe61bcc62", size = 39676, upload-time = "2025-10-06T05:36:47.8Z" }, + { url = "https://files.pythonhosted.org/packages/d2/5c/3bbfaa920dfab09e76946a5d2833a7cbdf7b9b4a91c714666ac4855b88b4/frozenlist-1.8.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:e25ac20a2ef37e91c1b39938b591457666a0fa835c7783c3a8f33ea42870db94", size = 89235, upload-time = "2025-10-06T05:36:48.78Z" }, + { url = "https://files.pythonhosted.org/packages/d2/d6/f03961ef72166cec1687e84e8925838442b615bd0b8854b54923ce5b7b8a/frozenlist-1.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:07cdca25a91a4386d2e76ad992916a85038a9b97561bf7a3fd12d5d9ce31870c", size = 50742, upload-time = "2025-10-06T05:36:49.837Z" }, + { url = "https://files.pythonhosted.org/packages/1e/bb/a6d12b7ba4c3337667d0e421f7181c82dda448ce4e7ad7ecd249a16fa806/frozenlist-1.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4e0c11f2cc6717e0a741f84a527c52616140741cd812a50422f83dc31749fb52", size = 51725, upload-time = "2025-10-06T05:36:50.851Z" }, + { url = "https://files.pythonhosted.org/packages/bc/71/d1fed0ffe2c2ccd70b43714c6cab0f4188f09f8a67a7914a6b46ee30f274/frozenlist-1.8.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b3210649ee28062ea6099cfda39e147fa1bc039583c8ee4481cb7811e2448c51", size = 284533, upload-time = "2025-10-06T05:36:51.898Z" }, + { url = "https://files.pythonhosted.org/packages/c9/1f/fb1685a7b009d89f9bf78a42d94461bc06581f6e718c39344754a5d9bada/frozenlist-1.8.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:581ef5194c48035a7de2aefc72ac6539823bb71508189e5de01d60c9dcd5fa65", size = 292506, upload-time = "2025-10-06T05:36:53.101Z" }, + { url = "https://files.pythonhosted.org/packages/e6/3b/b991fe1612703f7e0d05c0cf734c1b77aaf7c7d321df4572e8d36e7048c8/frozenlist-1.8.0-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3ef2d026f16a2b1866e1d86fc4e1291e1ed8a387b2c333809419a2f8b3a77b82", size = 274161, upload-time = "2025-10-06T05:36:54.309Z" }, + { url = "https://files.pythonhosted.org/packages/ca/ec/c5c618767bcdf66e88945ec0157d7f6c4a1322f1473392319b7a2501ded7/frozenlist-1.8.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:5500ef82073f599ac84d888e3a8c1f77ac831183244bfd7f11eaa0289fb30714", size = 294676, upload-time = "2025-10-06T05:36:55.566Z" }, + { url = "https://files.pythonhosted.org/packages/7c/ce/3934758637d8f8a88d11f0585d6495ef54b2044ed6ec84492a91fa3b27aa/frozenlist-1.8.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:50066c3997d0091c411a66e710f4e11752251e6d2d73d70d8d5d4c76442a199d", size = 300638, upload-time = "2025-10-06T05:36:56.758Z" }, + { url = "https://files.pythonhosted.org/packages/fc/4f/a7e4d0d467298f42de4b41cbc7ddaf19d3cfeabaf9ff97c20c6c7ee409f9/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:5c1c8e78426e59b3f8005e9b19f6ff46e5845895adbde20ece9218319eca6506", size = 283067, upload-time = "2025-10-06T05:36:57.965Z" }, + { url = "https://files.pythonhosted.org/packages/dc/48/c7b163063d55a83772b268e6d1affb960771b0e203b632cfe09522d67ea5/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:eefdba20de0d938cec6a89bd4d70f346a03108a19b9df4248d3cf0d88f1b0f51", size = 292101, upload-time = "2025-10-06T05:36:59.237Z" }, + { url = "https://files.pythonhosted.org/packages/9f/d0/2366d3c4ecdc2fd391e0afa6e11500bfba0ea772764d631bbf82f0136c9d/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:cf253e0e1c3ceb4aaff6df637ce033ff6535fb8c70a764a8f46aafd3d6ab798e", size = 289901, upload-time = "2025-10-06T05:37:00.811Z" }, + { url = "https://files.pythonhosted.org/packages/b8/94/daff920e82c1b70e3618a2ac39fbc01ae3e2ff6124e80739ce5d71c9b920/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:032efa2674356903cd0261c4317a561a6850f3ac864a63fc1583147fb05a79b0", size = 289395, upload-time = "2025-10-06T05:37:02.115Z" }, + { url = "https://files.pythonhosted.org/packages/e3/20/bba307ab4235a09fdcd3cc5508dbabd17c4634a1af4b96e0f69bfe551ebd/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6da155091429aeba16851ecb10a9104a108bcd32f6c1642867eadaee401c1c41", size = 283659, upload-time = "2025-10-06T05:37:03.711Z" }, + { url = "https://files.pythonhosted.org/packages/fd/00/04ca1c3a7a124b6de4f8a9a17cc2fcad138b4608e7a3fc5877804b8715d7/frozenlist-1.8.0-cp313-cp313t-win32.whl", hash = "sha256:0f96534f8bfebc1a394209427d0f8a63d343c9779cda6fc25e8e121b5fd8555b", size = 43492, upload-time = "2025-10-06T05:37:04.915Z" }, + { url = "https://files.pythonhosted.org/packages/59/5e/c69f733a86a94ab10f68e496dc6b7e8bc078ebb415281d5698313e3af3a1/frozenlist-1.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5d63a068f978fc69421fb0e6eb91a9603187527c86b7cd3f534a5b77a592b888", size = 48034, upload-time = "2025-10-06T05:37:06.343Z" }, + { url = "https://files.pythonhosted.org/packages/16/6c/be9d79775d8abe79b05fa6d23da99ad6e7763a1d080fbae7290b286093fd/frozenlist-1.8.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bf0a7e10b077bf5fb9380ad3ae8ce20ef919a6ad93b4552896419ac7e1d8e042", size = 41749, upload-time = "2025-10-06T05:37:07.431Z" }, + { url = "https://files.pythonhosted.org/packages/f1/c8/85da824b7e7b9b6e7f7705b2ecaf9591ba6f79c1177f324c2735e41d36a2/frozenlist-1.8.0-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:cee686f1f4cadeb2136007ddedd0aaf928ab95216e7691c63e50a8ec066336d0", size = 86127, upload-time = "2025-10-06T05:37:08.438Z" }, + { url = "https://files.pythonhosted.org/packages/8e/e8/a1185e236ec66c20afd72399522f142c3724c785789255202d27ae992818/frozenlist-1.8.0-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:119fb2a1bd47307e899c2fac7f28e85b9a543864df47aa7ec9d3c1b4545f096f", size = 49698, upload-time = "2025-10-06T05:37:09.48Z" }, + { url = "https://files.pythonhosted.org/packages/a1/93/72b1736d68f03fda5fdf0f2180fb6caaae3894f1b854d006ac61ecc727ee/frozenlist-1.8.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:4970ece02dbc8c3a92fcc5228e36a3e933a01a999f7094ff7c23fbd2beeaa67c", size = 49749, upload-time = "2025-10-06T05:37:10.569Z" }, + { url = "https://files.pythonhosted.org/packages/a7/b2/fabede9fafd976b991e9f1b9c8c873ed86f202889b864756f240ce6dd855/frozenlist-1.8.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:cba69cb73723c3f329622e34bdbf5ce1f80c21c290ff04256cff1cd3c2036ed2", size = 231298, upload-time = "2025-10-06T05:37:11.993Z" }, + { url = "https://files.pythonhosted.org/packages/3a/3b/d9b1e0b0eed36e70477ffb8360c49c85c8ca8ef9700a4e6711f39a6e8b45/frozenlist-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:778a11b15673f6f1df23d9586f83c4846c471a8af693a22e066508b77d201ec8", size = 232015, upload-time = "2025-10-06T05:37:13.194Z" }, + { url = "https://files.pythonhosted.org/packages/dc/94/be719d2766c1138148564a3960fc2c06eb688da592bdc25adcf856101be7/frozenlist-1.8.0-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0325024fe97f94c41c08872db482cf8ac4800d80e79222c6b0b7b162d5b13686", size = 225038, upload-time = "2025-10-06T05:37:14.577Z" }, + { url = "https://files.pythonhosted.org/packages/e4/09/6712b6c5465f083f52f50cf74167b92d4ea2f50e46a9eea0523d658454ae/frozenlist-1.8.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:97260ff46b207a82a7567b581ab4190bd4dfa09f4db8a8b49d1a958f6aa4940e", size = 240130, upload-time = "2025-10-06T05:37:15.781Z" }, + { url = "https://files.pythonhosted.org/packages/f8/d4/cd065cdcf21550b54f3ce6a22e143ac9e4836ca42a0de1022da8498eac89/frozenlist-1.8.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:54b2077180eb7f83dd52c40b2750d0a9f175e06a42e3213ce047219de902717a", size = 242845, upload-time = "2025-10-06T05:37:17.037Z" }, + { url = "https://files.pythonhosted.org/packages/62/c3/f57a5c8c70cd1ead3d5d5f776f89d33110b1addae0ab010ad774d9a44fb9/frozenlist-1.8.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:2f05983daecab868a31e1da44462873306d3cbfd76d1f0b5b69c473d21dbb128", size = 229131, upload-time = "2025-10-06T05:37:18.221Z" }, + { url = "https://files.pythonhosted.org/packages/6c/52/232476fe9cb64f0742f3fde2b7d26c1dac18b6d62071c74d4ded55e0ef94/frozenlist-1.8.0-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:33f48f51a446114bc5d251fb2954ab0164d5be02ad3382abcbfe07e2531d650f", size = 240542, upload-time = "2025-10-06T05:37:19.771Z" }, + { url = "https://files.pythonhosted.org/packages/5f/85/07bf3f5d0fb5414aee5f47d33c6f5c77bfe49aac680bfece33d4fdf6a246/frozenlist-1.8.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:154e55ec0655291b5dd1b8731c637ecdb50975a2ae70c606d100750a540082f7", size = 237308, upload-time = "2025-10-06T05:37:20.969Z" }, + { url = "https://files.pythonhosted.org/packages/11/99/ae3a33d5befd41ac0ca2cc7fd3aa707c9c324de2e89db0e0f45db9a64c26/frozenlist-1.8.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:4314debad13beb564b708b4a496020e5306c7333fa9a3ab90374169a20ffab30", size = 238210, upload-time = "2025-10-06T05:37:22.252Z" }, + { url = "https://files.pythonhosted.org/packages/b2/60/b1d2da22f4970e7a155f0adde9b1435712ece01b3cd45ba63702aea33938/frozenlist-1.8.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:073f8bf8becba60aa931eb3bc420b217bb7d5b8f4750e6f8b3be7f3da85d38b7", size = 231972, upload-time = "2025-10-06T05:37:23.5Z" }, + { url = "https://files.pythonhosted.org/packages/3f/ab/945b2f32de889993b9c9133216c068b7fcf257d8595a0ac420ac8677cab0/frozenlist-1.8.0-cp314-cp314-win32.whl", hash = "sha256:bac9c42ba2ac65ddc115d930c78d24ab8d4f465fd3fc473cdedfccadb9429806", size = 40536, upload-time = "2025-10-06T05:37:25.581Z" }, + { url = "https://files.pythonhosted.org/packages/59/ad/9caa9b9c836d9ad6f067157a531ac48b7d36499f5036d4141ce78c230b1b/frozenlist-1.8.0-cp314-cp314-win_amd64.whl", hash = "sha256:3e0761f4d1a44f1d1a47996511752cf3dcec5bbdd9cc2b4fe595caf97754b7a0", size = 44330, upload-time = "2025-10-06T05:37:26.928Z" }, + { url = "https://files.pythonhosted.org/packages/82/13/e6950121764f2676f43534c555249f57030150260aee9dcf7d64efda11dd/frozenlist-1.8.0-cp314-cp314-win_arm64.whl", hash = "sha256:d1eaff1d00c7751b7c6662e9c5ba6eb2c17a2306ba5e2a37f24ddf3cc953402b", size = 40627, upload-time = "2025-10-06T05:37:28.075Z" }, + { url = "https://files.pythonhosted.org/packages/c0/c7/43200656ecc4e02d3f8bc248df68256cd9572b3f0017f0a0c4e93440ae23/frozenlist-1.8.0-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:d3bb933317c52d7ea5004a1c442eef86f426886fba134ef8cf4226ea6ee1821d", size = 89238, upload-time = "2025-10-06T05:37:29.373Z" }, + { url = "https://files.pythonhosted.org/packages/d1/29/55c5f0689b9c0fb765055629f472c0de484dcaf0acee2f7707266ae3583c/frozenlist-1.8.0-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:8009897cdef112072f93a0efdce29cd819e717fd2f649ee3016efd3cd885a7ed", size = 50738, upload-time = "2025-10-06T05:37:30.792Z" }, + { url = "https://files.pythonhosted.org/packages/ba/7d/b7282a445956506fa11da8c2db7d276adcbf2b17d8bb8407a47685263f90/frozenlist-1.8.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:2c5dcbbc55383e5883246d11fd179782a9d07a986c40f49abe89ddf865913930", size = 51739, upload-time = "2025-10-06T05:37:32.127Z" }, + { url = "https://files.pythonhosted.org/packages/62/1c/3d8622e60d0b767a5510d1d3cf21065b9db874696a51ea6d7a43180a259c/frozenlist-1.8.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:39ecbc32f1390387d2aa4f5a995e465e9e2f79ba3adcac92d68e3e0afae6657c", size = 284186, upload-time = "2025-10-06T05:37:33.21Z" }, + { url = "https://files.pythonhosted.org/packages/2d/14/aa36d5f85a89679a85a1d44cd7a6657e0b1c75f61e7cad987b203d2daca8/frozenlist-1.8.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:92db2bf818d5cc8d9c1f1fc56b897662e24ea5adb36ad1f1d82875bd64e03c24", size = 292196, upload-time = "2025-10-06T05:37:36.107Z" }, + { url = "https://files.pythonhosted.org/packages/05/23/6bde59eb55abd407d34f77d39a5126fb7b4f109a3f611d3929f14b700c66/frozenlist-1.8.0-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2dc43a022e555de94c3b68a4ef0b11c4f747d12c024a520c7101709a2144fb37", size = 273830, upload-time = "2025-10-06T05:37:37.663Z" }, + { url = "https://files.pythonhosted.org/packages/d2/3f/22cff331bfad7a8afa616289000ba793347fcd7bc275f3b28ecea2a27909/frozenlist-1.8.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:cb89a7f2de3602cfed448095bab3f178399646ab7c61454315089787df07733a", size = 294289, upload-time = "2025-10-06T05:37:39.261Z" }, + { url = "https://files.pythonhosted.org/packages/a4/89/5b057c799de4838b6c69aa82b79705f2027615e01be996d2486a69ca99c4/frozenlist-1.8.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:33139dc858c580ea50e7e60a1b0ea003efa1fd42e6ec7fdbad78fff65fad2fd2", size = 300318, upload-time = "2025-10-06T05:37:43.213Z" }, + { url = "https://files.pythonhosted.org/packages/30/de/2c22ab3eb2a8af6d69dc799e48455813bab3690c760de58e1bf43b36da3e/frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:168c0969a329b416119507ba30b9ea13688fafffac1b7822802537569a1cb0ef", size = 282814, upload-time = "2025-10-06T05:37:45.337Z" }, + { url = "https://files.pythonhosted.org/packages/59/f7/970141a6a8dbd7f556d94977858cfb36fa9b66e0892c6dd780d2219d8cd8/frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:28bd570e8e189d7f7b001966435f9dac6718324b5be2990ac496cf1ea9ddb7fe", size = 291762, upload-time = "2025-10-06T05:37:46.657Z" }, + { url = "https://files.pythonhosted.org/packages/c1/15/ca1adae83a719f82df9116d66f5bb28bb95557b3951903d39135620ef157/frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:b2a095d45c5d46e5e79ba1e5b9cb787f541a8dee0433836cea4b96a2c439dcd8", size = 289470, upload-time = "2025-10-06T05:37:47.946Z" }, + { url = "https://files.pythonhosted.org/packages/ac/83/dca6dc53bf657d371fbc88ddeb21b79891e747189c5de990b9dfff2ccba1/frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:eab8145831a0d56ec9c4139b6c3e594c7a83c2c8be25d5bcf2d86136a532287a", size = 289042, upload-time = "2025-10-06T05:37:49.499Z" }, + { url = "https://files.pythonhosted.org/packages/96/52/abddd34ca99be142f354398700536c5bd315880ed0a213812bc491cff5e4/frozenlist-1.8.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:974b28cf63cc99dfb2188d8d222bc6843656188164848c4f679e63dae4b0708e", size = 283148, upload-time = "2025-10-06T05:37:50.745Z" }, + { url = "https://files.pythonhosted.org/packages/af/d3/76bd4ed4317e7119c2b7f57c3f6934aba26d277acc6309f873341640e21f/frozenlist-1.8.0-cp314-cp314t-win32.whl", hash = "sha256:342c97bf697ac5480c0a7ec73cd700ecfa5a8a40ac923bd035484616efecc2df", size = 44676, upload-time = "2025-10-06T05:37:52.222Z" }, + { url = "https://files.pythonhosted.org/packages/89/76/c615883b7b521ead2944bb3480398cbb07e12b7b4e4d073d3752eb721558/frozenlist-1.8.0-cp314-cp314t-win_amd64.whl", hash = "sha256:06be8f67f39c8b1dc671f5d83aaefd3358ae5cdcf8314552c57e7ed3e6475bdd", size = 49451, upload-time = "2025-10-06T05:37:53.425Z" }, + { url = "https://files.pythonhosted.org/packages/e0/a3/5982da14e113d07b325230f95060e2169f5311b1017ea8af2a29b374c289/frozenlist-1.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:102e6314ca4da683dca92e3b1355490fed5f313b768500084fbe6371fddfdb79", size = 42507, upload-time = "2025-10-06T05:37:54.513Z" }, + { url = "https://files.pythonhosted.org/packages/9a/9a/e35b4a917281c0b8419d4207f4334c8e8c5dbf4f3f5f9ada73958d937dcc/frozenlist-1.8.0-py3-none-any.whl", hash = "sha256:0c18a16eab41e82c295618a77502e17b195883241c563b00f0aa5106fc4eaa0d", size = 13409, upload-time = "2025-10-06T05:38:16.721Z" }, +] + [[package]] name = "fsspec" version = "2026.2.0" @@ -465,6 +702,33 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ab/fb21f4c939bb440104cc2b396d3be1d9b7a9fd3c6c2a53d98c45b3d7c954/fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437", size = 202505, upload-time = "2026-02-05T21:50:51.819Z" }, ] +[[package]] +name = "httpstan" +version = "4.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "appdirs" }, + { name = "marshmallow" }, + { name = "numpy" }, + { name = "setuptools" }, + { name = "webargs" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/d5/e7998ed6558debc5029f7d6f1dc60bee494ff41d2a77dc20df464a885abf/httpstan-4.13.0-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:046cc66f0adbf7e149361ccaa76a0e1478ed865490d1288eb91ef5007d4de590", size = 39632875, upload-time = "2024-07-03T16:26:15.175Z" }, + { url = "https://files.pythonhosted.org/packages/ae/59/995c89214322d4705043e4beb8e05461c490f04be213c3f6f0d76caa99a6/httpstan-4.13.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:f597aea27ea4026d8a4e175dea074d316af1ccc72b61c18bda87b53b59ddb29d", size = 39557020, upload-time = "2025-11-14T15:57:20.974Z" }, + { url = "https://files.pythonhosted.org/packages/13/d4/a51e9f5bbf7ab50d6e975431da79cd39fd99136231ed326e2bd10720fc78/httpstan-4.13.0-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:f225765ccc22851ce51be2300a90d70091137d559344825a25950532b26a89c7", size = 39720683, upload-time = "2025-11-13T10:59:59.947Z" }, + { url = "https://files.pythonhosted.org/packages/1a/d8/b97e532306a8d7e0aa5d5d9b236e42851c662eb19bfba055bc880fd5b0f4/httpstan-4.13.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:be64a0bfc0ea436197da07e143630980a294cc55318250e2797655e68ffdc03f", size = 45548033, upload-time = "2025-11-13T11:00:53.961Z" }, + { url = "https://files.pythonhosted.org/packages/d0/6e/069e28610b85d93341ce830b1470aa103d19ba0eded27ab718264041e2c1/httpstan-4.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a452f600e314f107f6bc17e43905ea2f3aaeebf8b8b98350e861aacf84476382", size = 45553384, upload-time = "2024-07-03T16:08:59.31Z" }, + { url = "https://files.pythonhosted.org/packages/1a/78/390017e85b26a491df9ac0b442bb33c82e422d280c308a61e057cdd56efc/httpstan-4.13.0-cp313-cp313-macosx_13_0_x86_64.whl", hash = "sha256:dfb5bc019969659d46ad3e4db93bf4c1cebb787f9b7fd32ca66fc38beca6c861", size = 39632876, upload-time = "2025-02-06T01:14:38.119Z" }, + { url = "https://files.pythonhosted.org/packages/41/21/8bcd1318ce59725b3c133f2c7de2c4806fbfb6e9ca03b9f125ecd69e763f/httpstan-4.13.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:7db6e08a44f89b20491f3fc46fa7cfe2bab01c7519c318648687cce68bf968da", size = 39557012, upload-time = "2025-11-14T15:56:26.77Z" }, + { url = "https://files.pythonhosted.org/packages/a5/0e/845fb932cbe7011bc8471cfd6b207a0e653e7fd712415e433fa9b4107dfb/httpstan-4.13.0-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:ffa4923273585d4e91ba5168c4308b0587789de8c043bb0d00ecc0b7642ad246", size = 39720492, upload-time = "2025-11-13T11:01:04.436Z" }, + { url = "https://files.pythonhosted.org/packages/76/71/546158d24a46d38b928a2406cacdb05cee0086f000e954358c32842dc8e2/httpstan-4.13.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6b4e9de61b199318bef1141871bf73742a1b657cf95491f21c45c735a8ffeaa3", size = 45548152, upload-time = "2025-11-13T11:00:55.747Z" }, + { url = "https://files.pythonhosted.org/packages/2b/b0/87f3b199e9312e7d75e68da1202f32d0a1d4046eb32702226d003ea4aeee/httpstan-4.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21dffe0201a5866059d2ed75322c36613421ced1a33d769b992afcc7953766f0", size = 45548094, upload-time = "2025-02-05T21:27:18.117Z" }, + { url = "https://files.pythonhosted.org/packages/74/31/7bb2fb766da4ddda88bce7a6105a1d3877eaf29ba6e77ade2484698f5bd8/httpstan-4.13.0-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:40a5d03845c49ef2c5d60e1eda297b2e72571d1ba2f074aaf064db41e5cecedf", size = 39553950, upload-time = "2025-11-14T15:58:39.366Z" }, + { url = "https://files.pythonhosted.org/packages/1a/78/7720e37e44515011d977e3e3e25b3b800d6c7a71afcd94a71ef38f734801/httpstan-4.13.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bb8dba968b0c4f4e6332e4c14ed004c07d2bce78b2bf6d47e787433c986197b2", size = 45546271, upload-time = "2025-11-14T15:59:07.637Z" }, +] + [[package]] name = "idna" version = "3.13" @@ -754,6 +1018,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/70/bc/6f1c2f612465f5fa89b95bead1f44dcb607670fd42891d8fdcd5d039f4f4/markupsafe-3.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:32001d6a8fc98c8cb5c947787c5d08b0a50663d139f1305bac5885d98d9b40fa", size = 14146, upload-time = "2025-09-27T18:37:28.327Z" }, ] +[[package]] +name = "marshmallow" +version = "3.26.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" }, +] + [[package]] name = "matplotlib" version = "3.10.8" @@ -829,6 +1105,105 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, ] +[[package]] +name = "multidict" +version = "6.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/c2/c2d94cbe6ac1753f3fc980da97b3d930efe1da3af3c9f5125354436c073d/multidict-6.7.1.tar.gz", hash = "sha256:ec6652a1bee61c53a3e5776b6049172c53b6aaba34f18c9ad04f82712bac623d", size = 102010, upload-time = "2026-01-26T02:46:45.979Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/9c/f20e0e2cf80e4b2e4b1c365bf5fe104ee633c751a724246262db8f1a0b13/multidict-6.7.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a90f75c956e32891a4eda3639ce6dd86e87105271f43d43442a3aedf3cddf172", size = 76893, upload-time = "2026-01-26T02:43:52.754Z" }, + { url = "https://files.pythonhosted.org/packages/fe/cf/18ef143a81610136d3da8193da9d80bfe1cb548a1e2d1c775f26b23d024a/multidict-6.7.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fccb473e87eaa1382689053e4a4618e7ba7b9b9b8d6adf2027ee474597128cd", size = 45456, upload-time = "2026-01-26T02:43:53.893Z" }, + { url = "https://files.pythonhosted.org/packages/a9/65/1caac9d4cd32e8433908683446eebc953e82d22b03d10d41a5f0fefe991b/multidict-6.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0fa96985700739c4c7853a43c0b3e169360d6855780021bfc6d0f1ce7c123e7", size = 43872, upload-time = "2026-01-26T02:43:55.041Z" }, + { url = "https://files.pythonhosted.org/packages/cf/3b/d6bd75dc4f3ff7c73766e04e705b00ed6dbbaccf670d9e05a12b006f5a21/multidict-6.7.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cb2a55f408c3043e42b40cc8eecd575afa27b7e0b956dfb190de0f8499a57a53", size = 251018, upload-time = "2026-01-26T02:43:56.198Z" }, + { url = "https://files.pythonhosted.org/packages/fd/80/c959c5933adedb9ac15152e4067c702a808ea183a8b64cf8f31af8ad3155/multidict-6.7.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eb0ce7b2a32d09892b3dd6cc44877a0d02a33241fafca5f25c8b6b62374f8b75", size = 258883, upload-time = "2026-01-26T02:43:57.499Z" }, + { url = "https://files.pythonhosted.org/packages/86/85/7ed40adafea3d4f1c8b916e3b5cc3a8e07dfcdcb9cd72800f4ed3ca1b387/multidict-6.7.1-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c3a32d23520ee37bf327d1e1a656fec76a2edd5c038bf43eddfa0572ec49c60b", size = 242413, upload-time = "2026-01-26T02:43:58.755Z" }, + { url = "https://files.pythonhosted.org/packages/d2/57/b8565ff533e48595503c785f8361ff9a4fde4d67de25c207cd0ba3befd03/multidict-6.7.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:9c90fed18bffc0189ba814749fdcc102b536e83a9f738a9003e569acd540a733", size = 268404, upload-time = "2026-01-26T02:44:00.216Z" }, + { url = "https://files.pythonhosted.org/packages/e0/50/9810c5c29350f7258180dfdcb2e52783a0632862eb334c4896ac717cebcb/multidict-6.7.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:da62917e6076f512daccfbbde27f46fed1c98fee202f0559adec8ee0de67f71a", size = 269456, upload-time = "2026-01-26T02:44:02.202Z" }, + { url = "https://files.pythonhosted.org/packages/f3/8d/5e5be3ced1d12966fefb5c4ea3b2a5b480afcea36406559442c6e31d4a48/multidict-6.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfde23ef6ed9db7eaee6c37dcec08524cb43903c60b285b172b6c094711b3961", size = 256322, upload-time = "2026-01-26T02:44:03.56Z" }, + { url = "https://files.pythonhosted.org/packages/31/6e/d8a26d81ac166a5592782d208dd90dfdc0a7a218adaa52b45a672b46c122/multidict-6.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3758692429e4e32f1ba0df23219cd0b4fc0a52f476726fff9337d1a57676a582", size = 253955, upload-time = "2026-01-26T02:44:04.845Z" }, + { url = "https://files.pythonhosted.org/packages/59/4c/7c672c8aad41534ba619bcd4ade7a0dc87ed6b8b5c06149b85d3dd03f0cd/multidict-6.7.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:398c1478926eca669f2fd6a5856b6de9c0acf23a2cb59a14c0ba5844fa38077e", size = 251254, upload-time = "2026-01-26T02:44:06.133Z" }, + { url = "https://files.pythonhosted.org/packages/7b/bd/84c24de512cbafbdbc39439f74e967f19570ce7924e3007174a29c348916/multidict-6.7.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c102791b1c4f3ab36ce4101154549105a53dc828f016356b3e3bcae2e3a039d3", size = 252059, upload-time = "2026-01-26T02:44:07.518Z" }, + { url = "https://files.pythonhosted.org/packages/fa/ba/f5449385510825b73d01c2d4087bf6d2fccc20a2d42ac34df93191d3dd03/multidict-6.7.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a088b62bd733e2ad12c50dad01b7d0166c30287c166e137433d3b410add807a6", size = 263588, upload-time = "2026-01-26T02:44:09.382Z" }, + { url = "https://files.pythonhosted.org/packages/d7/11/afc7c677f68f75c84a69fe37184f0f82fce13ce4b92f49f3db280b7e92b3/multidict-6.7.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d51ff4785d58d3f6c91bdbffcb5e1f7ddfda557727043aa20d20ec4f65e324a", size = 259642, upload-time = "2026-01-26T02:44:10.73Z" }, + { url = "https://files.pythonhosted.org/packages/2b/17/ebb9644da78c4ab36403739e0e6e0e30ebb135b9caf3440825001a0bddcb/multidict-6.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc5907494fccf3e7d3f94f95c91d6336b092b5fc83811720fae5e2765890dfba", size = 251377, upload-time = "2026-01-26T02:44:12.042Z" }, + { url = "https://files.pythonhosted.org/packages/ca/a4/840f5b97339e27846c46307f2530a2805d9d537d8b8bd416af031cad7fa0/multidict-6.7.1-cp312-cp312-win32.whl", hash = "sha256:28ca5ce2fd9716631133d0e9a9b9a745ad7f60bac2bccafb56aa380fc0b6c511", size = 41887, upload-time = "2026-01-26T02:44:14.245Z" }, + { url = "https://files.pythonhosted.org/packages/80/31/0b2517913687895f5904325c2069d6a3b78f66cc641a86a2baf75a05dcbb/multidict-6.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcee94dfbd638784645b066074b338bc9cc155d4b4bffa4adce1615c5a426c19", size = 46053, upload-time = "2026-01-26T02:44:15.371Z" }, + { url = "https://files.pythonhosted.org/packages/0c/5b/aba28e4ee4006ae4c7df8d327d31025d760ffa992ea23812a601d226e682/multidict-6.7.1-cp312-cp312-win_arm64.whl", hash = "sha256:ba0a9fb644d0c1a2194cf7ffb043bd852cea63a57f66fbd33959f7dae18517bf", size = 43307, upload-time = "2026-01-26T02:44:16.852Z" }, + { url = "https://files.pythonhosted.org/packages/f2/22/929c141d6c0dba87d3e1d38fbdf1ba8baba86b7776469f2bc2d3227a1e67/multidict-6.7.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2b41f5fed0ed563624f1c17630cb9941cf2309d4df00e494b551b5f3e3d67a23", size = 76174, upload-time = "2026-01-26T02:44:18.509Z" }, + { url = "https://files.pythonhosted.org/packages/c7/75/bc704ae15fee974f8fccd871305e254754167dce5f9e42d88a2def741a1d/multidict-6.7.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:84e61e3af5463c19b67ced91f6c634effb89ef8bfc5ca0267f954451ed4bb6a2", size = 45116, upload-time = "2026-01-26T02:44:19.745Z" }, + { url = "https://files.pythonhosted.org/packages/79/76/55cd7186f498ed080a18440c9013011eb548f77ae1b297206d030eb1180a/multidict-6.7.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:935434b9853c7c112eee7ac891bc4cb86455aa631269ae35442cb316790c1445", size = 43524, upload-time = "2026-01-26T02:44:21.571Z" }, + { url = "https://files.pythonhosted.org/packages/e9/3c/414842ef8d5a1628d68edee29ba0e5bcf235dbfb3ccd3ea303a7fe8c72ff/multidict-6.7.1-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:432feb25a1cb67fe82a9680b4d65fb542e4635cb3166cd9c01560651ad60f177", size = 249368, upload-time = "2026-01-26T02:44:22.803Z" }, + { url = "https://files.pythonhosted.org/packages/f6/32/befed7f74c458b4a525e60519fe8d87eef72bb1e99924fa2b0f9d97a221e/multidict-6.7.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e82d14e3c948952a1a85503817e038cba5905a3352de76b9a465075d072fba23", size = 256952, upload-time = "2026-01-26T02:44:24.306Z" }, + { url = "https://files.pythonhosted.org/packages/03/d6/c878a44ba877f366630c860fdf74bfb203c33778f12b6ac274936853c451/multidict-6.7.1-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4cfb48c6ea66c83bcaaf7e4dfa7ec1b6bbcf751b7db85a328902796dfde4c060", size = 240317, upload-time = "2026-01-26T02:44:25.772Z" }, + { url = "https://files.pythonhosted.org/packages/68/49/57421b4d7ad2e9e60e25922b08ceb37e077b90444bde6ead629095327a6f/multidict-6.7.1-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:1d540e51b7e8e170174555edecddbd5538105443754539193e3e1061864d444d", size = 267132, upload-time = "2026-01-26T02:44:27.648Z" }, + { url = "https://files.pythonhosted.org/packages/b7/fe/ec0edd52ddbcea2a2e89e174f0206444a61440b40f39704e64dc807a70bd/multidict-6.7.1-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:273d23f4b40f3dce4d6c8a821c741a86dec62cded82e1175ba3d99be128147ed", size = 268140, upload-time = "2026-01-26T02:44:29.588Z" }, + { url = "https://files.pythonhosted.org/packages/b0/73/6e1b01cbeb458807aa0831742232dbdd1fa92bfa33f52a3f176b4ff3dc11/multidict-6.7.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d624335fd4fa1c08a53f8b4be7676ebde19cd092b3895c421045ca87895b429", size = 254277, upload-time = "2026-01-26T02:44:30.902Z" }, + { url = "https://files.pythonhosted.org/packages/6a/b2/5fb8c124d7561a4974c342bc8c778b471ebbeb3cc17df696f034a7e9afe7/multidict-6.7.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:12fad252f8b267cc75b66e8fc51b3079604e8d43a75428ffe193cd9e2195dfd6", size = 252291, upload-time = "2026-01-26T02:44:32.31Z" }, + { url = "https://files.pythonhosted.org/packages/5a/96/51d4e4e06bcce92577fcd488e22600bd38e4fd59c20cb49434d054903bd2/multidict-6.7.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:03ede2a6ffbe8ef936b92cb4529f27f42be7f56afcdab5ab739cd5f27fb1cbf9", size = 250156, upload-time = "2026-01-26T02:44:33.734Z" }, + { url = "https://files.pythonhosted.org/packages/db/6b/420e173eec5fba721a50e2a9f89eda89d9c98fded1124f8d5c675f7a0c0f/multidict-6.7.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:90efbcf47dbe33dcf643a1e400d67d59abeac5db07dc3f27d6bdeae497a2198c", size = 249742, upload-time = "2026-01-26T02:44:35.222Z" }, + { url = "https://files.pythonhosted.org/packages/44/a3/ec5b5bd98f306bc2aa297b8c6f11a46714a56b1e6ef5ebda50a4f5d7c5fb/multidict-6.7.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:5c4b9bfc148f5a91be9244d6264c53035c8a0dcd2f51f1c3c6e30e30ebaa1c84", size = 262221, upload-time = "2026-01-26T02:44:36.604Z" }, + { url = "https://files.pythonhosted.org/packages/cd/f7/e8c0d0da0cd1e28d10e624604e1a36bcc3353aaebdfdc3a43c72bc683a12/multidict-6.7.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:401c5a650f3add2472d1d288c26deebc540f99e2fb83e9525007a74cd2116f1d", size = 258664, upload-time = "2026-01-26T02:44:38.008Z" }, + { url = "https://files.pythonhosted.org/packages/52/da/151a44e8016dd33feed44f730bd856a66257c1ee7aed4f44b649fb7edeb3/multidict-6.7.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:97891f3b1b3ffbded884e2916cacf3c6fc87b66bb0dde46f7357404750559f33", size = 249490, upload-time = "2026-01-26T02:44:39.386Z" }, + { url = "https://files.pythonhosted.org/packages/87/af/a3b86bf9630b732897f6fc3f4c4714b90aa4361983ccbdcd6c0339b21b0c/multidict-6.7.1-cp313-cp313-win32.whl", hash = "sha256:e1c5988359516095535c4301af38d8a8838534158f649c05dd1050222321bcb3", size = 41695, upload-time = "2026-01-26T02:44:41.318Z" }, + { url = "https://files.pythonhosted.org/packages/b2/35/e994121b0e90e46134673422dd564623f93304614f5d11886b1b3e06f503/multidict-6.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:960c83bf01a95b12b08fd54324a4eb1d5b52c88932b5cba5d6e712bb3ed12eb5", size = 45884, upload-time = "2026-01-26T02:44:42.488Z" }, + { url = "https://files.pythonhosted.org/packages/ca/61/42d3e5dbf661242a69c97ea363f2d7b46c567da8eadef8890022be6e2ab0/multidict-6.7.1-cp313-cp313-win_arm64.whl", hash = "sha256:563fe25c678aaba333d5399408f5ec3c383ca5b663e7f774dd179a520b8144df", size = 43122, upload-time = "2026-01-26T02:44:43.664Z" }, + { url = "https://files.pythonhosted.org/packages/6d/b3/e6b21c6c4f314bb956016b0b3ef2162590a529b84cb831c257519e7fde44/multidict-6.7.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:c76c4bec1538375dad9d452d246ca5368ad6e1c9039dadcf007ae59c70619ea1", size = 83175, upload-time = "2026-01-26T02:44:44.894Z" }, + { url = "https://files.pythonhosted.org/packages/fb/76/23ecd2abfe0957b234f6c960f4ade497f55f2c16aeb684d4ecdbf1c95791/multidict-6.7.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:57b46b24b5d5ebcc978da4ec23a819a9402b4228b8a90d9c656422b4bdd8a963", size = 48460, upload-time = "2026-01-26T02:44:46.106Z" }, + { url = "https://files.pythonhosted.org/packages/c4/57/a0ed92b23f3a042c36bc4227b72b97eca803f5f1801c1ab77c8a212d455e/multidict-6.7.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e954b24433c768ce78ab7929e84ccf3422e46deb45a4dc9f93438f8217fa2d34", size = 46930, upload-time = "2026-01-26T02:44:47.278Z" }, + { url = "https://files.pythonhosted.org/packages/b5/66/02ec7ace29162e447f6382c495dc95826bf931d3818799bbef11e8f7df1a/multidict-6.7.1-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3bd231490fa7217cc832528e1cd8752a96f0125ddd2b5749390f7c3ec8721b65", size = 242582, upload-time = "2026-01-26T02:44:48.604Z" }, + { url = "https://files.pythonhosted.org/packages/58/18/64f5a795e7677670e872673aca234162514696274597b3708b2c0d276cce/multidict-6.7.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:253282d70d67885a15c8a7716f3a73edf2d635793ceda8173b9ecc21f2fb8292", size = 250031, upload-time = "2026-01-26T02:44:50.544Z" }, + { url = "https://files.pythonhosted.org/packages/c8/ed/e192291dbbe51a8290c5686f482084d31bcd9d09af24f63358c3d42fd284/multidict-6.7.1-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:0b4c48648d7649c9335cf1927a8b87fa692de3dcb15faa676c6a6f1f1aabda43", size = 228596, upload-time = "2026-01-26T02:44:51.951Z" }, + { url = "https://files.pythonhosted.org/packages/1e/7e/3562a15a60cf747397e7f2180b0a11dc0c38d9175a650e75fa1b4d325e15/multidict-6.7.1-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:98bc624954ec4d2c7cb074b8eefc2b5d0ce7d482e410df446414355d158fe4ca", size = 257492, upload-time = "2026-01-26T02:44:53.902Z" }, + { url = "https://files.pythonhosted.org/packages/24/02/7d0f9eae92b5249bb50ac1595b295f10e263dd0078ebb55115c31e0eaccd/multidict-6.7.1-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:1b99af4d9eec0b49927b4402bcbb58dea89d3e0db8806a4086117019939ad3dd", size = 255899, upload-time = "2026-01-26T02:44:55.316Z" }, + { url = "https://files.pythonhosted.org/packages/00/e3/9b60ed9e23e64c73a5cde95269ef1330678e9c6e34dd4eb6b431b85b5a10/multidict-6.7.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6aac4f16b472d5b7dc6f66a0d49dd57b0e0902090be16594dc9ebfd3d17c47e7", size = 247970, upload-time = "2026-01-26T02:44:56.783Z" }, + { url = "https://files.pythonhosted.org/packages/3e/06/538e58a63ed5cfb0bd4517e346b91da32fde409d839720f664e9a4ae4f9d/multidict-6.7.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:21f830fe223215dffd51f538e78c172ed7c7f60c9b96a2bf05c4848ad49921c3", size = 245060, upload-time = "2026-01-26T02:44:58.195Z" }, + { url = "https://files.pythonhosted.org/packages/b2/2f/d743a3045a97c895d401e9bd29aaa09b94f5cbdf1bd561609e5a6c431c70/multidict-6.7.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:f5dd81c45b05518b9aa4da4aa74e1c93d715efa234fd3e8a179df611cc85e5f4", size = 235888, upload-time = "2026-01-26T02:44:59.57Z" }, + { url = "https://files.pythonhosted.org/packages/38/83/5a325cac191ab28b63c52f14f1131f3b0a55ba3b9aa65a6d0bf2a9b921a0/multidict-6.7.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:eb304767bca2bb92fb9c5bd33cedc95baee5bb5f6c88e63706533a1c06ad08c8", size = 243554, upload-time = "2026-01-26T02:45:01.054Z" }, + { url = "https://files.pythonhosted.org/packages/20/1f/9d2327086bd15da2725ef6aae624208e2ef828ed99892b17f60c344e57ed/multidict-6.7.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:c9035dde0f916702850ef66460bc4239d89d08df4d02023a5926e7446724212c", size = 252341, upload-time = "2026-01-26T02:45:02.484Z" }, + { url = "https://files.pythonhosted.org/packages/e8/2c/2a1aa0280cf579d0f6eed8ee5211c4f1730bd7e06c636ba2ee6aafda302e/multidict-6.7.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:af959b9beeb66c822380f222f0e0a1889331597e81f1ded7f374f3ecb0fd6c52", size = 246391, upload-time = "2026-01-26T02:45:03.862Z" }, + { url = "https://files.pythonhosted.org/packages/e5/03/7ca022ffc36c5a3f6e03b179a5ceb829be9da5783e6fe395f347c0794680/multidict-6.7.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:41f2952231456154ee479651491e94118229844dd7226541788be783be2b5108", size = 243422, upload-time = "2026-01-26T02:45:05.296Z" }, + { url = "https://files.pythonhosted.org/packages/dc/1d/b31650eab6c5778aceed46ba735bd97f7c7d2f54b319fa916c0f96e7805b/multidict-6.7.1-cp313-cp313t-win32.whl", hash = "sha256:df9f19c28adcb40b6aae30bbaa1478c389efd50c28d541d76760199fc1037c32", size = 47770, upload-time = "2026-01-26T02:45:06.754Z" }, + { url = "https://files.pythonhosted.org/packages/ac/5b/2d2d1d522e51285bd61b1e20df8f47ae1a9d80839db0b24ea783b3832832/multidict-6.7.1-cp313-cp313t-win_amd64.whl", hash = "sha256:d54ecf9f301853f2c5e802da559604b3e95bb7a3b01a9c295c6ee591b9882de8", size = 53109, upload-time = "2026-01-26T02:45:08.044Z" }, + { url = "https://files.pythonhosted.org/packages/3d/a3/cc409ba012c83ca024a308516703cf339bdc4b696195644a7215a5164a24/multidict-6.7.1-cp313-cp313t-win_arm64.whl", hash = "sha256:5a37ca18e360377cfda1d62f5f382ff41f2b8c4ccb329ed974cc2e1643440118", size = 45573, upload-time = "2026-01-26T02:45:09.349Z" }, + { url = "https://files.pythonhosted.org/packages/91/cc/db74228a8be41884a567e88a62fd589a913708fcf180d029898c17a9a371/multidict-6.7.1-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8f333ec9c5eb1b7105e3b84b53141e66ca05a19a605368c55450b6ba208cb9ee", size = 75190, upload-time = "2026-01-26T02:45:10.651Z" }, + { url = "https://files.pythonhosted.org/packages/d5/22/492f2246bb5b534abd44804292e81eeaf835388901f0c574bac4eeec73c5/multidict-6.7.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:a407f13c188f804c759fc6a9f88286a565c242a76b27626594c133b82883b5c2", size = 44486, upload-time = "2026-01-26T02:45:11.938Z" }, + { url = "https://files.pythonhosted.org/packages/f1/4f/733c48f270565d78b4544f2baddc2fb2a245e5a8640254b12c36ac7ac68e/multidict-6.7.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:0e161ddf326db5577c3a4cc2d8648f81456e8a20d40415541587a71620d7a7d1", size = 43219, upload-time = "2026-01-26T02:45:14.346Z" }, + { url = "https://files.pythonhosted.org/packages/24/bb/2c0c2287963f4259c85e8bcbba9182ced8d7fca65c780c38e99e61629d11/multidict-6.7.1-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:1e3a8bb24342a8201d178c3b4984c26ba81a577c80d4d525727427460a50c22d", size = 245132, upload-time = "2026-01-26T02:45:15.712Z" }, + { url = "https://files.pythonhosted.org/packages/a7/f9/44d4b3064c65079d2467888794dea218d1601898ac50222ab8a9a8094460/multidict-6.7.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:97231140a50f5d447d3164f994b86a0bed7cd016e2682f8650d6a9158e14fd31", size = 252420, upload-time = "2026-01-26T02:45:17.293Z" }, + { url = "https://files.pythonhosted.org/packages/8b/13/78f7275e73fa17b24c9a51b0bd9d73ba64bb32d0ed51b02a746eb876abe7/multidict-6.7.1-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6b10359683bd8806a200fd2909e7c8ca3a7b24ec1d8132e483d58e791d881048", size = 233510, upload-time = "2026-01-26T02:45:19.356Z" }, + { url = "https://files.pythonhosted.org/packages/4b/25/8167187f62ae3cbd52da7893f58cb036b47ea3fb67138787c76800158982/multidict-6.7.1-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:283ddac99f7ac25a4acadbf004cb5ae34480bbeb063520f70ce397b281859362", size = 264094, upload-time = "2026-01-26T02:45:20.834Z" }, + { url = "https://files.pythonhosted.org/packages/a1/e7/69a3a83b7b030cf283fb06ce074a05a02322359783424d7edf0f15fe5022/multidict-6.7.1-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:538cec1e18c067d0e6103aa9a74f9e832904c957adc260e61cd9d8cf0c3b3d37", size = 260786, upload-time = "2026-01-26T02:45:22.818Z" }, + { url = "https://files.pythonhosted.org/packages/fe/3b/8ec5074bcfc450fe84273713b4b0a0dd47c0249358f5d82eb8104ffe2520/multidict-6.7.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7eee46ccb30ff48a1e35bb818cc90846c6be2b68240e42a78599166722cea709", size = 248483, upload-time = "2026-01-26T02:45:24.368Z" }, + { url = "https://files.pythonhosted.org/packages/48/5a/d5a99e3acbca0e29c5d9cba8f92ceb15dce78bab963b308ae692981e3a5d/multidict-6.7.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fa263a02f4f2dd2d11a7b1bb4362aa7cb1049f84a9235d31adf63f30143469a0", size = 248403, upload-time = "2026-01-26T02:45:25.982Z" }, + { url = "https://files.pythonhosted.org/packages/35/48/e58cd31f6c7d5102f2a4bf89f96b9cf7e00b6c6f3d04ecc44417c00a5a3c/multidict-6.7.1-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:2e1425e2f99ec5bd36c15a01b690a1a2456209c5deed58f95469ffb46039ccbb", size = 240315, upload-time = "2026-01-26T02:45:27.487Z" }, + { url = "https://files.pythonhosted.org/packages/94/33/1cd210229559cb90b6786c30676bb0c58249ff42f942765f88793b41fdce/multidict-6.7.1-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:497394b3239fc6f0e13a78a3e1b61296e72bf1c5f94b4c4eb80b265c37a131cd", size = 245528, upload-time = "2026-01-26T02:45:28.991Z" }, + { url = "https://files.pythonhosted.org/packages/64/f2/6e1107d226278c876c783056b7db43d800bb64c6131cec9c8dfb6903698e/multidict-6.7.1-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:233b398c29d3f1b9676b4b6f75c518a06fcb2ea0b925119fb2c1bc35c05e1601", size = 258784, upload-time = "2026-01-26T02:45:30.503Z" }, + { url = "https://files.pythonhosted.org/packages/4d/c1/11f664f14d525e4a1b5327a82d4de61a1db604ab34c6603bb3c2cc63ad34/multidict-6.7.1-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:93b1818e4a6e0930454f0f2af7dfce69307ca03cdcfb3739bf4d91241967b6c1", size = 251980, upload-time = "2026-01-26T02:45:32.603Z" }, + { url = "https://files.pythonhosted.org/packages/e1/9f/75a9ac888121d0c5bbd4ecf4eead45668b1766f6baabfb3b7f66a410e231/multidict-6.7.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f33dc2a3abe9249ea5d8360f969ec7f4142e7ac45ee7014d8f8d5acddf178b7b", size = 243602, upload-time = "2026-01-26T02:45:34.043Z" }, + { url = "https://files.pythonhosted.org/packages/9a/e7/50bf7b004cc8525d80dbbbedfdc7aed3e4c323810890be4413e589074032/multidict-6.7.1-cp314-cp314-win32.whl", hash = "sha256:3ab8b9d8b75aef9df299595d5388b14530839f6422333357af1339443cff777d", size = 40930, upload-time = "2026-01-26T02:45:36.278Z" }, + { url = "https://files.pythonhosted.org/packages/e0/bf/52f25716bbe93745595800f36fb17b73711f14da59ed0bb2eba141bc9f0f/multidict-6.7.1-cp314-cp314-win_amd64.whl", hash = "sha256:5e01429a929600e7dab7b166062d9bb54a5eed752384c7384c968c2afab8f50f", size = 45074, upload-time = "2026-01-26T02:45:37.546Z" }, + { url = "https://files.pythonhosted.org/packages/97/ab/22803b03285fa3a525f48217963da3a65ae40f6a1b6f6cf2768879e208f9/multidict-6.7.1-cp314-cp314-win_arm64.whl", hash = "sha256:4885cb0e817aef5d00a2e8451d4665c1808378dc27c2705f1bf4ef8505c0d2e5", size = 42471, upload-time = "2026-01-26T02:45:38.889Z" }, + { url = "https://files.pythonhosted.org/packages/e0/6d/f9293baa6146ba9507e360ea0292b6422b016907c393e2f63fc40ab7b7b5/multidict-6.7.1-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:0458c978acd8e6ea53c81eefaddbbee9c6c5e591f41b3f5e8e194780fe026581", size = 82401, upload-time = "2026-01-26T02:45:40.254Z" }, + { url = "https://files.pythonhosted.org/packages/7a/68/53b5494738d83558d87c3c71a486504d8373421c3e0dbb6d0db48ad42ee0/multidict-6.7.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:c0abd12629b0af3cf590982c0b413b1e7395cd4ec026f30986818ab95bfaa94a", size = 48143, upload-time = "2026-01-26T02:45:41.635Z" }, + { url = "https://files.pythonhosted.org/packages/37/e8/5284c53310dcdc99ce5d66563f6e5773531a9b9fe9ec7a615e9bc306b05f/multidict-6.7.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:14525a5f61d7d0c94b368a42cff4c9a4e7ba2d52e2672a7b23d84dc86fb02b0c", size = 46507, upload-time = "2026-01-26T02:45:42.99Z" }, + { url = "https://files.pythonhosted.org/packages/e4/fc/6800d0e5b3875568b4083ecf5f310dcf91d86d52573160834fb4bfcf5e4f/multidict-6.7.1-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:17307b22c217b4cf05033dabefe68255a534d637c6c9b0cc8382718f87be4262", size = 239358, upload-time = "2026-01-26T02:45:44.376Z" }, + { url = "https://files.pythonhosted.org/packages/41/75/4ad0973179361cdf3a113905e6e088173198349131be2b390f9fa4da5fc6/multidict-6.7.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a7e590ff876a3eaf1c02a4dfe0724b6e69a9e9de6d8f556816f29c496046e59", size = 246884, upload-time = "2026-01-26T02:45:47.167Z" }, + { url = "https://files.pythonhosted.org/packages/c3/9c/095bb28b5da139bd41fb9a5d5caff412584f377914bd8787c2aa98717130/multidict-6.7.1-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:5fa6a95dfee63893d80a34758cd0e0c118a30b8dcb46372bf75106c591b77889", size = 225878, upload-time = "2026-01-26T02:45:48.698Z" }, + { url = "https://files.pythonhosted.org/packages/07/d0/c0a72000243756e8f5a277b6b514fa005f2c73d481b7d9e47cd4568aa2e4/multidict-6.7.1-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a0543217a6a017692aa6ae5cc39adb75e587af0f3a82288b1492eb73dd6cc2a4", size = 253542, upload-time = "2026-01-26T02:45:50.164Z" }, + { url = "https://files.pythonhosted.org/packages/c0/6b/f69da15289e384ecf2a68837ec8b5ad8c33e973aa18b266f50fe55f24b8c/multidict-6.7.1-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f99fe611c312b3c1c0ace793f92464d8cd263cc3b26b5721950d977b006b6c4d", size = 252403, upload-time = "2026-01-26T02:45:51.779Z" }, + { url = "https://files.pythonhosted.org/packages/a2/76/b9669547afa5a1a25cd93eaca91c0da1c095b06b6d2d8ec25b713588d3a1/multidict-6.7.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9004d8386d133b7e6135679424c91b0b854d2d164af6ea3f289f8f2761064609", size = 244889, upload-time = "2026-01-26T02:45:53.27Z" }, + { url = "https://files.pythonhosted.org/packages/7e/a9/a50d2669e506dad33cfc45b5d574a205587b7b8a5f426f2fbb2e90882588/multidict-6.7.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:e628ef0e6859ffd8273c69412a2465c4be4a9517d07261b33334b5ec6f3c7489", size = 241982, upload-time = "2026-01-26T02:45:54.919Z" }, + { url = "https://files.pythonhosted.org/packages/c5/bb/1609558ad8b456b4827d3c5a5b775c93b87878fd3117ed3db3423dfbce1b/multidict-6.7.1-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:841189848ba629c3552035a6a7f5bf3b02eb304e9fea7492ca220a8eda6b0e5c", size = 232415, upload-time = "2026-01-26T02:45:56.981Z" }, + { url = "https://files.pythonhosted.org/packages/d8/59/6f61039d2aa9261871e03ab9dc058a550d240f25859b05b67fd70f80d4b3/multidict-6.7.1-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:ce1bbd7d780bb5a0da032e095c951f7014d6b0a205f8318308140f1a6aba159e", size = 240337, upload-time = "2026-01-26T02:45:58.698Z" }, + { url = "https://files.pythonhosted.org/packages/a1/29/fdc6a43c203890dc2ae9249971ecd0c41deaedfe00d25cb6564b2edd99eb/multidict-6.7.1-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:b26684587228afed0d50cf804cc71062cc9c1cdf55051c4c6345d372947b268c", size = 248788, upload-time = "2026-01-26T02:46:00.862Z" }, + { url = "https://files.pythonhosted.org/packages/a9/14/a153a06101323e4cf086ecee3faadba52ff71633d471f9685c42e3736163/multidict-6.7.1-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:9f9af11306994335398293f9958071019e3ab95e9a707dc1383a35613f6abcb9", size = 242842, upload-time = "2026-01-26T02:46:02.824Z" }, + { url = "https://files.pythonhosted.org/packages/41/5f/604ae839e64a4a6efc80db94465348d3b328ee955e37acb24badbcd24d83/multidict-6.7.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:b4938326284c4f1224178a560987b6cf8b4d38458b113d9b8c1db1a836e640a2", size = 240237, upload-time = "2026-01-26T02:46:05.898Z" }, + { url = "https://files.pythonhosted.org/packages/5f/60/c3a5187bf66f6fb546ff4ab8fb5a077cbdd832d7b1908d4365c7f74a1917/multidict-6.7.1-cp314-cp314t-win32.whl", hash = "sha256:98655c737850c064a65e006a3df7c997cd3b220be4ec8fe26215760b9697d4d7", size = 48008, upload-time = "2026-01-26T02:46:07.468Z" }, + { url = "https://files.pythonhosted.org/packages/0c/f7/addf1087b860ac60e6f382240f64fb99f8bfb532bb06f7c542b83c29ca61/multidict-6.7.1-cp314-cp314t-win_amd64.whl", hash = "sha256:497bde6223c212ba11d462853cfa4f0ae6ef97465033e7dc9940cdb3ab5b48e5", size = 53542, upload-time = "2026-01-26T02:46:08.809Z" }, + { url = "https://files.pythonhosted.org/packages/4c/81/4629d0aa32302ef7b2ec65c75a728cc5ff4fa410c50096174c1632e70b3e/multidict-6.7.1-cp314-cp314t-win_arm64.whl", hash = "sha256:2bbd113e0d4af5db41d5ebfe9ccaff89de2120578164f86a5d17d5a576d1e5b2", size = 44719, upload-time = "2026-01-26T02:46:11.146Z" }, + { url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" }, +] + [[package]] name = "nak-torch" version = "0.1.0" @@ -846,8 +1221,8 @@ examples = [ { name = "matplotlib" }, { name = "posteriordb" }, { name = "pyro-ppl" }, + { name = "pystan" }, { name = "scipy" }, - { name = "stanpy" }, ] [package.dev-dependencies] @@ -869,8 +1244,8 @@ requires-dist = [ { name = "numpy", specifier = ">=2.4.1" }, { name = "posteriordb", marker = "extra == 'examples'", specifier = ">=0.2.0" }, { name = "pyro-ppl", marker = "extra == 'examples'", specifier = ">=1.9.1" }, + { name = "pystan", marker = "extra == 'examples'", specifier = ">=3.10.1" }, { name = "scipy", marker = "extra == 'examples'", specifier = ">=1.17.1" }, - { name = "stanpy", marker = "extra == 'examples'", specifier = ">=0.2.11" }, { name = "torch", specifier = ">=2.10" }, { name = "tqdm", specifier = ">=4.67.1" }, ] @@ -1179,6 +1554,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b6/61/fae042894f4296ec49e3f193aff5d7c18440da9e48102c3315e1bc4519a7/parso-0.8.6-py2.py3-none-any.whl", hash = "sha256:2c549f800b70a5c4952197248825584cb00f033b29c692671d3bf08bf380baff", size = 106894, upload-time = "2026-02-09T15:45:21.391Z" }, ] +[[package]] +name = "pastel" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/76/f1/4594f5e0fcddb6953e5b8fe00da8c317b8b41b547e2b3ae2da7512943c62/pastel-0.2.1.tar.gz", hash = "sha256:e6581ac04e973cac858828c6202c1e1e81fee1dc7de7683f3e1ffe0bfd8a573d", size = 7555, upload-time = "2020-09-16T19:21:12.43Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/18/a8444036c6dd65ba3624c63b734d3ba95ba63ace513078e1580590075d21/pastel-0.2.1-py2.py3-none-any.whl", hash = "sha256:4349225fcdf6c2bb34d483e523475de5bb04a5c10ef711263452cb37d7dd4364", size = 5955, upload-time = "2020-09-16T19:21:11.409Z" }, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -1301,6 +1685,90 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, ] +[[package]] +name = "propcache" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/da/e9fc233cf63743258bff22b3dfa7ea5baef7b5bc324af47a0ad89b8ffc6f/propcache-0.4.1.tar.gz", hash = "sha256:f48107a8c637e80362555f37ecf49abe20370e557cc4ab374f04ec4423c97c3d", size = 46442, upload-time = "2025-10-08T19:49:02.291Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/0f/f17b1b2b221d5ca28b4b876e8bb046ac40466513960646bda8e1853cdfa2/propcache-0.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e153e9cd40cc8945138822807139367f256f89c6810c2634a4f6902b52d3b4e2", size = 80061, upload-time = "2025-10-08T19:46:46.075Z" }, + { url = "https://files.pythonhosted.org/packages/76/47/8ccf75935f51448ba9a16a71b783eb7ef6b9ee60f5d14c7f8a8a79fbeed7/propcache-0.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:cd547953428f7abb73c5ad82cbb32109566204260d98e41e5dfdc682eb7f8403", size = 46037, upload-time = "2025-10-08T19:46:47.23Z" }, + { url = "https://files.pythonhosted.org/packages/0a/b6/5c9a0e42df4d00bfb4a3cbbe5cf9f54260300c88a0e9af1f47ca5ce17ac0/propcache-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f048da1b4f243fc44f205dfd320933a951b8d89e0afd4c7cacc762a8b9165207", size = 47324, upload-time = "2025-10-08T19:46:48.384Z" }, + { url = "https://files.pythonhosted.org/packages/9e/d3/6c7ee328b39a81ee877c962469f1e795f9db87f925251efeb0545e0020d0/propcache-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ec17c65562a827bba85e3872ead335f95405ea1674860d96483a02f5c698fa72", size = 225505, upload-time = "2025-10-08T19:46:50.055Z" }, + { url = "https://files.pythonhosted.org/packages/01/5d/1c53f4563490b1d06a684742cc6076ef944bc6457df6051b7d1a877c057b/propcache-0.4.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:405aac25c6394ef275dee4c709be43745d36674b223ba4eb7144bf4d691b7367", size = 230242, upload-time = "2025-10-08T19:46:51.815Z" }, + { url = "https://files.pythonhosted.org/packages/20/e1/ce4620633b0e2422207c3cb774a0ee61cac13abc6217763a7b9e2e3f4a12/propcache-0.4.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0013cb6f8dde4b2a2f66903b8ba740bdfe378c943c4377a200551ceb27f379e4", size = 238474, upload-time = "2025-10-08T19:46:53.208Z" }, + { url = "https://files.pythonhosted.org/packages/46/4b/3aae6835b8e5f44ea6a68348ad90f78134047b503765087be2f9912140ea/propcache-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15932ab57837c3368b024473a525e25d316d8353016e7cc0e5ba9eb343fbb1cf", size = 221575, upload-time = "2025-10-08T19:46:54.511Z" }, + { url = "https://files.pythonhosted.org/packages/6e/a5/8a5e8678bcc9d3a1a15b9a29165640d64762d424a16af543f00629c87338/propcache-0.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:031dce78b9dc099f4c29785d9cf5577a3faf9ebf74ecbd3c856a7b92768c3df3", size = 216736, upload-time = "2025-10-08T19:46:56.212Z" }, + { url = "https://files.pythonhosted.org/packages/f1/63/b7b215eddeac83ca1c6b934f89d09a625aa9ee4ba158338854c87210cc36/propcache-0.4.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ab08df6c9a035bee56e31af99be621526bd237bea9f32def431c656b29e41778", size = 213019, upload-time = "2025-10-08T19:46:57.595Z" }, + { url = "https://files.pythonhosted.org/packages/57/74/f580099a58c8af587cac7ba19ee7cb418506342fbbe2d4a4401661cca886/propcache-0.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4d7af63f9f93fe593afbf104c21b3b15868efb2c21d07d8732c0c4287e66b6a6", size = 220376, upload-time = "2025-10-08T19:46:59.067Z" }, + { url = "https://files.pythonhosted.org/packages/c4/ee/542f1313aff7eaf19c2bb758c5d0560d2683dac001a1c96d0774af799843/propcache-0.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cfc27c945f422e8b5071b6e93169679e4eb5bf73bbcbf1ba3ae3a83d2f78ebd9", size = 226988, upload-time = "2025-10-08T19:47:00.544Z" }, + { url = "https://files.pythonhosted.org/packages/8f/18/9c6b015dd9c6930f6ce2229e1f02fb35298b847f2087ea2b436a5bfa7287/propcache-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:35c3277624a080cc6ec6f847cbbbb5b49affa3598c4535a0a4682a697aaa5c75", size = 215615, upload-time = "2025-10-08T19:47:01.968Z" }, + { url = "https://files.pythonhosted.org/packages/80/9e/e7b85720b98c45a45e1fca6a177024934dc9bc5f4d5dd04207f216fc33ed/propcache-0.4.1-cp312-cp312-win32.whl", hash = "sha256:671538c2262dadb5ba6395e26c1731e1d52534bfe9ae56d0b5573ce539266aa8", size = 38066, upload-time = "2025-10-08T19:47:03.503Z" }, + { url = "https://files.pythonhosted.org/packages/54/09/d19cff2a5aaac632ec8fc03737b223597b1e347416934c1b3a7df079784c/propcache-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:cb2d222e72399fcf5890d1d5cc1060857b9b236adff2792ff48ca2dfd46c81db", size = 41655, upload-time = "2025-10-08T19:47:04.973Z" }, + { url = "https://files.pythonhosted.org/packages/68/ab/6b5c191bb5de08036a8c697b265d4ca76148efb10fa162f14af14fb5f076/propcache-0.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:204483131fb222bdaaeeea9f9e6c6ed0cac32731f75dfc1d4a567fc1926477c1", size = 37789, upload-time = "2025-10-08T19:47:06.077Z" }, + { url = "https://files.pythonhosted.org/packages/bf/df/6d9c1b6ac12b003837dde8a10231a7344512186e87b36e855bef32241942/propcache-0.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:43eedf29202c08550aac1d14e0ee619b0430aaef78f85864c1a892294fbc28cf", size = 77750, upload-time = "2025-10-08T19:47:07.648Z" }, + { url = "https://files.pythonhosted.org/packages/8b/e8/677a0025e8a2acf07d3418a2e7ba529c9c33caf09d3c1f25513023c1db56/propcache-0.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d62cdfcfd89ccb8de04e0eda998535c406bf5e060ffd56be6c586cbcc05b3311", size = 44780, upload-time = "2025-10-08T19:47:08.851Z" }, + { url = "https://files.pythonhosted.org/packages/89/a4/92380f7ca60f99ebae761936bc48a72a639e8a47b29050615eef757cb2a7/propcache-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cae65ad55793da34db5f54e4029b89d3b9b9490d8abe1b4c7ab5d4b8ec7ebf74", size = 46308, upload-time = "2025-10-08T19:47:09.982Z" }, + { url = "https://files.pythonhosted.org/packages/2d/48/c5ac64dee5262044348d1d78a5f85dd1a57464a60d30daee946699963eb3/propcache-0.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:333ddb9031d2704a301ee3e506dc46b1fe5f294ec198ed6435ad5b6a085facfe", size = 208182, upload-time = "2025-10-08T19:47:11.319Z" }, + { url = "https://files.pythonhosted.org/packages/c6/0c/cd762dd011a9287389a6a3eb43aa30207bde253610cca06824aeabfe9653/propcache-0.4.1-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:fd0858c20f078a32cf55f7e81473d96dcf3b93fd2ccdb3d40fdf54b8573df3af", size = 211215, upload-time = "2025-10-08T19:47:13.146Z" }, + { url = "https://files.pythonhosted.org/packages/30/3e/49861e90233ba36890ae0ca4c660e95df565b2cd15d4a68556ab5865974e/propcache-0.4.1-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:678ae89ebc632c5c204c794f8dab2837c5f159aeb59e6ed0539500400577298c", size = 218112, upload-time = "2025-10-08T19:47:14.913Z" }, + { url = "https://files.pythonhosted.org/packages/f1/8b/544bc867e24e1bd48f3118cecd3b05c694e160a168478fa28770f22fd094/propcache-0.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d472aeb4fbf9865e0c6d622d7f4d54a4e101a89715d8904282bb5f9a2f476c3f", size = 204442, upload-time = "2025-10-08T19:47:16.277Z" }, + { url = "https://files.pythonhosted.org/packages/50/a6/4282772fd016a76d3e5c0df58380a5ea64900afd836cec2c2f662d1b9bb3/propcache-0.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4d3df5fa7e36b3225954fba85589da77a0fe6a53e3976de39caf04a0db4c36f1", size = 199398, upload-time = "2025-10-08T19:47:17.962Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ec/d8a7cd406ee1ddb705db2139f8a10a8a427100347bd698e7014351c7af09/propcache-0.4.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:ee17f18d2498f2673e432faaa71698032b0127ebf23ae5974eeaf806c279df24", size = 196920, upload-time = "2025-10-08T19:47:19.355Z" }, + { url = "https://files.pythonhosted.org/packages/f6/6c/f38ab64af3764f431e359f8baf9e0a21013e24329e8b85d2da32e8ed07ca/propcache-0.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:580e97762b950f993ae618e167e7be9256b8353c2dcd8b99ec100eb50f5286aa", size = 203748, upload-time = "2025-10-08T19:47:21.338Z" }, + { url = "https://files.pythonhosted.org/packages/d6/e3/fa846bd70f6534d647886621388f0a265254d30e3ce47e5c8e6e27dbf153/propcache-0.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:501d20b891688eb8e7aa903021f0b72d5a55db40ffaab27edefd1027caaafa61", size = 205877, upload-time = "2025-10-08T19:47:23.059Z" }, + { url = "https://files.pythonhosted.org/packages/e2/39/8163fc6f3133fea7b5f2827e8eba2029a0277ab2c5beee6c1db7b10fc23d/propcache-0.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a0bd56e5b100aef69bd8562b74b46254e7c8812918d3baa700c8a8009b0af66", size = 199437, upload-time = "2025-10-08T19:47:24.445Z" }, + { url = "https://files.pythonhosted.org/packages/93/89/caa9089970ca49c7c01662bd0eeedfe85494e863e8043565aeb6472ce8fe/propcache-0.4.1-cp313-cp313-win32.whl", hash = "sha256:bcc9aaa5d80322bc2fb24bb7accb4a30f81e90ab8d6ba187aec0744bc302ad81", size = 37586, upload-time = "2025-10-08T19:47:25.736Z" }, + { url = "https://files.pythonhosted.org/packages/f5/ab/f76ec3c3627c883215b5c8080debb4394ef5a7a29be811f786415fc1e6fd/propcache-0.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:381914df18634f5494334d201e98245c0596067504b9372d8cf93f4bb23e025e", size = 40790, upload-time = "2025-10-08T19:47:26.847Z" }, + { url = "https://files.pythonhosted.org/packages/59/1b/e71ae98235f8e2ba5004d8cb19765a74877abf189bc53fc0c80d799e56c3/propcache-0.4.1-cp313-cp313-win_arm64.whl", hash = "sha256:8873eb4460fd55333ea49b7d189749ecf6e55bf85080f11b1c4530ed3034cba1", size = 37158, upload-time = "2025-10-08T19:47:27.961Z" }, + { url = "https://files.pythonhosted.org/packages/83/ce/a31bbdfc24ee0dcbba458c8175ed26089cf109a55bbe7b7640ed2470cfe9/propcache-0.4.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:92d1935ee1f8d7442da9c0c4fa7ac20d07e94064184811b685f5c4fada64553b", size = 81451, upload-time = "2025-10-08T19:47:29.445Z" }, + { url = "https://files.pythonhosted.org/packages/25/9c/442a45a470a68456e710d96cacd3573ef26a1d0a60067e6a7d5e655621ed/propcache-0.4.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:473c61b39e1460d386479b9b2f337da492042447c9b685f28be4f74d3529e566", size = 46374, upload-time = "2025-10-08T19:47:30.579Z" }, + { url = "https://files.pythonhosted.org/packages/f4/bf/b1d5e21dbc3b2e889ea4327044fb16312a736d97640fb8b6aa3f9c7b3b65/propcache-0.4.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:c0ef0aaafc66fbd87842a3fe3902fd889825646bc21149eafe47be6072725835", size = 48396, upload-time = "2025-10-08T19:47:31.79Z" }, + { url = "https://files.pythonhosted.org/packages/f4/04/5b4c54a103d480e978d3c8a76073502b18db0c4bc17ab91b3cb5092ad949/propcache-0.4.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f95393b4d66bfae908c3ca8d169d5f79cd65636ae15b5e7a4f6e67af675adb0e", size = 275950, upload-time = "2025-10-08T19:47:33.481Z" }, + { url = "https://files.pythonhosted.org/packages/b4/c1/86f846827fb969c4b78b0af79bba1d1ea2156492e1b83dea8b8a6ae27395/propcache-0.4.1-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c07fda85708bc48578467e85099645167a955ba093be0a2dcba962195676e859", size = 273856, upload-time = "2025-10-08T19:47:34.906Z" }, + { url = "https://files.pythonhosted.org/packages/36/1d/fc272a63c8d3bbad6878c336c7a7dea15e8f2d23a544bda43205dfa83ada/propcache-0.4.1-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:af223b406d6d000830c6f65f1e6431783fc3f713ba3e6cc8c024d5ee96170a4b", size = 280420, upload-time = "2025-10-08T19:47:36.338Z" }, + { url = "https://files.pythonhosted.org/packages/07/0c/01f2219d39f7e53d52e5173bcb09c976609ba30209912a0680adfb8c593a/propcache-0.4.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a78372c932c90ee474559c5ddfffd718238e8673c340dc21fe45c5b8b54559a0", size = 263254, upload-time = "2025-10-08T19:47:37.692Z" }, + { url = "https://files.pythonhosted.org/packages/2d/18/cd28081658ce597898f0c4d174d4d0f3c5b6d4dc27ffafeef835c95eb359/propcache-0.4.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:564d9f0d4d9509e1a870c920a89b2fec951b44bf5ba7d537a9e7c1ccec2c18af", size = 261205, upload-time = "2025-10-08T19:47:39.659Z" }, + { url = "https://files.pythonhosted.org/packages/7a/71/1f9e22eb8b8316701c2a19fa1f388c8a3185082607da8e406a803c9b954e/propcache-0.4.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:17612831fda0138059cc5546f4d12a2aacfb9e47068c06af35c400ba58ba7393", size = 247873, upload-time = "2025-10-08T19:47:41.084Z" }, + { url = "https://files.pythonhosted.org/packages/4a/65/3d4b61f36af2b4eddba9def857959f1016a51066b4f1ce348e0cf7881f58/propcache-0.4.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:41a89040cb10bd345b3c1a873b2bf36413d48da1def52f268a055f7398514874", size = 262739, upload-time = "2025-10-08T19:47:42.51Z" }, + { url = "https://files.pythonhosted.org/packages/2a/42/26746ab087faa77c1c68079b228810436ccd9a5ce9ac85e2b7307195fd06/propcache-0.4.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e35b88984e7fa64aacecea39236cee32dd9bd8c55f57ba8a75cf2399553f9bd7", size = 263514, upload-time = "2025-10-08T19:47:43.927Z" }, + { url = "https://files.pythonhosted.org/packages/94/13/630690fe201f5502d2403dd3cfd451ed8858fe3c738ee88d095ad2ff407b/propcache-0.4.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6f8b465489f927b0df505cbe26ffbeed4d6d8a2bbc61ce90eb074ff129ef0ab1", size = 257781, upload-time = "2025-10-08T19:47:45.448Z" }, + { url = "https://files.pythonhosted.org/packages/92/f7/1d4ec5841505f423469efbfc381d64b7b467438cd5a4bbcbb063f3b73d27/propcache-0.4.1-cp313-cp313t-win32.whl", hash = "sha256:2ad890caa1d928c7c2965b48f3a3815c853180831d0e5503d35cf00c472f4717", size = 41396, upload-time = "2025-10-08T19:47:47.202Z" }, + { url = "https://files.pythonhosted.org/packages/48/f0/615c30622316496d2cbbc29f5985f7777d3ada70f23370608c1d3e081c1f/propcache-0.4.1-cp313-cp313t-win_amd64.whl", hash = "sha256:f7ee0e597f495cf415bcbd3da3caa3bd7e816b74d0d52b8145954c5e6fd3ff37", size = 44897, upload-time = "2025-10-08T19:47:48.336Z" }, + { url = "https://files.pythonhosted.org/packages/fd/ca/6002e46eccbe0e33dcd4069ef32f7f1c9e243736e07adca37ae8c4830ec3/propcache-0.4.1-cp313-cp313t-win_arm64.whl", hash = "sha256:929d7cbe1f01bb7baffb33dc14eb5691c95831450a26354cd210a8155170c93a", size = 39789, upload-time = "2025-10-08T19:47:49.876Z" }, + { url = "https://files.pythonhosted.org/packages/8e/5c/bca52d654a896f831b8256683457ceddd490ec18d9ec50e97dfd8fc726a8/propcache-0.4.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:3f7124c9d820ba5548d431afb4632301acf965db49e666aa21c305cbe8c6de12", size = 78152, upload-time = "2025-10-08T19:47:51.051Z" }, + { url = "https://files.pythonhosted.org/packages/65/9b/03b04e7d82a5f54fb16113d839f5ea1ede58a61e90edf515f6577c66fa8f/propcache-0.4.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:c0d4b719b7da33599dfe3b22d3db1ef789210a0597bc650b7cee9c77c2be8c5c", size = 44869, upload-time = "2025-10-08T19:47:52.594Z" }, + { url = "https://files.pythonhosted.org/packages/b2/fa/89a8ef0468d5833a23fff277b143d0573897cf75bd56670a6d28126c7d68/propcache-0.4.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:9f302f4783709a78240ebc311b793f123328716a60911d667e0c036bc5dcbded", size = 46596, upload-time = "2025-10-08T19:47:54.073Z" }, + { url = "https://files.pythonhosted.org/packages/86/bd/47816020d337f4a746edc42fe8d53669965138f39ee117414c7d7a340cfe/propcache-0.4.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c80ee5802e3fb9ea37938e7eecc307fb984837091d5fd262bb37238b1ae97641", size = 206981, upload-time = "2025-10-08T19:47:55.715Z" }, + { url = "https://files.pythonhosted.org/packages/df/f6/c5fa1357cc9748510ee55f37173eb31bfde6d94e98ccd9e6f033f2fc06e1/propcache-0.4.1-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ed5a841e8bb29a55fb8159ed526b26adc5bdd7e8bd7bf793ce647cb08656cdf4", size = 211490, upload-time = "2025-10-08T19:47:57.499Z" }, + { url = "https://files.pythonhosted.org/packages/80/1e/e5889652a7c4a3846683401a48f0f2e5083ce0ec1a8a5221d8058fbd1adf/propcache-0.4.1-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:55c72fd6ea2da4c318e74ffdf93c4fe4e926051133657459131a95c846d16d44", size = 215371, upload-time = "2025-10-08T19:47:59.317Z" }, + { url = "https://files.pythonhosted.org/packages/b2/f2/889ad4b2408f72fe1a4f6a19491177b30ea7bf1a0fd5f17050ca08cfc882/propcache-0.4.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8326e144341460402713f91df60ade3c999d601e7eb5ff8f6f7862d54de0610d", size = 201424, upload-time = "2025-10-08T19:48:00.67Z" }, + { url = "https://files.pythonhosted.org/packages/27/73/033d63069b57b0812c8bd19f311faebeceb6ba31b8f32b73432d12a0b826/propcache-0.4.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:060b16ae65bc098da7f6d25bf359f1f31f688384858204fe5d652979e0015e5b", size = 197566, upload-time = "2025-10-08T19:48:02.604Z" }, + { url = "https://files.pythonhosted.org/packages/dc/89/ce24f3dc182630b4e07aa6d15f0ff4b14ed4b9955fae95a0b54c58d66c05/propcache-0.4.1-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:89eb3fa9524f7bec9de6e83cf3faed9d79bffa560672c118a96a171a6f55831e", size = 193130, upload-time = "2025-10-08T19:48:04.499Z" }, + { url = "https://files.pythonhosted.org/packages/a9/24/ef0d5fd1a811fb5c609278d0209c9f10c35f20581fcc16f818da959fc5b4/propcache-0.4.1-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:dee69d7015dc235f526fe80a9c90d65eb0039103fe565776250881731f06349f", size = 202625, upload-time = "2025-10-08T19:48:06.213Z" }, + { url = "https://files.pythonhosted.org/packages/f5/02/98ec20ff5546f68d673df2f7a69e8c0d076b5abd05ca882dc7ee3a83653d/propcache-0.4.1-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:5558992a00dfd54ccbc64a32726a3357ec93825a418a401f5cc67df0ac5d9e49", size = 204209, upload-time = "2025-10-08T19:48:08.432Z" }, + { url = "https://files.pythonhosted.org/packages/a0/87/492694f76759b15f0467a2a93ab68d32859672b646aa8a04ce4864e7932d/propcache-0.4.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c9b822a577f560fbd9554812526831712c1436d2c046cedee4c3796d3543b144", size = 197797, upload-time = "2025-10-08T19:48:09.968Z" }, + { url = "https://files.pythonhosted.org/packages/ee/36/66367de3575db1d2d3f3d177432bd14ee577a39d3f5d1b3d5df8afe3b6e2/propcache-0.4.1-cp314-cp314-win32.whl", hash = "sha256:ab4c29b49d560fe48b696cdcb127dd36e0bc2472548f3bf56cc5cb3da2b2984f", size = 38140, upload-time = "2025-10-08T19:48:11.232Z" }, + { url = "https://files.pythonhosted.org/packages/0c/2a/a758b47de253636e1b8aef181c0b4f4f204bf0dd964914fb2af90a95b49b/propcache-0.4.1-cp314-cp314-win_amd64.whl", hash = "sha256:5a103c3eb905fcea0ab98be99c3a9a5ab2de60228aa5aceedc614c0281cf6153", size = 41257, upload-time = "2025-10-08T19:48:12.707Z" }, + { url = "https://files.pythonhosted.org/packages/34/5e/63bd5896c3fec12edcbd6f12508d4890d23c265df28c74b175e1ef9f4f3b/propcache-0.4.1-cp314-cp314-win_arm64.whl", hash = "sha256:74c1fb26515153e482e00177a1ad654721bf9207da8a494a0c05e797ad27b992", size = 38097, upload-time = "2025-10-08T19:48:13.923Z" }, + { url = "https://files.pythonhosted.org/packages/99/85/9ff785d787ccf9bbb3f3106f79884a130951436f58392000231b4c737c80/propcache-0.4.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:824e908bce90fb2743bd6b59db36eb4f45cd350a39637c9f73b1c1ea66f5b75f", size = 81455, upload-time = "2025-10-08T19:48:15.16Z" }, + { url = "https://files.pythonhosted.org/packages/90/85/2431c10c8e7ddb1445c1f7c4b54d886e8ad20e3c6307e7218f05922cad67/propcache-0.4.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:c2b5e7db5328427c57c8e8831abda175421b709672f6cfc3d630c3b7e2146393", size = 46372, upload-time = "2025-10-08T19:48:16.424Z" }, + { url = "https://files.pythonhosted.org/packages/01/20/b0972d902472da9bcb683fa595099911f4d2e86e5683bcc45de60dd05dc3/propcache-0.4.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6f6ff873ed40292cd4969ef5310179afd5db59fdf055897e282485043fc80ad0", size = 48411, upload-time = "2025-10-08T19:48:17.577Z" }, + { url = "https://files.pythonhosted.org/packages/e2/e3/7dc89f4f21e8f99bad3d5ddb3a3389afcf9da4ac69e3deb2dcdc96e74169/propcache-0.4.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:49a2dc67c154db2c1463013594c458881a069fcf98940e61a0569016a583020a", size = 275712, upload-time = "2025-10-08T19:48:18.901Z" }, + { url = "https://files.pythonhosted.org/packages/20/67/89800c8352489b21a8047c773067644e3897f02ecbbd610f4d46b7f08612/propcache-0.4.1-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:005f08e6a0529984491e37d8dbc3dd86f84bd78a8ceb5fa9a021f4c48d4984be", size = 273557, upload-time = "2025-10-08T19:48:20.762Z" }, + { url = "https://files.pythonhosted.org/packages/e2/a1/b52b055c766a54ce6d9c16d9aca0cad8059acd9637cdf8aa0222f4a026ef/propcache-0.4.1-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5c3310452e0d31390da9035c348633b43d7e7feb2e37be252be6da45abd1abcc", size = 280015, upload-time = "2025-10-08T19:48:22.592Z" }, + { url = "https://files.pythonhosted.org/packages/48/c8/33cee30bd890672c63743049f3c9e4be087e6780906bfc3ec58528be59c1/propcache-0.4.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4c3c70630930447f9ef1caac7728c8ad1c56bc5015338b20fed0d08ea2480b3a", size = 262880, upload-time = "2025-10-08T19:48:23.947Z" }, + { url = "https://files.pythonhosted.org/packages/0c/b1/8f08a143b204b418285c88b83d00edbd61afbc2c6415ffafc8905da7038b/propcache-0.4.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8e57061305815dfc910a3634dcf584f08168a8836e6999983569f51a8544cd89", size = 260938, upload-time = "2025-10-08T19:48:25.656Z" }, + { url = "https://files.pythonhosted.org/packages/cf/12/96e4664c82ca2f31e1c8dff86afb867348979eb78d3cb8546a680287a1e9/propcache-0.4.1-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:521a463429ef54143092c11a77e04056dd00636f72e8c45b70aaa3140d639726", size = 247641, upload-time = "2025-10-08T19:48:27.207Z" }, + { url = "https://files.pythonhosted.org/packages/18/ed/e7a9cfca28133386ba52278136d42209d3125db08d0a6395f0cba0c0285c/propcache-0.4.1-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:120c964da3fdc75e3731aa392527136d4ad35868cc556fd09bb6d09172d9a367", size = 262510, upload-time = "2025-10-08T19:48:28.65Z" }, + { url = "https://files.pythonhosted.org/packages/f5/76/16d8bf65e8845dd62b4e2b57444ab81f07f40caa5652b8969b87ddcf2ef6/propcache-0.4.1-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:d8f353eb14ee3441ee844ade4277d560cdd68288838673273b978e3d6d2c8f36", size = 263161, upload-time = "2025-10-08T19:48:30.133Z" }, + { url = "https://files.pythonhosted.org/packages/e7/70/c99e9edb5d91d5ad8a49fa3c1e8285ba64f1476782fed10ab251ff413ba1/propcache-0.4.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ab2943be7c652f09638800905ee1bab2c544e537edb57d527997a24c13dc1455", size = 257393, upload-time = "2025-10-08T19:48:31.567Z" }, + { url = "https://files.pythonhosted.org/packages/08/02/87b25304249a35c0915d236575bc3574a323f60b47939a2262b77632a3ee/propcache-0.4.1-cp314-cp314t-win32.whl", hash = "sha256:05674a162469f31358c30bcaa8883cb7829fa3110bf9c0991fe27d7896c42d85", size = 42546, upload-time = "2025-10-08T19:48:32.872Z" }, + { url = "https://files.pythonhosted.org/packages/cb/ef/3c6ecf8b317aa982f309835e8f96987466123c6e596646d4e6a1dfcd080f/propcache-0.4.1-cp314-cp314t-win_amd64.whl", hash = "sha256:990f6b3e2a27d683cb7602ed6c86f15ee6b43b1194736f9baaeb93d0016633b1", size = 46259, upload-time = "2025-10-08T19:48:34.226Z" }, + { url = "https://files.pythonhosted.org/packages/c4/2d/346e946d4951f37eca1e4f55be0f0174c52cd70720f84029b02f296f4a38/propcache-0.4.1-cp314-cp314t-win_arm64.whl", hash = "sha256:ecef2343af4cc68e05131e45024ba34f6095821988a9d0a02aa7c73fcc448aa9", size = 40428, upload-time = "2025-10-08T19:48:35.441Z" }, + { url = "https://files.pythonhosted.org/packages/5b/5a/bc7b4a4ef808fa59a816c17b20c4bef6884daebbdf627ff2a161da67da19/propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237", size = 13305, upload-time = "2025-10-08T19:49:00.792Z" }, +] + [[package]] name = "psutil" version = "7.2.2" @@ -1365,6 +1833,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, ] +[[package]] +name = "pylev" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/f2/404d2bfa30fb4ee7c7a7435d593f9f698b25d191cafec69dd0c726f02f11/pylev-1.4.0.tar.gz", hash = "sha256:9e77e941042ad3a4cc305dcdf2b2dec1aec2fbe3dd9015d2698ad02b173006d1", size = 4710, upload-time = "2021-05-30T20:07:59.989Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/78/95cfe72991d22994f0ec5a3b742b31c95a28344d33e06b69406b68398a29/pylev-1.4.0-py2.py3-none-any.whl", hash = "sha256:7b2e2aa7b00e05bb3f7650eb506fc89f474f70493271a35c242d9a92188ad3dd", size = 6080, upload-time = "2021-05-30T20:07:58.473Z" }, +] + [[package]] name = "pyparsing" version = "3.3.2" @@ -1399,6 +1876,55 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/37/def183a2a2c8619d92649d62fe0622c4c6c62f60e4151e8fbaa409e7d5ab/pyro_ppl-1.9.1-py3-none-any.whl", hash = "sha256:91fb2c8740d9d3bd548180ac5ecfa04552ed8c471a1ab66870180663b8f09852", size = 755956, upload-time = "2024-06-02T00:37:37.486Z" }, ] +[[package]] +name = "pysimdjson" +version = "7.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/24/65e3cad88e74ef8ca59fefded953eb78ebface8a3199c3a97fe318a7387b/pysimdjson-7.0.2.tar.gz", hash = "sha256:44cf276e48912a3b9c7ca362c14da8420a7ac15a9f1a16ec95becff86db3904a", size = 1397812, upload-time = "2025-06-28T20:37:24.071Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/81/2a7bee8961e9519084ee290bb7135844f1f786ec8a26f62d48e7fd23a08b/pysimdjson-7.0.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:8ea5ffbdfde6a26b05bec12263ffacf8435d2e51c3793b44aa090fb38e709434", size = 1877768, upload-time = "2025-06-28T20:36:38.463Z" }, + { url = "https://files.pythonhosted.org/packages/b3/55/dfa21b647ff1a54e5925664ebfe3f1f800375546f0665347f3041a52bf5a/pysimdjson-7.0.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4fbe295c84bd9406ac8fc38ab76a6ff1187df11be9348e5937f9dcc42f41c8f8", size = 1656024, upload-time = "2025-06-28T20:36:39.847Z" }, + { url = "https://files.pythonhosted.org/packages/64/bd/06b744b0b33f4932ad4ed51fdb8ec5eeca6f7980ad502839dbfbe5ac60c9/pysimdjson-7.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:abbbd51ef301083c9ee885d1ba8d3c2081c462d56c2d0e2f603cc917a44f7ed5", size = 2771741, upload-time = "2025-06-28T20:36:41.249Z" }, + { url = "https://files.pythonhosted.org/packages/90/a4/c13afff7d4cd2fd001508f0d411063a8a9c451d694178b5230d50c8caf98/pysimdjson-7.0.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:14ca76010e5d82f4c0de90586a940e57c28beee937b4a53ef239b88ebee7190e", size = 2823997, upload-time = "2025-06-28T20:36:42.704Z" }, + { url = "https://files.pythonhosted.org/packages/58/da/459c89f3dbb8344f6b2a374850d13522cc9a89726faea4319568034f1f1f/pysimdjson-7.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1de838fc7aa473db24ddacc0b285928bd74d5830755f8471b17c34e78e94840", size = 3248858, upload-time = "2025-06-28T20:36:43.969Z" }, + { url = "https://files.pythonhosted.org/packages/d6/90/c9274cb68412b2b119a0d72c71d57b01f05397b59afc7cec9ff0b28a88d5/pysimdjson-7.0.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:061259784a9a4746d40a3a3f20542a19bd0e403e49af4aa3bd9a1626429ce704", size = 2529651, upload-time = "2025-06-28T20:36:45.266Z" }, + { url = "https://files.pythonhosted.org/packages/95/3b/8f3a3866daa6776ea3d3986b0c21cc678bd0bb5872a19a18170fae396e90/pysimdjson-7.0.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:27c2e4cde872b8d3a05dc855341508d11d056bb3b25eddbc17e533417a848a52", size = 3664874, upload-time = "2025-06-28T20:36:46.541Z" }, + { url = "https://files.pythonhosted.org/packages/1e/21/376e54868918d8b4831fb8653c1976615f99a11d95e0502ecaaa7a306d32/pysimdjson-7.0.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:41a18886861d47b63ef6231796a30ccc547bf3772a06fa60b681ee8f00a614ce", size = 3579057, upload-time = "2025-06-28T20:36:47.843Z" }, + { url = "https://files.pythonhosted.org/packages/5f/92/29bf4549ec6d692aca1cc11b1ff8a8bf8f742dd09e834f649e2567eb1438/pysimdjson-7.0.2-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:fdbd392590613ddbc4922ab5374282dddefa94471fc7a97bc2c1df6a450dd671", size = 3818097, upload-time = "2025-06-28T20:36:49.319Z" }, + { url = "https://files.pythonhosted.org/packages/9a/f8/ff0a6e3ee124eef780f164c95ea95ccca1ac04e4cff483e728aa029e7b36/pysimdjson-7.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cb217ddaedd5f28ca7db16e4ea972f02c6db380827ec312c7e6a9371ca5e4d7c", size = 4201879, upload-time = "2025-06-28T20:36:50.801Z" }, + { url = "https://files.pythonhosted.org/packages/4a/b0/7f60a32fef8b97407f07c80d367fb161c9245bd3c1de1597c9f4cb1c6536/pysimdjson-7.0.2-cp312-cp312-win32.whl", hash = "sha256:bf5af81e19b0cef57679523759f9219e2641e5156a4ee5b854e49e3e6b1690ab", size = 1529773, upload-time = "2025-06-28T20:36:51.97Z" }, + { url = "https://files.pythonhosted.org/packages/28/e7/b127c677f6aa8991ba6f9ea99a08aa167ab93a1844f6da35c65fa4b98179/pysimdjson-7.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:782ee03679eaea5b28d9bc9279bc0f0f03d251c17571396f3ed50ba86023d88f", size = 1574523, upload-time = "2025-06-28T20:36:53.103Z" }, + { url = "https://files.pythonhosted.org/packages/65/65/bf171e0dde8a40a56c6fde4e700daa3b172f1781b26478e92c34317f1225/pysimdjson-7.0.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a721cc23cd6240430b2c862caff79a411abc987290859cd0f9c5a3e29efa1d2c", size = 1877151, upload-time = "2025-06-28T20:36:54.199Z" }, + { url = "https://files.pythonhosted.org/packages/e2/2d/242c1bebadb960b704066288ae28660da3de7fb5d8f52f655e080e7ffbbf/pysimdjson-7.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fdbbf4246cac27dac38043da8f4d82a46d434b5bc3a4e54c0a55de1dd92631ae", size = 1655651, upload-time = "2025-06-28T20:36:55.336Z" }, + { url = "https://files.pythonhosted.org/packages/49/86/3b25e77ae2998342d2bd376eb58baf17b35e6c2fdb9184e8bc8c31ebfafe/pysimdjson-7.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77bbf9afdea8a9aa220cbf29115cc32e81207f9e8e07963ea145ba8d2e8f4053", size = 2771613, upload-time = "2025-06-28T20:36:56.732Z" }, + { url = "https://files.pythonhosted.org/packages/49/d9/3db962802aa5c95a8f89023dcf00eefa30817e9b9862668d5efb91c44d81/pysimdjson-7.0.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43d42ef0660181b67bd833c13bdcbb2743abd40bc348db8f9e788b5d88717459", size = 2819981, upload-time = "2025-06-28T20:36:57.923Z" }, + { url = "https://files.pythonhosted.org/packages/f2/a0/bfbc3c9a1b216cacad74863229c06c576f108e4f67cb6daa3c4d6071a9ff/pysimdjson-7.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13f2820c95d9c74139407921aeec8099e67546ccfcb309561881e877e4a3aa97", size = 3246918, upload-time = "2025-06-28T20:36:59.458Z" }, + { url = "https://files.pythonhosted.org/packages/ed/fc/1d21538d1fd3e4f2f7a96de605fbcdb1f150ff0eb49ac08f005da83e17c7/pysimdjson-7.0.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f81638ce66a7393ad1b4f5fae6666c417cc01e5ecb81c86ff727349599bbc83f", size = 2524078, upload-time = "2025-06-28T20:37:00.659Z" }, + { url = "https://files.pythonhosted.org/packages/2d/d3/76c05b4d116adcb947955c68700c9e67ee7f748a38d37ba72e5b1109ef1d/pysimdjson-7.0.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5ffe83c4dbfdabea5f2231cc64ff1a62b7ecd18f64cb04a61439a5c24d08a0cd", size = 3662263, upload-time = "2025-06-28T20:37:01.835Z" }, + { url = "https://files.pythonhosted.org/packages/5f/4c/7f4c326f4022babab518e1295446c58c7f72b7bfb242b47e9fae421c3783/pysimdjson-7.0.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:08b576531375fa6b9479b43b5358e5e172490bef8969b0f53d6b6be7c5d7b88a", size = 3576295, upload-time = "2025-06-28T20:37:02.989Z" }, + { url = "https://files.pythonhosted.org/packages/1c/9a/c4df622caf46284dd1a4d6e403dccea2a874623563c63d6e1cec4f54259a/pysimdjson-7.0.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1b7e26580d0030b6f7bb6fddc12e7756f4ffae3a9e4f7a8c3522d783173ac459", size = 3813976, upload-time = "2025-06-28T20:37:04.186Z" }, + { url = "https://files.pythonhosted.org/packages/75/b9/e21a5d1f4060ffeca6026a94599f6b68bf62221dd02a7af5962c73040edc/pysimdjson-7.0.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4a8fb78454cd2936f8e27e8948b56b6e44a766eaa162fef02a1436c2d4570053", size = 4197725, upload-time = "2025-06-28T20:37:05.591Z" }, + { url = "https://files.pythonhosted.org/packages/d8/ed/7e4511cabdcb2931cce174ce0ecf17cf4de6039b4d908daca4d313875f1e/pysimdjson-7.0.2-cp313-cp313-win32.whl", hash = "sha256:ef56eacf050e194d4058d6ed818dbbe40d9ec5dcb182ba93a451cad2467aad27", size = 1529585, upload-time = "2025-06-28T20:37:07.016Z" }, + { url = "https://files.pythonhosted.org/packages/e3/fa/3642b49521007362c9eb228ed472927e020b84d6413efa8fd69fd9f7c6b9/pysimdjson-7.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:4ae000c2d45a1af0303fe151e5204188fcbb23acc6cbdf04ac1062ab80538a1b", size = 1574251, upload-time = "2025-06-28T20:37:08.327Z" }, +] + +[[package]] +name = "pystan" +version = "3.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "clikit" }, + { name = "httpstan" }, + { name = "numpy" }, + { name = "pysimdjson" }, + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8f/9e/cb47952a6ed3c5b12c60c85a0d9b8b07c81f781bbcd87a23abb107c47389/pystan-3.10.1.tar.gz", hash = "sha256:acac030ee6e95afd21373f63d443e6219d373b3c3be8263f0683c554f064551c", size = 13775, upload-time = "2026-03-12T18:28:14.445Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/59/8855e2e51815f49407c27a5f3beb00e44267607139da5ad5b5aaec6aaa60/pystan-3.10.1-py3-none-any.whl", hash = "sha256:52286149c123f769430a940f4d10c4faddcc2bcee056f33fba1337d344e93fb0", size = 13878, upload-time = "2026-03-12T18:28:13.361Z" }, +] + [[package]] name = "pytest" version = "9.0.2" @@ -1617,15 +2143,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, ] -[[package]] -name = "stanpy" -version = "0.2.11" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e3/cc/74da52b7de8ee281a2fdffcfba578a8b6b317a7801631d3800cb6a21ec80/stanpy-0.2.11.tar.gz", hash = "sha256:6b6354d042a705b9657392a1cee8c17ebacc2e43de8ed5dfd44e6cab52822530", size = 7809464, upload-time = "2022-04-09T10:49:48.038Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/06/19356ea9f4f20d997d5c2dc5f32d72cd220038f3d5dd2f33957d56cb6ba3/stanpy-0.2.11-py2.py3-none-any.whl", hash = "sha256:64fec89761e56a520d124f9487c365f78545145f0d1fee64c1e085d1f6c4adff", size = 28068, upload-time = "2022-04-09T10:49:34.812Z" }, -] - [[package]] name = "sympy" version = "1.14.0" @@ -1792,3 +2309,120 @@ sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a3 wheels = [ { url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" }, ] + +[[package]] +name = "webargs" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "marshmallow" }, + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/37/64/17afc4e6f47eef154a553c6e56adcc9f1ac3003305c7df978d11aa62937e/webargs-8.7.1.tar.gz", hash = "sha256:799bf9039c76c23fd8dc1951107a75a9e561203c15d6ae8f89c1e46e234636c1", size = 97351, upload-time = "2025-10-29T16:07:50.066Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/ef/b0d17f3943429358184449771b592e0e1d33bbeaa6ed326434a95eac187b/webargs-8.7.1-py3-none-any.whl", hash = "sha256:a184aed9d2509e6e14ab99ee3e9dc3a614c7070affe94cd4dfdb0d002e0a6e5f", size = 32500, upload-time = "2025-10-29T16:07:47.895Z" }, +] + +[[package]] +name = "yarl" +version = "1.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "multidict" }, + { name = "propcache" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/6e/beb1beec874a72f23815c1434518bfc4ed2175065173fb138c3705f658d4/yarl-1.23.0.tar.gz", hash = "sha256:53b1ea6ca88ebd4420379c330aea57e258408dd0df9af0992e5de2078dc9f5d5", size = 194676, upload-time = "2026-03-01T22:07:53.373Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/8a/94615bc31022f711add374097ad4144d569e95ff3c38d39215d07ac153a0/yarl-1.23.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1932b6b8bba8d0160a9d1078aae5838a66039e8832d41d2992daa9a3a08f7860", size = 124737, upload-time = "2026-03-01T22:05:12.897Z" }, + { url = "https://files.pythonhosted.org/packages/e3/6f/c6554045d59d64052698add01226bc867b52fe4a12373415d7991fdca95d/yarl-1.23.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:411225bae281f114067578891bc75534cfb3d92a3b4dfef7a6ca78ba354e6069", size = 87029, upload-time = "2026-03-01T22:05:14.376Z" }, + { url = "https://files.pythonhosted.org/packages/19/2a/725ecc166d53438bc88f76822ed4b1e3b10756e790bafd7b523fe97c322d/yarl-1.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:13a563739ae600a631c36ce096615fe307f131344588b0bc0daec108cdb47b25", size = 86310, upload-time = "2026-03-01T22:05:15.71Z" }, + { url = "https://files.pythonhosted.org/packages/99/30/58260ed98e6ff7f90ba84442c1ddd758c9170d70327394a6227b310cd60f/yarl-1.23.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9cbf44c5cb4a7633d078788e1b56387e3d3cf2b8139a3be38040b22d6c3221c8", size = 97587, upload-time = "2026-03-01T22:05:17.384Z" }, + { url = "https://files.pythonhosted.org/packages/76/0a/8b08aac08b50682e65759f7f8dde98ae8168f72487e7357a5d684c581ef9/yarl-1.23.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:53ad387048f6f09a8969631e4de3f1bf70c50e93545d64af4f751b2498755072", size = 92528, upload-time = "2026-03-01T22:05:18.804Z" }, + { url = "https://files.pythonhosted.org/packages/52/07/0b7179101fe5f8385ec6c6bb5d0cb9f76bd9fb4a769591ab6fb5cdbfc69a/yarl-1.23.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4a59ba56f340334766f3a4442e0efd0af895fae9e2b204741ef885c446b3a1a8", size = 105339, upload-time = "2026-03-01T22:05:20.235Z" }, + { url = "https://files.pythonhosted.org/packages/d3/8a/36d82869ab5ec829ca8574dfcb92b51286fcfb1e9c7a73659616362dc880/yarl-1.23.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:803a3c3ce4acc62eaf01eaca1208dcf0783025ef27572c3336502b9c232005e7", size = 105061, upload-time = "2026-03-01T22:05:22.268Z" }, + { url = "https://files.pythonhosted.org/packages/66/3e/868e5c3364b6cee19ff3e1a122194fa4ce51def02c61023970442162859e/yarl-1.23.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3d2bff8f37f8d0f96c7ec554d16945050d54462d6e95414babaa18bfafc7f51", size = 100132, upload-time = "2026-03-01T22:05:23.638Z" }, + { url = "https://files.pythonhosted.org/packages/cf/26/9c89acf82f08a52cb52d6d39454f8d18af15f9d386a23795389d1d423823/yarl-1.23.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c75eb09e8d55bceb4367e83496ff8ef2bc7ea6960efb38e978e8073ea59ecb67", size = 99289, upload-time = "2026-03-01T22:05:25.749Z" }, + { url = "https://files.pythonhosted.org/packages/6f/54/5b0db00d2cb056922356104468019c0a132e89c8d3ab67d8ede9f4483d2a/yarl-1.23.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877b0738624280e34c55680d6054a307aa94f7d52fa0e3034a9cc6e790871da7", size = 96950, upload-time = "2026-03-01T22:05:27.318Z" }, + { url = "https://files.pythonhosted.org/packages/f6/40/10fa93811fd439341fad7e0718a86aca0de9548023bbb403668d6555acab/yarl-1.23.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b5405bb8f0e783a988172993cfc627e4d9d00432d6bbac65a923041edacf997d", size = 93960, upload-time = "2026-03-01T22:05:28.738Z" }, + { url = "https://files.pythonhosted.org/packages/bc/d2/8ae2e6cd77d0805f4526e30ec43b6f9a3dfc542d401ac4990d178e4bf0cf/yarl-1.23.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1c3a3598a832590c5a3ce56ab5576361b5688c12cb1d39429cf5dba30b510760", size = 104703, upload-time = "2026-03-01T22:05:30.438Z" }, + { url = "https://files.pythonhosted.org/packages/2f/0c/b3ceacf82c3fe21183ce35fa2acf5320af003d52bc1fcf5915077681142e/yarl-1.23.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8419ebd326430d1cbb7efb5292330a2cf39114e82df5cc3d83c9a0d5ebeaf2f2", size = 98325, upload-time = "2026-03-01T22:05:31.835Z" }, + { url = "https://files.pythonhosted.org/packages/9d/e0/12900edd28bdab91a69bd2554b85ad7b151f64e8b521fe16f9ad2f56477a/yarl-1.23.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:be61f6fff406ca40e3b1d84716fde398fc08bc63dd96d15f3a14230a0973ed86", size = 105067, upload-time = "2026-03-01T22:05:33.358Z" }, + { url = "https://files.pythonhosted.org/packages/15/61/74bb1182cf79c9bbe4eb6b1f14a57a22d7a0be5e9cedf8e2d5c2086474c3/yarl-1.23.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ceb13c5c858d01321b5d9bb65e4cf37a92169ea470b70fec6f236b2c9dd7e34", size = 100285, upload-time = "2026-03-01T22:05:35.4Z" }, + { url = "https://files.pythonhosted.org/packages/69/7f/cd5ef733f2550de6241bd8bd8c3febc78158b9d75f197d9c7baa113436af/yarl-1.23.0-cp312-cp312-win32.whl", hash = "sha256:fffc45637bcd6538de8b85f51e3df3223e4ad89bccbfca0481c08c7fc8b7ed7d", size = 82359, upload-time = "2026-03-01T22:05:36.811Z" }, + { url = "https://files.pythonhosted.org/packages/f5/be/25216a49daeeb7af2bec0db22d5e7df08ed1d7c9f65d78b14f3b74fd72fc/yarl-1.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:f69f57305656a4852f2a7203efc661d8c042e6cc67f7acd97d8667fb448a426e", size = 87674, upload-time = "2026-03-01T22:05:38.171Z" }, + { url = "https://files.pythonhosted.org/packages/d2/35/aeab955d6c425b227d5b7247eafb24f2653fedc32f95373a001af5dfeb9e/yarl-1.23.0-cp312-cp312-win_arm64.whl", hash = "sha256:6e87a6e8735b44816e7db0b2fbc9686932df473c826b0d9743148432e10bb9b9", size = 81879, upload-time = "2026-03-01T22:05:40.006Z" }, + { url = "https://files.pythonhosted.org/packages/9a/4b/a0a6e5d0ee8a2f3a373ddef8a4097d74ac901ac363eea1440464ccbe0898/yarl-1.23.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:16c6994ac35c3e74fb0ae93323bf8b9c2a9088d55946109489667c510a7d010e", size = 123796, upload-time = "2026-03-01T22:05:41.412Z" }, + { url = "https://files.pythonhosted.org/packages/67/b6/8925d68af039b835ae876db5838e82e76ec87b9782ecc97e192b809c4831/yarl-1.23.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4a42e651629dafb64fd5b0286a3580613702b5809ad3f24934ea87595804f2c5", size = 86547, upload-time = "2026-03-01T22:05:42.841Z" }, + { url = "https://files.pythonhosted.org/packages/ae/50/06d511cc4b8e0360d3c94af051a768e84b755c5eb031b12adaaab6dec6e5/yarl-1.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7c6b9461a2a8b47c65eef63bb1c76a4f1c119618ffa99ea79bc5bb1e46c5821b", size = 85854, upload-time = "2026-03-01T22:05:44.85Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f4/4e30b250927ffdab4db70da08b9b8d2194d7c7b400167b8fbeca1e4701ca/yarl-1.23.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2569b67d616eab450d262ca7cb9f9e19d2f718c70a8b88712859359d0ab17035", size = 98351, upload-time = "2026-03-01T22:05:46.836Z" }, + { url = "https://files.pythonhosted.org/packages/86/fc/4118c5671ea948208bdb1492d8b76bdf1453d3e73df051f939f563e7dcc5/yarl-1.23.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e9d9a4d06d3481eab79803beb4d9bd6f6a8e781ec078ac70d7ef2dcc29d1bea5", size = 92711, upload-time = "2026-03-01T22:05:48.316Z" }, + { url = "https://files.pythonhosted.org/packages/56/11/1ed91d42bd9e73c13dc9e7eb0dd92298d75e7ac4dd7f046ad0c472e231cd/yarl-1.23.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f514f6474e04179d3d33175ed3f3e31434d3130d42ec153540d5b157deefd735", size = 106014, upload-time = "2026-03-01T22:05:50.028Z" }, + { url = "https://files.pythonhosted.org/packages/ce/c9/74e44e056a23fbc33aca71779ef450ca648a5bc472bdad7a82339918f818/yarl-1.23.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:fda207c815b253e34f7e1909840fd14299567b1c0eb4908f8c2ce01a41265401", size = 105557, upload-time = "2026-03-01T22:05:51.416Z" }, + { url = "https://files.pythonhosted.org/packages/66/fe/b1e10b08d287f518994f1e2ff9b6d26f0adeecd8dd7d533b01bab29a3eda/yarl-1.23.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34b6cf500e61c90f305094911f9acc9c86da1a05a7a3f5be9f68817043f486e4", size = 101559, upload-time = "2026-03-01T22:05:52.872Z" }, + { url = "https://files.pythonhosted.org/packages/72/59/c5b8d94b14e3d3c2a9c20cb100119fd534ab5a14b93673ab4cc4a4141ea5/yarl-1.23.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d7504f2b476d21653e4d143f44a175f7f751cd41233525312696c76aa3dbb23f", size = 100502, upload-time = "2026-03-01T22:05:54.954Z" }, + { url = "https://files.pythonhosted.org/packages/77/4f/96976cb54cbfc5c9fd73ed4c51804f92f209481d1fb190981c0f8a07a1d7/yarl-1.23.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:578110dd426f0d209d1509244e6d4a3f1a3e9077655d98c5f22583d63252a08a", size = 98027, upload-time = "2026-03-01T22:05:56.409Z" }, + { url = "https://files.pythonhosted.org/packages/63/6e/904c4f476471afdbad6b7e5b70362fb5810e35cd7466529a97322b6f5556/yarl-1.23.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:609d3614d78d74ebe35f54953c5bbd2ac647a7ddb9c30a5d877580f5e86b22f2", size = 95369, upload-time = "2026-03-01T22:05:58.141Z" }, + { url = "https://files.pythonhosted.org/packages/9d/40/acfcdb3b5f9d68ef499e39e04d25e141fe90661f9d54114556cf83be8353/yarl-1.23.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:4966242ec68afc74c122f8459abd597afd7d8a60dc93d695c1334c5fd25f762f", size = 105565, upload-time = "2026-03-01T22:06:00.286Z" }, + { url = "https://files.pythonhosted.org/packages/5e/c6/31e28f3a6ba2869c43d124f37ea5260cac9c9281df803c354b31f4dd1f3c/yarl-1.23.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:e0fd068364a6759bc794459f0a735ab151d11304346332489c7972bacbe9e72b", size = 99813, upload-time = "2026-03-01T22:06:01.712Z" }, + { url = "https://files.pythonhosted.org/packages/08/1f/6f65f59e72d54aa467119b63fc0b0b1762eff0232db1f4720cd89e2f4a17/yarl-1.23.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:39004f0ad156da43e86aa71f44e033de68a44e5a31fc53507b36dd253970054a", size = 105632, upload-time = "2026-03-01T22:06:03.188Z" }, + { url = "https://files.pythonhosted.org/packages/a3/c4/18b178a69935f9e7a338127d5b77d868fdc0f0e49becd286d51b3a18c61d/yarl-1.23.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e5723c01a56c5028c807c701aa66722916d2747ad737a046853f6c46f4875543", size = 101895, upload-time = "2026-03-01T22:06:04.651Z" }, + { url = "https://files.pythonhosted.org/packages/8f/54/f5b870b5505663911dba950a8e4776a0dbd51c9c54c0ae88e823e4b874a0/yarl-1.23.0-cp313-cp313-win32.whl", hash = "sha256:1b6b572edd95b4fa8df75de10b04bc81acc87c1c7d16bcdd2035b09d30acc957", size = 82356, upload-time = "2026-03-01T22:06:06.04Z" }, + { url = "https://files.pythonhosted.org/packages/7a/84/266e8da36879c6edcd37b02b547e2d9ecdfea776be49598e75696e3316e1/yarl-1.23.0-cp313-cp313-win_amd64.whl", hash = "sha256:baaf55442359053c7d62f6f8413a62adba3205119bcb6f49594894d8be47e5e3", size = 87515, upload-time = "2026-03-01T22:06:08.107Z" }, + { url = "https://files.pythonhosted.org/packages/00/fd/7e1c66efad35e1649114fa13f17485f62881ad58edeeb7f49f8c5e748bf9/yarl-1.23.0-cp313-cp313-win_arm64.whl", hash = "sha256:fb4948814a2a98e3912505f09c9e7493b1506226afb1f881825368d6fb776ee3", size = 81785, upload-time = "2026-03-01T22:06:10.181Z" }, + { url = "https://files.pythonhosted.org/packages/9c/fc/119dd07004f17ea43bb91e3ece6587759edd7519d6b086d16bfbd3319982/yarl-1.23.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:aecfed0b41aa72b7881712c65cf764e39ce2ec352324f5e0837c7048d9e6daaa", size = 130719, upload-time = "2026-03-01T22:06:11.708Z" }, + { url = "https://files.pythonhosted.org/packages/e6/0d/9f2348502fbb3af409e8f47730282cd6bc80dec6630c1e06374d882d6eb2/yarl-1.23.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a41bcf68efd19073376eb8cf948b8d9be0af26256403e512bb18f3966f1f9120", size = 89690, upload-time = "2026-03-01T22:06:13.429Z" }, + { url = "https://files.pythonhosted.org/packages/50/93/e88f3c80971b42cfc83f50a51b9d165a1dbf154b97005f2994a79f212a07/yarl-1.23.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:cde9a2ecd91668bcb7f077c4966d8ceddb60af01b52e6e3e2680e4cf00ad1a59", size = 89851, upload-time = "2026-03-01T22:06:15.53Z" }, + { url = "https://files.pythonhosted.org/packages/1c/07/61c9dd8ba8f86473263b4036f70fb594c09e99c0d9737a799dfd8bc85651/yarl-1.23.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5023346c4ee7992febc0068e7593de5fa2bf611848c08404b35ebbb76b1b0512", size = 95874, upload-time = "2026-03-01T22:06:17.553Z" }, + { url = "https://files.pythonhosted.org/packages/9e/e9/f9ff8ceefba599eac6abddcfb0b3bee9b9e636e96dbf54342a8577252379/yarl-1.23.0-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:d1009abedb49ae95b136a8904a3f71b342f849ffeced2d3747bf29caeda218c4", size = 88710, upload-time = "2026-03-01T22:06:19.004Z" }, + { url = "https://files.pythonhosted.org/packages/eb/78/0231bfcc5d4c8eec220bc2f9ef82cb4566192ea867a7c5b4148f44f6cbcd/yarl-1.23.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a8d00f29b42f534cc8aa3931cfe773b13b23e561e10d2b26f27a8d309b0e82a1", size = 101033, upload-time = "2026-03-01T22:06:21.203Z" }, + { url = "https://files.pythonhosted.org/packages/cd/9b/30ea5239a61786f18fd25797151a17fbb3be176977187a48d541b5447dd4/yarl-1.23.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:95451e6ce06c3e104556d73b559f5da6c34a069b6b62946d3ad66afcd51642ea", size = 100817, upload-time = "2026-03-01T22:06:22.738Z" }, + { url = "https://files.pythonhosted.org/packages/62/e2/a4980481071791bc83bce2b7a1a1f7adcabfa366007518b4b845e92eeee3/yarl-1.23.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:531ef597132086b6cf96faa7c6c1dcd0361dd5f1694e5cc30375907b9b7d3ea9", size = 97482, upload-time = "2026-03-01T22:06:24.21Z" }, + { url = "https://files.pythonhosted.org/packages/e5/1e/304a00cf5f6100414c4b5a01fc7ff9ee724b62158a08df2f8170dfc72a2d/yarl-1.23.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:88f9fb0116fbfcefcab70f85cf4b74a2b6ce5d199c41345296f49d974ddb4123", size = 95949, upload-time = "2026-03-01T22:06:25.697Z" }, + { url = "https://files.pythonhosted.org/packages/68/03/093f4055ed4cae649ac53bca3d180bd37102e9e11d048588e9ab0c0108d0/yarl-1.23.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:e7b0460976dc75cb87ad9cc1f9899a4b97751e7d4e77ab840fc9b6d377b8fd24", size = 95839, upload-time = "2026-03-01T22:06:27.309Z" }, + { url = "https://files.pythonhosted.org/packages/b9/28/4c75ebb108f322aa8f917ae10a8ffa4f07cae10a8a627b64e578617df6a0/yarl-1.23.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:115136c4a426f9da976187d238e84139ff6b51a20839aa6e3720cd1026d768de", size = 90696, upload-time = "2026-03-01T22:06:29.048Z" }, + { url = "https://files.pythonhosted.org/packages/23/9c/42c2e2dd91c1a570402f51bdf066bfdb1241c2240ba001967bad778e77b7/yarl-1.23.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:ead11956716a940c1abc816b7df3fa2b84d06eaed8832ca32f5c5e058c65506b", size = 100865, upload-time = "2026-03-01T22:06:30.525Z" }, + { url = "https://files.pythonhosted.org/packages/74/05/1bcd60a8a0a914d462c305137246b6f9d167628d73568505fce3f1cb2e65/yarl-1.23.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:fe8f8f5e70e6dbdfca9882cd9deaac058729bcf323cf7a58660901e55c9c94f6", size = 96234, upload-time = "2026-03-01T22:06:32.692Z" }, + { url = "https://files.pythonhosted.org/packages/90/b2/f52381aac396d6778ce516b7bc149c79e65bfc068b5de2857ab69eeea3b7/yarl-1.23.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:a0e317df055958a0c1e79e5d2aa5a5eaa4a6d05a20d4b0c9c3f48918139c9fc6", size = 100295, upload-time = "2026-03-01T22:06:34.268Z" }, + { url = "https://files.pythonhosted.org/packages/e5/e8/638bae5bbf1113a659b2435d8895474598afe38b4a837103764f603aba56/yarl-1.23.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6f0fd84de0c957b2d280143522c4f91a73aada1923caee763e24a2b3fda9f8a5", size = 97784, upload-time = "2026-03-01T22:06:35.864Z" }, + { url = "https://files.pythonhosted.org/packages/80/25/a3892b46182c586c202629fc2159aa13975d3741d52ebd7347fd501d48d5/yarl-1.23.0-cp313-cp313t-win32.whl", hash = "sha256:93a784271881035ab4406a172edb0faecb6e7d00f4b53dc2f55919d6c9688595", size = 88313, upload-time = "2026-03-01T22:06:37.39Z" }, + { url = "https://files.pythonhosted.org/packages/43/68/8c5b36aa5178900b37387937bc2c2fe0e9505537f713495472dcf6f6fccc/yarl-1.23.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dd00607bffbf30250fe108065f07453ec124dbf223420f57f5e749b04295e090", size = 94932, upload-time = "2026-03-01T22:06:39.579Z" }, + { url = "https://files.pythonhosted.org/packages/c6/cc/d79ba8292f51f81f4dc533a8ccfb9fc6992cabf0998ed3245de7589dc07c/yarl-1.23.0-cp313-cp313t-win_arm64.whl", hash = "sha256:ac09d42f48f80c9ee1635b2fcaa819496a44502737660d3c0f2ade7526d29144", size = 84786, upload-time = "2026-03-01T22:06:41.988Z" }, + { url = "https://files.pythonhosted.org/packages/90/98/b85a038d65d1b92c3903ab89444f48d3cee490a883477b716d7a24b1a78c/yarl-1.23.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:21d1b7305a71a15b4794b5ff22e8eef96ff4a6d7f9657155e5aa419444b28912", size = 124455, upload-time = "2026-03-01T22:06:43.615Z" }, + { url = "https://files.pythonhosted.org/packages/39/54/bc2b45559f86543d163b6e294417a107bb87557609007c007ad889afec18/yarl-1.23.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:85610b4f27f69984932a7abbe52703688de3724d9f72bceb1cca667deff27474", size = 86752, upload-time = "2026-03-01T22:06:45.425Z" }, + { url = "https://files.pythonhosted.org/packages/24/f9/e8242b68362bffe6fb536c8db5076861466fc780f0f1b479fc4ffbebb128/yarl-1.23.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:23f371bd662cf44a7630d4d113101eafc0cfa7518a2760d20760b26021454719", size = 86291, upload-time = "2026-03-01T22:06:46.974Z" }, + { url = "https://files.pythonhosted.org/packages/ea/d8/d1cb2378c81dd729e98c716582b1ccb08357e8488e4c24714658cc6630e8/yarl-1.23.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c4a80f77dc1acaaa61f0934176fccca7096d9b1ff08c8ba9cddf5ae034a24319", size = 99026, upload-time = "2026-03-01T22:06:48.459Z" }, + { url = "https://files.pythonhosted.org/packages/0a/ff/7196790538f31debe3341283b5b0707e7feb947620fc5e8236ef28d44f72/yarl-1.23.0-cp314-cp314-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:bd654fad46d8d9e823afbb4f87c79160b5a374ed1ff5bde24e542e6ba8f41434", size = 92355, upload-time = "2026-03-01T22:06:50.306Z" }, + { url = "https://files.pythonhosted.org/packages/c1/56/25d58c3eddde825890a5fe6aa1866228377354a3c39262235234ab5f616b/yarl-1.23.0-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:682bae25f0a0dd23a056739f23a134db9f52a63e2afd6bfb37ddc76292bbd723", size = 106417, upload-time = "2026-03-01T22:06:52.1Z" }, + { url = "https://files.pythonhosted.org/packages/51/8a/882c0e7bc8277eb895b31bce0138f51a1ba551fc2e1ec6753ffc1e7c1377/yarl-1.23.0-cp314-cp314-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a82836cab5f197a0514235aaf7ffccdc886ccdaa2324bc0aafdd4ae898103039", size = 106422, upload-time = "2026-03-01T22:06:54.424Z" }, + { url = "https://files.pythonhosted.org/packages/42/2b/fef67d616931055bf3d6764885990a3ac647d68734a2d6a9e1d13de437a2/yarl-1.23.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1c57676bdedc94cd3bc37724cf6f8cd2779f02f6aba48de45feca073e714fe52", size = 101915, upload-time = "2026-03-01T22:06:55.895Z" }, + { url = "https://files.pythonhosted.org/packages/18/6a/530e16aebce27c5937920f3431c628a29a4b6b430fab3fd1c117b26ff3f6/yarl-1.23.0-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c7f8dc16c498ff06497c015642333219871effba93e4a2e8604a06264aca5c5c", size = 100690, upload-time = "2026-03-01T22:06:58.21Z" }, + { url = "https://files.pythonhosted.org/packages/88/08/93749219179a45e27b036e03260fda05190b911de8e18225c294ac95bbc9/yarl-1.23.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:5ee586fb17ff8f90c91cf73c6108a434b02d69925f44f5f8e0d7f2f260607eae", size = 98750, upload-time = "2026-03-01T22:06:59.794Z" }, + { url = "https://files.pythonhosted.org/packages/d9/cf/ea424a004969f5d81a362110a6ac1496d79efdc6d50c2c4b2e3ea0fc2519/yarl-1.23.0-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:17235362f580149742739cc3828b80e24029d08cbb9c4bda0242c7b5bc610a8e", size = 94685, upload-time = "2026-03-01T22:07:01.375Z" }, + { url = "https://files.pythonhosted.org/packages/e2/b7/14341481fe568e2b0408bcf1484c652accafe06a0ade9387b5d3fd9df446/yarl-1.23.0-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:0793e2bd0cf14234983bbb371591e6bea9e876ddf6896cdcc93450996b0b5c85", size = 106009, upload-time = "2026-03-01T22:07:03.151Z" }, + { url = "https://files.pythonhosted.org/packages/0a/e6/5c744a9b54f4e8007ad35bce96fbc9218338e84812d36f3390cea616881a/yarl-1.23.0-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:3650dc2480f94f7116c364096bc84b1d602f44224ef7d5c7208425915c0475dd", size = 100033, upload-time = "2026-03-01T22:07:04.701Z" }, + { url = "https://files.pythonhosted.org/packages/0c/23/e3bfc188d0b400f025bc49d99793d02c9abe15752138dcc27e4eaf0c4a9e/yarl-1.23.0-cp314-cp314-musllinux_1_2_s390x.whl", hash = "sha256:f40e782d49630ad384db66d4d8b73ff4f1b8955dc12e26b09a3e3af064b3b9d6", size = 106483, upload-time = "2026-03-01T22:07:06.231Z" }, + { url = "https://files.pythonhosted.org/packages/72/42/f0505f949a90b3f8b7a363d6cbdf398f6e6c58946d85c6d3a3bc70595b26/yarl-1.23.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:94f8575fbdf81749008d980c17796097e645574a3b8c28ee313931068dad14fe", size = 102175, upload-time = "2026-03-01T22:07:08.4Z" }, + { url = "https://files.pythonhosted.org/packages/aa/65/b39290f1d892a9dd671d1c722014ca062a9c35d60885d57e5375db0404b5/yarl-1.23.0-cp314-cp314-win32.whl", hash = "sha256:c8aa34a5c864db1087d911a0b902d60d203ea3607d91f615acd3f3108ac32169", size = 83871, upload-time = "2026-03-01T22:07:09.968Z" }, + { url = "https://files.pythonhosted.org/packages/a9/5b/9b92f54c784c26e2a422e55a8d2607ab15b7ea3349e28359282f84f01d43/yarl-1.23.0-cp314-cp314-win_amd64.whl", hash = "sha256:63e92247f383c85ab00dd0091e8c3fa331a96e865459f5ee80353c70a4a42d70", size = 89093, upload-time = "2026-03-01T22:07:11.501Z" }, + { url = "https://files.pythonhosted.org/packages/e0/7d/8a84dc9381fd4412d5e7ff04926f9865f6372b4c2fd91e10092e65d29eb8/yarl-1.23.0-cp314-cp314-win_arm64.whl", hash = "sha256:70efd20be968c76ece7baa8dafe04c5be06abc57f754d6f36f3741f7aa7a208e", size = 83384, upload-time = "2026-03-01T22:07:13.069Z" }, + { url = "https://files.pythonhosted.org/packages/dd/8d/d2fad34b1c08aa161b74394183daa7d800141aaaee207317e82c790b418d/yarl-1.23.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:9a18d6f9359e45722c064c97464ec883eb0e0366d33eda61cb19a244bf222679", size = 131019, upload-time = "2026-03-01T22:07:14.903Z" }, + { url = "https://files.pythonhosted.org/packages/19/ff/33009a39d3ccf4b94d7d7880dfe17fb5816c5a4fe0096d9b56abceea9ac7/yarl-1.23.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:2803ed8b21ca47a43da80a6fd1ed3019d30061f7061daa35ac54f63933409412", size = 89894, upload-time = "2026-03-01T22:07:17.372Z" }, + { url = "https://files.pythonhosted.org/packages/0c/f1/dab7ac5e7306fb79c0190766a3c00b4cb8d09a1f390ded68c85a5934faf5/yarl-1.23.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:394906945aa8b19fc14a61cf69743a868bb8c465efe85eee687109cc540b98f4", size = 89979, upload-time = "2026-03-01T22:07:19.361Z" }, + { url = "https://files.pythonhosted.org/packages/aa/b1/08e95f3caee1fad6e65017b9f26c1d79877b502622d60e517de01e72f95d/yarl-1.23.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:71d006bee8397a4a89f469b8deb22469fe7508132d3c17fa6ed871e79832691c", size = 95943, upload-time = "2026-03-01T22:07:21.266Z" }, + { url = "https://files.pythonhosted.org/packages/c0/cc/6409f9018864a6aa186c61175b977131f373f1988e198e031236916e87e4/yarl-1.23.0-cp314-cp314t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:62694e275c93d54f7ccedcfef57d42761b2aad5234b6be1f3e3026cae4001cd4", size = 88786, upload-time = "2026-03-01T22:07:23.129Z" }, + { url = "https://files.pythonhosted.org/packages/76/40/cc22d1d7714b717fde2006fad2ced5efe5580606cb059ae42117542122f3/yarl-1.23.0-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a31de1613658308efdb21ada98cbc86a97c181aa050ba22a808120bb5be3ab94", size = 101307, upload-time = "2026-03-01T22:07:24.689Z" }, + { url = "https://files.pythonhosted.org/packages/8f/0d/476c38e85ddb4c6ec6b20b815bdd779aa386a013f3d8b85516feee55c8dc/yarl-1.23.0-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:fb1e8b8d66c278b21d13b0a7ca22c41dd757a7c209c6b12c313e445c31dd3b28", size = 100904, upload-time = "2026-03-01T22:07:26.287Z" }, + { url = "https://files.pythonhosted.org/packages/72/32/0abe4a76d59adf2081dcb0397168553ece4616ada1c54d1c49d8936c74f8/yarl-1.23.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50f9d8d531dfb767c565f348f33dd5139a6c43f5cbdf3f67da40d54241df93f6", size = 97728, upload-time = "2026-03-01T22:07:27.906Z" }, + { url = "https://files.pythonhosted.org/packages/b7/35/7b30f4810fba112f60f5a43237545867504e15b1c7647a785fbaf588fac2/yarl-1.23.0-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:575aa4405a656e61a540f4a80eaa5260f2a38fff7bfdc4b5f611840d76e9e277", size = 95964, upload-time = "2026-03-01T22:07:30.198Z" }, + { url = "https://files.pythonhosted.org/packages/2d/86/ed7a73ab85ef00e8bb70b0cb5421d8a2a625b81a333941a469a6f4022828/yarl-1.23.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:041b1a4cefacf65840b4e295c6985f334ba83c30607441ae3cf206a0eed1a2e4", size = 95882, upload-time = "2026-03-01T22:07:32.132Z" }, + { url = "https://files.pythonhosted.org/packages/19/90/d56967f61a29d8498efb7afb651e0b2b422a1e9b47b0ab5f4e40a19b699b/yarl-1.23.0-cp314-cp314t-musllinux_1_2_armv7l.whl", hash = "sha256:d38c1e8231722c4ce40d7593f28d92b5fc72f3e9774fe73d7e800ec32299f63a", size = 90797, upload-time = "2026-03-01T22:07:34.404Z" }, + { url = "https://files.pythonhosted.org/packages/72/00/8b8f76909259f56647adb1011d7ed8b321bcf97e464515c65016a47ecdf0/yarl-1.23.0-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:d53834e23c015ee83a99377db6e5e37d8484f333edb03bd15b4bc312cc7254fb", size = 101023, upload-time = "2026-03-01T22:07:35.953Z" }, + { url = "https://files.pythonhosted.org/packages/ac/e2/cab11b126fb7d440281b7df8e9ddbe4851e70a4dde47a202b6642586b8d9/yarl-1.23.0-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:2e27c8841126e017dd2a054a95771569e6070b9ee1b133366d8b31beb5018a41", size = 96227, upload-time = "2026-03-01T22:07:37.594Z" }, + { url = "https://files.pythonhosted.org/packages/c2/9b/2c893e16bfc50e6b2edf76c1a9eb6cb0c744346197e74c65e99ad8d634d0/yarl-1.23.0-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:76855800ac56f878847a09ce6dba727c93ca2d89c9e9d63002d26b916810b0a2", size = 100302, upload-time = "2026-03-01T22:07:39.334Z" }, + { url = "https://files.pythonhosted.org/packages/28/ec/5498c4e3a6d5f1003beb23405671c2eb9cdbf3067d1c80f15eeafe301010/yarl-1.23.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e09fd068c2e169a7070d83d3bde728a4d48de0549f975290be3c108c02e499b4", size = 98202, upload-time = "2026-03-01T22:07:41.717Z" }, + { url = "https://files.pythonhosted.org/packages/fe/c3/cd737e2d45e70717907f83e146f6949f20cc23cd4bf7b2688727763aa458/yarl-1.23.0-cp314-cp314t-win32.whl", hash = "sha256:73309162a6a571d4cbd3b6a1dcc703c7311843ae0d1578df6f09be4e98df38d4", size = 90558, upload-time = "2026-03-01T22:07:43.433Z" }, + { url = "https://files.pythonhosted.org/packages/e1/19/3774d162f6732d1cfb0b47b4140a942a35ca82bb19b6db1f80e9e7bdc8f8/yarl-1.23.0-cp314-cp314t-win_amd64.whl", hash = "sha256:4503053d296bc6e4cbd1fad61cf3b6e33b939886c4f249ba7c78b602214fabe2", size = 97610, upload-time = "2026-03-01T22:07:45.773Z" }, + { url = "https://files.pythonhosted.org/packages/51/47/3fa2286c3cb162c71cdb34c4224d5745a1ceceb391b2bd9b19b668a8d724/yarl-1.23.0-cp314-cp314t-win_arm64.whl", hash = "sha256:44bb7bef4ea409384e3f8bc36c063d77ea1b8d4a5b2706956c0d6695f07dcc25", size = 86041, upload-time = "2026-03-01T22:07:49.026Z" }, + { url = "https://files.pythonhosted.org/packages/69/68/c8739671f5699c7dc470580a4f821ef37c32c4cb0b047ce223a7f115757f/yarl-1.23.0-py3-none-any.whl", hash = "sha256:a2df6afe50dea8ae15fa34c9f824a3ee958d785fd5d089063d960bae1daa0a3f", size = 48288, upload-time = "2026-03-01T22:07:51.388Z" }, +] From 39a26d54a5f0d3f53e6340958c64aca8cd922093 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Mon, 27 Apr 2026 08:22:29 -0400 Subject: [PATCH 35/60] Add basic stan model --- examples/stan/.gitignore | 1 + examples/stan/schools.py | 85 +++++++++++++++++++++++++++++++ src/nak_torch/tools/__init__.py | 6 +++ src/nak_torch/tools/stan_tools.py | 49 ++++++++++++++++++ 4 files changed, 141 insertions(+) create mode 100644 examples/stan/.gitignore create mode 100644 examples/stan/schools.py create mode 100644 src/nak_torch/tools/stan_tools.py diff --git a/examples/stan/.gitignore b/examples/stan/.gitignore new file mode 100644 index 0000000..c795b05 --- /dev/null +++ b/examples/stan/.gitignore @@ -0,0 +1 @@ +build \ No newline at end of file diff --git a/examples/stan/schools.py b/examples/stan/schools.py new file mode 100644 index 0000000..7f5a61d --- /dev/null +++ b/examples/stan/schools.py @@ -0,0 +1,85 @@ +# %% +import nest_asyncio +import torch +import nak_torch +from nak_torch.tools import stan_tools +from nak_torch.algorithms import MSIP, SVGD +from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientFree + +nest_asyncio.apply() +import stan # noqa: E402 + +# %% +# Example from https://github.com/stan-dev/pystan +schools_code = """ +data { + int J; // number of schools + array[J] real y; // estimated treatment effects + array[J] real sigma; // standard error of effect estimates +} +parameters { + real mu; // population treatment effect + real log_tau; // standard deviation in treatment effects + vector[J] eta; // unscaled deviation from mu by school +} +transformed parameters { + vector[J] theta = mu + exp(log_tau) * eta; // school treatment effects +} +model { + target += normal_lpdf(eta | 0, 1); // prior log-density + target += normal_lpdf(log_tau | 5, 1); + target += normal_lpdf(mu | 0, 10); + target += normal_lpdf(y | theta, sigma); // log-likelihood +} +""" + +schools_data = { + "J": 8, + "y": [28, 8, -3, 7, -1, 1, 18, 12], + "sigma": [15, 10, 16, 11, 9, 11, 10, 18], +} + +posterior = stan.build(schools_code, data=schools_data) + +# %% +# Ten dimensional (mu, tau, eta): theta is a constrained parameter. +model = stan_tools.StanModel(posterior, dim=10) + +# %% +# Test evaluation of the pdf and logpdf +pts = torch.randn((100, model.dim)) +pdfs = model.log_dens_batch(pts, None) +grad_log_pdfs = model.grad_log_dens_batch(pts, None) +grad_log_pdfs_2, pdfs_2 = model.grad_val_log_dens_batch(pts, None) + +# %% +GRADIENT_DECAY = 0.95 +N_PARTICLES = 100 +KERNEL_DIAG_INFL = 1e-6 +KERNEL_LENGTHSCALE = 1e-2 +target_msip_fr = MSIPFredholm(GRADIENT_DECAY, model.grad_val_log_dens_batch) +init_eta = torch.randn((N_PARTICLES, 8)) +init_log_tau = torch.randn((N_PARTICLES, 1)) + 5 +init_mu = torch.randn((N_PARTICLES, 1)) * 10 +init_particles = torch.column_stack((init_mu, init_log_tau, init_eta)) +msip = MSIP( + model.dim, + N_PARTICLES, + kernel_diag_infl=KERNEL_DIAG_INFL, + kernel_lengthscale=KERNEL_LENGTHSCALE, +) + +# %% +N_STEPS = 1000 +LR = 1e-3 +trajectories_msip_fr = nak_torch.nak( + target_msip_fr, + msip, + N_STEPS, + LR, + init_particles=init_particles, + bounds=(-100.0, 100.0), +) +trajectories_pts_msip_fr, trajectories_wts_msip_fr = trajectories_msip_fr + +# %% diff --git a/src/nak_torch/tools/__init__.py b/src/nak_torch/tools/__init__.py index 5b7b872..bb4daaa 100644 --- a/src/nak_torch/tools/__init__.py +++ b/src/nak_torch/tools/__init__.py @@ -19,7 +19,13 @@ "adaptive_step", "metrics", ] + if importlib.util.find_spec("pyro") is not None: from . import pyro_tools # noqa: F401 __all__.append("pyro_tools") + +if importlib.util.find_spec("stanpy") is not None: + from . import stan_tools # noqa: F401 + + __all__.append("stan_tools") diff --git a/src/nak_torch/tools/stan_tools.py b/src/nak_torch/tools/stan_tools.py new file mode 100644 index 0000000..97e5deb --- /dev/null +++ b/src/nak_torch/tools/stan_tools.py @@ -0,0 +1,49 @@ +from typing import Optional + +import stan.model +import torch + +from nak_torch.tools.types import BatchPtType, BatchType, NAKTarget + + +class StanModel(NAKTarget): + dim: int + + def __init__(self, model: stan.model.Model, dim: Optional[int] = None): + if dim is None: + all_dims = model.dims + if any(len(d) > 1 for d in all_dims): + raise ValueError( + f"Can currently only handle models with one-dimensional variables. Got dims {all_dims}" + ) + self.dim = sum(1 if len(x) == 0 else x[0] for x in all_dims) + else: + self.dim = dim + self.model = model + + def log_dens_batch(self, theta: BatchPtType, _) -> BatchType: + out = torch.empty(theta.shape[0], dtype=theta.dtype, device="cpu") + for theta_idx in range(theta.shape[0]): + th = theta[theta_idx].cpu().tolist() + out[theta_idx] = self.model.log_prob(th) + return out.to(device=theta.device) + + def grad_log_dens_batch(self, theta: BatchPtType, _) -> BatchPtType: + out = torch.empty_like(theta, device="cpu") + for theta_idx in range(theta.shape[0]): + th = theta[theta_idx].cpu().tolist() + out_i = self.model.grad_log_prob(th) + out[theta_idx] = torch.as_tensor(out_i) + return out.to(device=theta.device) + + def grad_val_log_dens_batch( + self, theta: BatchPtType, _ + ) -> tuple[BatchPtType, BatchType]: + out_grad = torch.empty_like(theta, device="cpu") + out = torch.empty(theta.shape[0], device="cpu") + for theta_idx in range(theta.shape[0]): + th = theta[theta_idx].cpu().tolist() + out_grad_i = self.model.grad_log_prob(th) + out_grad[theta_idx] = torch.as_tensor(out_grad_i) + out[theta_idx] = self.model.log_prob(th) + return out_grad.to(device=theta.device), out.to(device=theta.device) From d1c17979f7e5fca4039405cbbfab4416e776f9dd Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Mon, 27 Apr 2026 10:50:36 -0400 Subject: [PATCH 36/60] Change version dep of stan --- pyproject.toml | 5 ++++- uv.lock | 8 ++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3bdd0bc..32c5f4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ examples = [ "matplotlib>=3.10.8", "posteriordb>=0.2.0", "pyro-ppl>=1.9.1", - "pystan>=3.10.1", + "pystan", "scipy>=1.17.1", ] @@ -56,3 +56,6 @@ testpaths = [ # Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or # McCabe complexity (`C901`) by default. ignore = ["F722"] + +[tool.uv.sources] +pystan = { git = "https://github.com/dannys4/pystan", branch = "change_function_interface" } diff --git a/uv.lock b/uv.lock index cd9ce29..82c15e4 100644 --- a/uv.lock +++ b/uv.lock @@ -1244,7 +1244,7 @@ requires-dist = [ { name = "numpy", specifier = ">=2.4.1" }, { name = "posteriordb", marker = "extra == 'examples'", specifier = ">=0.2.0" }, { name = "pyro-ppl", marker = "extra == 'examples'", specifier = ">=1.9.1" }, - { name = "pystan", marker = "extra == 'examples'", specifier = ">=3.10.1" }, + { name = "pystan", marker = "extra == 'examples'", git = "https://github.com/dannys4/pystan?branch=change_function_interface" }, { name = "scipy", marker = "extra == 'examples'", specifier = ">=1.17.1" }, { name = "torch", specifier = ">=2.10" }, { name = "tqdm", specifier = ">=4.67.1" }, @@ -1911,7 +1911,7 @@ wheels = [ [[package]] name = "pystan" version = "3.10.1" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/dannys4/pystan?branch=change_function_interface#d53224c76971aee716e2e7e2d68e3ad250f79e9e" } dependencies = [ { name = "aiohttp" }, { name = "clikit" }, @@ -1920,10 +1920,6 @@ dependencies = [ { name = "pysimdjson" }, { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8f/9e/cb47952a6ed3c5b12c60c85a0d9b8b07c81f781bbcd87a23abb107c47389/pystan-3.10.1.tar.gz", hash = "sha256:acac030ee6e95afd21373f63d443e6219d373b3c3be8263f0683c554f064551c", size = 13775, upload-time = "2026-03-12T18:28:14.445Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c9/59/8855e2e51815f49407c27a5f3beb00e44267607139da5ad5b5aaec6aaa60/pystan-3.10.1-py3-none-any.whl", hash = "sha256:52286149c123f769430a940f4d10c4faddcc2bcee056f33fba1337d344e93fb0", size = 13878, upload-time = "2026-03-12T18:28:13.361Z" }, -] [[package]] name = "pytest" From 5e1621893834c3de2ea9f5dc1daacc432763706e Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Mon, 27 Apr 2026 10:51:14 -0400 Subject: [PATCH 37/60] Adapt to different stan interface --- src/nak_torch/tools/stan_tools.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/src/nak_torch/tools/stan_tools.py b/src/nak_torch/tools/stan_tools.py index 97e5deb..295e7e3 100644 --- a/src/nak_torch/tools/stan_tools.py +++ b/src/nak_torch/tools/stan_tools.py @@ -22,28 +22,20 @@ def __init__(self, model: stan.model.Model, dim: Optional[int] = None): self.model = model def log_dens_batch(self, theta: BatchPtType, _) -> BatchType: - out = torch.empty(theta.shape[0], dtype=theta.dtype, device="cpu") - for theta_idx in range(theta.shape[0]): - th = theta[theta_idx].cpu().tolist() - out[theta_idx] = self.model.log_prob(th) - return out.to(device=theta.device) + device, dtype = theta.device, theta.dtype + out_np = self.model.log_prob(theta.cpu().numpy()) + return torch.as_tensor(out_np, device=device, dtype=dtype) def grad_log_dens_batch(self, theta: BatchPtType, _) -> BatchPtType: - out = torch.empty_like(theta, device="cpu") - for theta_idx in range(theta.shape[0]): - th = theta[theta_idx].cpu().tolist() - out_i = self.model.grad_log_prob(th) - out[theta_idx] = torch.as_tensor(out_i) - return out.to(device=theta.device) + device, dtype = theta.device, theta.dtype + out_np = self.model.grad_log_prob(theta.cpu().numpy()) + return torch.as_tensor(out_np, device=device, dtype=dtype) def grad_val_log_dens_batch( self, theta: BatchPtType, _ ) -> tuple[BatchPtType, BatchType]: - out_grad = torch.empty_like(theta, device="cpu") - out = torch.empty(theta.shape[0], device="cpu") - for theta_idx in range(theta.shape[0]): - th = theta[theta_idx].cpu().tolist() - out_grad_i = self.model.grad_log_prob(th) - out_grad[theta_idx] = torch.as_tensor(out_grad_i) - out[theta_idx] = self.model.log_prob(th) - return out_grad.to(device=theta.device), out.to(device=theta.device) + device, dtype = theta.device, theta.dtype + out_grad_np, out_val_np = self.model.grad_val_log_prob(theta.cpu().numpy()) + return torch.as_tensor( + out_grad_np, device=device, dtype=dtype + ), torch.as_tensor(out_val_np, device=device, dtype=dtype) From 3847703554e18095b4f26199355a43c5accd9080 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Mon, 27 Apr 2026 10:51:27 -0400 Subject: [PATCH 38/60] make schools example work --- examples/stan/schools.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/stan/schools.py b/examples/stan/schools.py index 7f5a61d..fdcc9ab 100644 --- a/examples/stan/schools.py +++ b/examples/stan/schools.py @@ -67,6 +67,7 @@ N_PARTICLES, kernel_diag_infl=KERNEL_DIAG_INFL, kernel_lengthscale=KERNEL_LENGTHSCALE, + kernel_lengthscale_quantile=0.05 ) # %% @@ -83,3 +84,7 @@ trajectories_pts_msip_fr, trajectories_wts_msip_fr = trajectories_msip_fr # %% +msip_fr_end = trajectories_pts_msip_fr[-1] +eta_end = msip_fr_end[:,:8] - init_particles[:,:8] +mean_sq_shift = (msip_fr_end - init_particles).square().sum() / init_particles.square().sum() +print(mean_sq_shift) \ No newline at end of file From d332836c735472cb7528908a3c3602a92b087027 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Mon, 27 Apr 2026 17:21:37 -0400 Subject: [PATCH 39/60] Start on PosteriorDB example --- .gitignore | 1 + examples/stan/pdb_schools.py | 18 ++++++++++++++++++ examples/stan/schools.py | 2 +- 3 files changed, 20 insertions(+), 1 deletion(-) create mode 100644 examples/stan/pdb_schools.py diff --git a/.gitignore b/.gitignore index 3173d81..4452a8b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__ .vscode *.xml *.pdf +.env \ No newline at end of file diff --git a/examples/stan/pdb_schools.py b/examples/stan/pdb_schools.py new file mode 100644 index 0000000..5781d83 --- /dev/null +++ b/examples/stan/pdb_schools.py @@ -0,0 +1,18 @@ +# %% +import nest_asyncio +import os + +nest_asyncio.apply() # See pystan documentation on why you need this when doing jupyter +from posteriordb import PosteriorDatabaseGithub # noqa: E402 + +# %% +if "GITHUB_PAT" not in os.environ.keys(): + raise ValueError("Expected GITHUB_PAT to be in environment. Please add this into, e.g., your .env file.") + +my_pdb = PosteriorDatabaseGithub() +pos = my_pdb.posterior_names() + +# %% +posterior = my_pdb.posterior("eight_schools-eight_schools_centered") + +# %% diff --git a/examples/stan/schools.py b/examples/stan/schools.py index fdcc9ab..4c8c570 100644 --- a/examples/stan/schools.py +++ b/examples/stan/schools.py @@ -6,7 +6,7 @@ from nak_torch.algorithms import MSIP, SVGD from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientFree -nest_asyncio.apply() +nest_asyncio.apply() # See pystan documentation on why you need this when doing jupyter import stan # noqa: E402 # %% From 8be95e6c7529ab93b329f894f1c26f41cede02ad Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Mon, 27 Apr 2026 19:13:17 -0400 Subject: [PATCH 40/60] Work on fixing logistic regression --- examples/logistic_regression/covertype.py | 14 +++++++----- src/nak_torch/tools/types.py | 18 +++++++++------ tests/test_logistic_regression.py | 27 +++++++++++++++++++++++ 3 files changed, 46 insertions(+), 13 deletions(-) create mode 100644 tests/test_logistic_regression.py diff --git a/examples/logistic_regression/covertype.py b/examples/logistic_regression/covertype.py index 924ac6b..6b5dcbf 100644 --- a/examples/logistic_regression/covertype.py +++ b/examples/logistic_regression/covertype.py @@ -51,7 +51,7 @@ def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): # %% data_path = DATA_PATH regression_model = LogisticRegressionModel( - data_path, None, hyperprior_b=0.01, train_proportion=0.8 + data_path, None, hyperprior_b=0.01, train_proportion=0.8, reduction="sum" ) log_dens = regression_model.to_log_dens(use_compiled=True) @@ -101,13 +101,14 @@ def spherical_quad( KERNEL_LENGTHSCALE = 0.1 GRADIENT_DECAY = 0.9 KERNEL_DIAG_INFL = 1e-5 +KERNEL_QUANTILE = 0.01 msip = MSIP( dim=STATE_DIM, n_particles=N_PARTICLES, kernel_diag_infl=KERNEL_DIAG_INFL, kernel_lengthscale=KERNEL_LENGTHSCALE, - kernel_lengthscale_quantile=0.01, + kernel_lengthscale_quantile=KERNEL_QUANTILE, ) target_msip_f = MSIPFredholm(GRADIENT_DECAY, grad_val_log_p) @@ -116,7 +117,7 @@ def spherical_quad( # %% BOUNDS = (-100.0, 100.0) N_STEPS = 6000 -LR_MSIP = 0.05 +LR_MSIP = 0.01 # trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( # target_msip_f, # msip, @@ -186,17 +187,18 @@ def accuracy(coeffs): ) # %% +LR_SVGD = 0.05 trajectories_pts_svgd = nak_torch.nak( target_svgd, svgd, n_steps=N_STEPS, - lr=LR_MSIP, + lr=LR_SVGD, init_particles=init_particles, get_target_args=iter(train_data_loader), - bounds=BOUNDS, + # bounds=BOUNDS, ) # %% -accuracy(trajectories_pts_svgd[-1].mean(dim=0)) +accuracy(trajectories_pts_svgd[-1].mean(0)) # %% diff --git a/src/nak_torch/tools/types.py b/src/nak_torch/tools/types.py index 6f5535c..74257ee 100644 --- a/src/nak_torch/tools/types.py +++ b/src/nak_torch/tools/types.py @@ -259,6 +259,7 @@ def log_dens( params: BatchPtType, data_labels: Optional[tuple[BatchPtType, LabelsType]] = None, ) -> BatchType: + total_N = self.train_data.shape[0] if data_labels is None: data, labels = self.train_data, self.train_labels else: @@ -270,20 +271,23 @@ def log_dens( raise ValueError( f"Got params.shape[1] = {params.shape[1]}, expected {self.dim}" ) - prior_diff = params.clone() - if self.prior_mean is not None: - prior_diff -= self.prior_mean coeffs = params[:, :-1] log_precision = params[:, -1] + prior_diff = coeffs.clone() + if self.prior_mean is not None: + prior_diff -= self.prior_mean precision = torch.exp(log_precision) - hyperprior_term = log_hyperprior(precision) - prior_term = prior_diff.square().sum(dim=-1).mul_(0.5 * precision).neg_() + # Correct for change-of-variables precision using chain rule: + # exp(log_precision) -> log(d_log_precision) = log(d_precision) + log_precision + hyperprior_term = log_hyperprior(precision) + log_precision + prior_term = prior_diff.square().sum(dim=-1).mul_(-0.5 * precision) # log-normalization constant of prior w.r.t. alpha = precision - prior_term += 0.5 * self.dim * log_precision + num_coeffs = coeffs.shape[1] + prior_term = prior_term.add_((0.5 * num_coeffs) * log_precision) logits = coeffs @ data.T likelihood = bernoulli_loglikelihood_logit_v(logits, labels) if self.use_mean_reduction: - likelihood /= labels.numel() + likelihood *= total_N / labels.numel() post = likelihood + prior_term + hyperprior_term return post if is_batch else post[0] diff --git a/tests/test_logistic_regression.py b/tests/test_logistic_regression.py new file mode 100644 index 0000000..316d709 --- /dev/null +++ b/tests/test_logistic_regression.py @@ -0,0 +1,27 @@ +import numpy as np +import numpy.matlib as nm + +# Copied under MIT License from https://github.com/DartML/Stein-Variational-Gradient-Descent/blob/8d8f94974e1b91384dc44991ed5ad9a26212f136/python/bayesian_logistic_regression.py + +def dlnprob(theta, data, labels, a0 = 1.0, b0 = 0.01, total_N = None): + if total_N is None: + total_N = data.shape[0] + Xs = data + Ys = labels + + w = theta[:, :-1] # logistic weights + alpha = np.exp(theta[:, -1]) # the last column is logalpha + d = w.shape[1] + + wt = np.multiply((alpha / 2), np.sum(w ** 2, axis=1)) + + coff = np.matmul(Xs, w.T) + y_hat = 1.0 / (1.0 + np.exp(-1 * coff)) + + dw_data = np.matmul(((nm.repmat(np.vstack(Ys), 1, theta.shape[0]) + 1) / 2.0 - y_hat).T, Xs) # Y \in {-1,1} + dw_prior = -np.multiply(nm.repmat(np.vstack(alpha), 1, d) , w) + dw = dw_data * float(total_N / Xs.shape[0]) + dw_prior # re-scale + + dalpha = d / 2.0 - wt + (a0 - 1) - b0 * alpha + 1 # the last term is the jacobian term + + return np.hstack([dw, np.vstack(dalpha)]) # % first order derivative \ No newline at end of file From 5a150c9ffd19b47c8c63964680fa1000e530a8dc Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 28 Apr 2026 10:04:15 -0400 Subject: [PATCH 41/60] Add test for the logistic regression model --- tests/test_logistic_regression.py | 69 +++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 21 deletions(-) diff --git a/tests/test_logistic_regression.py b/tests/test_logistic_regression.py index 316d709..38b1e96 100644 --- a/tests/test_logistic_regression.py +++ b/tests/test_logistic_regression.py @@ -1,27 +1,54 @@ +import pytest import numpy as np import numpy.matlib as nm +import torch +import nak_torch # Copied under MIT License from https://github.com/DartML/Stein-Variational-Gradient-Descent/blob/8d8f94974e1b91384dc44991ed5ad9a26212f136/python/bayesian_logistic_regression.py -def dlnprob(theta, data, labels, a0 = 1.0, b0 = 0.01, total_N = None): - if total_N is None: - total_N = data.shape[0] - Xs = data - Ys = labels - w = theta[:, :-1] # logistic weights - alpha = np.exp(theta[:, -1]) # the last column is logalpha - d = w.shape[1] - - wt = np.multiply((alpha / 2), np.sum(w ** 2, axis=1)) - - coff = np.matmul(Xs, w.T) - y_hat = 1.0 / (1.0 + np.exp(-1 * coff)) - - dw_data = np.matmul(((nm.repmat(np.vstack(Ys), 1, theta.shape[0]) + 1) / 2.0 - y_hat).T, Xs) # Y \in {-1,1} - dw_prior = -np.multiply(nm.repmat(np.vstack(alpha), 1, d) , w) - dw = dw_data * float(total_N / Xs.shape[0]) + dw_prior # re-scale - - dalpha = d / 2.0 - wt + (a0 - 1) - b0 * alpha + 1 # the last term is the jacobian term - - return np.hstack([dw, np.vstack(dalpha)]) # % first order derivative \ No newline at end of file +def grad_logistic_regression_posterior( + theta, data, labels, a0=1.0, b0=0.01, total_N=None +): + if total_N is None: + total_N = data.shape[0] + Xs = data + Ys = labels + + w = theta[:, :-1] # logistic weights + alpha = np.exp(theta[:, -1]) # the last column is logalpha + d = w.shape[1] + + wt = np.multiply((alpha / 2), np.sum(w**2, axis=1)) + coff = np.matmul(Xs, w.T) + y_hat = 1.0 / (1.0 + np.exp(-1 * coff)) + + dw_data = np.matmul( + ((nm.repmat(np.vstack(Ys), 1, theta.shape[0]) + 1) / 2.0 - y_hat).T, Xs + ) # Y \in {-1,1} + dw_prior = -np.multiply(nm.repmat(np.vstack(alpha), 1, d), w) + dw = dw_data * float(total_N / Xs.shape[0]) + dw_prior # re-scale + + dalpha = ( + d / 2.0 - wt + (a0 - 1) - b0 * alpha + 1 + ) # the last term is the jacobian term + + return np.hstack([dw, np.vstack(dalpha)]) # % first order derivative + +def test_logistic_regression(): + N_DATA, N_PARTICLE, DIM = 100, 5, 2 + data = np.random.randn(N_DATA, DIM) + labels = np.random.rand(N_DATA) > 0.5 + data_t, labels_t = torch.as_tensor(data), torch.as_tensor(labels) + model = nak_torch.LogisticRegressionModel(data_t, labels_t, hyperprior_b=0.01) + PROP_SUBSET = 0.2 + n_subset = int(N_DATA * PROP_SUBSET) + data_subset, labels_subset = model.train_data[:n_subset], model.train_labels[:n_subset] + theta = np.random.randn(N_PARTICLE, DIM + 1 + 1) + labels_subset_np = labels_subset.numpy() * 2 - 1 + ref_grad_log_dens = grad_logistic_regression_posterior(theta, data_subset.numpy(), labels_subset_np, total_N = N_DATA) + log_dens = model.to_log_dens(False) + grad_log_dens_fcn = torch.func.grad(lambda x,a: log_dens(x,a).sum()) + theta_t = torch.as_tensor(theta) + grad_log_dens = grad_log_dens_fcn(theta_t, (data_subset, labels_subset)) + assert grad_log_dens == pytest.approx(ref_grad_log_dens) From 6758e96bf83e280148da6a37886ea1db6fd0871f Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 28 Apr 2026 11:35:03 -0400 Subject: [PATCH 42/60] Fix grad-informed estimator --- src/nak_torch/algorithms/msip/estimators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nak_torch/algorithms/msip/estimators.py b/src/nak_torch/algorithms/msip/estimators.py index 2c3c446..ab692c3 100644 --- a/src/nak_torch/algorithms/msip/estimators.py +++ b/src/nak_torch/algorithms/msip/estimators.py @@ -115,7 +115,7 @@ def __call__(self, particles, kernel_lengthscale, target_args): v1_integrand = v1_integrand_gf + v1_integrand_gi sigma_sq_score_v0, log_v0 = vmap_recursive_weighted_average_alpha_v( - v1_integrand, quad_wts, log_v=log_dens_evals + v1_integrand, quad_wts, log_dens_evals ) return log_v0, sigma_sq_score_v0 From f4fba16ee1172d51d39cbb816b3af7e398a2a24e Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 28 Apr 2026 11:38:15 -0400 Subject: [PATCH 43/60] switch to usual inv --- src/nak_torch/algorithms/msip/msip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nak_torch/algorithms/msip/msip.py b/src/nak_torch/algorithms/msip/msip.py index 6b7f992..a76eb6d 100644 --- a/src/nak_torch/algorithms/msip/msip.py +++ b/src/nak_torch/algorithms/msip/msip.py @@ -20,7 +20,7 @@ def initialize(self, init_particles, target, target_args): def step(self, lr, particles, target, algorithm_args, target_args): kernel_lengthscale, kernel_matrix, estimator_output = astuple(algorithm_args) - kernel_matrix_inverse = torch.linalg.pinv(kernel_matrix, hermitian=True) + kernel_matrix_inverse = torch.linalg.inv(kernel_matrix) # Update the particles particles_diff = msip_map( From ebbe8a2116a92df763450cba8065a27e9a80b7c0 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 28 Apr 2026 11:38:28 -0400 Subject: [PATCH 44/60] Satisfactory logreg --- examples/logistic_regression/covertype.py | 125 ++++++++++++---------- 1 file changed, 70 insertions(+), 55 deletions(-) diff --git a/examples/logistic_regression/covertype.py b/examples/logistic_regression/covertype.py index 6b5dcbf..3b26f73 100644 --- a/examples/logistic_regression/covertype.py +++ b/examples/logistic_regression/covertype.py @@ -48,10 +48,19 @@ def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): if not os.path.isfile(DATA_PATH): download_file() +def accuracy(coeffs): + data, labels = regression_model.test_data, regression_model.test_labels + prob = torch.sigmoid(coeffs[:-1] @ data.T) + pred_labels = prob > 0.5 + return torch.mean((pred_labels == labels).to(torch.float64)) + +accuracy_v = torch.vmap(accuracy) + # %% data_path = DATA_PATH regression_model = LogisticRegressionModel( - data_path, None, hyperprior_b=0.01, train_proportion=0.8, reduction="sum" + data_path, None, hyperprior_b=0.01, + train_proportion=0.8, reduction="mean" ) log_dens = regression_model.to_log_dens(use_compiled=True) @@ -83,7 +92,7 @@ def download_file(data_url: str = DATA_URL, data_path: str = DATA_PATH): @torch.compile(dynamic=False) -def mc_quad_rule(batch_size: int, N_quad: int = 500, dim: int = 56): +def mc_quad_rule(batch_size: int, N_quad: int = 2, dim: int = 56): pts = torch.randn((batch_size, N_quad, dim), dtype=torch.get_default_dtype()) wts = torch.ones((batch_size, N_quad), dtype=torch.get_default_dtype()).div_(N_quad) return pts, wts @@ -98,14 +107,15 @@ def spherical_quad( # %% -KERNEL_LENGTHSCALE = 0.1 -GRADIENT_DECAY = 0.9 -KERNEL_DIAG_INFL = 1e-5 -KERNEL_QUANTILE = 0.01 +N_PARTICLES_MSIP = 50 +KERNEL_LENGTHSCALE = 0.01 +GRADIENT_DECAY = 1.0 +KERNEL_DIAG_INFL = 1e-6 +KERNEL_QUANTILE = None msip = MSIP( dim=STATE_DIM, - n_particles=N_PARTICLES, + n_particles=N_PARTICLES_MSIP, kernel_diag_infl=KERNEL_DIAG_INFL, kernel_lengthscale=KERNEL_LENGTHSCALE, kernel_lengthscale_quantile=KERNEL_QUANTILE, @@ -115,18 +125,20 @@ def spherical_quad( target_msip_gi = MSIPQuadGradientInformed(grad_val_log_p, mc_quad_rule, GRADIENT_DECAY) # %% -BOUNDS = (-100.0, 100.0) +BOUNDS = (-100., 100.) N_STEPS = 6000 -LR_MSIP = 0.01 -# trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( -# target_msip_f, -# msip, -# n_steps=N_STEPS, -# lr=LR_MSIP, -# init_particles=init_particles, -# get_target_args=iter(train_data_loader), -# bounds=BOUNDS, -# ) +LR_MSIP = 0.005 +trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( + target_msip_gi, + msip, + n_steps=N_STEPS, + lr=LR_MSIP, + init_particles=init_particles[:N_PARTICLES_MSIP], + get_target_args=iter(train_data_loader), + bounds=BOUNDS, + keep_all=True +) +trajectories_pts_msip_fr[-1] # %% msip_end = trajectories_pts_msip_fr[-1] @@ -137,43 +149,22 @@ def spherical_quad( lower_tri_dist = dist_end[*lower_tri_idx] plt.hist(lower_tri_dist) -# %% -bce_logit_v = torch.vmap( - torch.nn.functional.binary_cross_entropy_with_logits, in_dims=(0, None) -) - - -# @torch.compile -def bce_logit_t(traj_t): - logits_t = traj_t[:, :-1] @ regression_model.test_data.T - return bce_logit_v(logits_t, regression_model.test_labels) - - -bce_logit_traj = torch.vmap(bce_logit_t) -bse_traj_list = [] -for j in tqdm(range(trajectories_pts_msip_fr.shape[0])): - bse_traj_list.append(bce_logit_t(trajectories_pts_msip_fr[j])) -bce_traj = torch.stack(bse_traj_list) -# logits_t = trajectories_msip[:,:,:-1].reshape(-1, trajectories_msip.shape[-1] - 1) @ regression_model.data -# bce_traj = bce_logit_v(logits_t, regression_model.labels).reshape(*trajectories_msip.shape[:2], -1) -# print("BCE t=0: {}, BCE t=T: {}".format(bce_0.mean(), bce_T.mean())) - -# %% -fig, ax = plt.subplots() -for particle_idx in range(N_PARTICLES): - ax.loglog(bce_traj[:, particle_idx], alpha=0.4) -plt.show() - +# # %% +# bce_logit_v = torch.vmap( +# torch.nn.functional.binary_cross_entropy_with_logits, in_dims=(0, None) +# ) +# bse_traj_list = [] +# for j in tqdm(range(trajectories_pts_msip_fr.shape[0])): +# bse_traj_list.append(bce_logit_v(trajectories_pts_msip_fr[j])) +# bce_traj = torch.stack(bse_traj_list) +# # %% +# fig, ax = plt.subplots() +# for particle_idx in range(N_PARTICLES): +# ax.loglog(bce_traj[:, particle_idx], alpha=0.4) +# plt.show() # %% -def accuracy(coeffs): - data, labels = regression_model.test_data, regression_model.test_labels - prob = torch.sigmoid(coeffs[:-1] @ data.T) - pred_labels = prob > 0.5 - return torch.mean((pred_labels == labels).to(torch.float64)) - -accuracy_v = torch.vmap(accuracy) -# accuracy_v(trajectories_pts_msip_fr[-1]) +msip_accuracies = accuracy_v(trajectories_pts_msip_fr.mean(dim=1)) # %% svgd = SVGD( @@ -188,10 +179,11 @@ def accuracy(coeffs): # %% LR_SVGD = 0.05 +N_STEPS_SVGD = 6000 trajectories_pts_svgd = nak_torch.nak( target_svgd, svgd, - n_steps=N_STEPS, + n_steps=N_STEPS_SVGD, lr=LR_SVGD, init_particles=init_particles, get_target_args=iter(train_data_loader), @@ -199,6 +191,29 @@ def accuracy(coeffs): ) # %% -accuracy(trajectories_pts_svgd[-1].mean(0)) +svgd_end = trajectories_pts_svgd[-1] +dist_end = torch.sqrt( + torch.sum(torch.square_(svgd_end[None, :] - svgd_end[:, None]), -1) +) +lower_tri_idx = torch.tril_indices(*dist_end.shape, -1) +lower_tri_dist = dist_end[*lower_tri_idx] +plt.hist(lower_tri_dist) + +# %% +svgd_accuracies = accuracy_v(torch.mean(trajectories_pts_svgd, dim=1)) # %% +def moving_avg(v,N): + if N == 0: + return v + return torch.column_stack([v[j:-(N - j)] for j in range(N)]).mean(1) + +# %% +smoothed_svgd_accuracies = moving_avg(svgd_accuracies, 0) +smoothed_msip_accuracies = moving_avg(msip_accuracies, 0) +fig, ax = plt.subplots() +ax.plot(torch.arange(len(smoothed_svgd_accuracies)), smoothed_svgd_accuracies, label="SVGD") +ax.plot(torch.arange(len(smoothed_msip_accuracies)), smoothed_msip_accuracies, label="MSIP") +ax.legend() +plt.show() +# %% From c4837166e3c9164d4e850d5819b7226ba081b8ce Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 28 Apr 2026 11:56:02 -0400 Subject: [PATCH 45/60] Remove old comment --- src/nak_torch/algorithms/gradfree_aldi.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/nak_torch/algorithms/gradfree_aldi.py b/src/nak_torch/algorithms/gradfree_aldi.py index b83c497..7422c66 100644 --- a/src/nak_torch/algorithms/gradfree_aldi.py +++ b/src/nak_torch/algorithms/gradfree_aldi.py @@ -15,19 +15,6 @@ from nak_torch.tools.util import sym_sqrtm -# def build_gradfree_aldi_step( -# model: GaussianModel, rng: torch.Generator, compile_step: bool -# ): -# prior_mean = model.prior_mean -# likelihood_precision = model.likelihood_precision -# prior_precision = model.prior_precision -# true_obs = model.true_obs -# if isinstance(true_obs, Tensor): -# true_obs.reshape(1, -1) - -# sqrt_2 = torch.sqrt(torch.tensor(2, dtype=true_obs.dtype, device=true_obs.device)) - - def gradfree_aldi_step( particles: BatchPtType, forecast_observations: Float[Tensor, "batch obs"], From 6d7839fbfda42f8a030ff0aeb3adb97ffb6992bd Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 28 Apr 2026 11:56:10 -0400 Subject: [PATCH 46/60] Update readme to include stan info --- README.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/README.md b/README.md index 6bc1f5a..7054c38 100644 --- a/README.md +++ b/README.md @@ -1 +1,34 @@ # Kernel-based quantization algorithms + +## Installation +We recommend installing with `uv`. Currently, the way to install this locally would be +```bash +$ uv pip install -e git+https://github.com/Nodes-and-Kernels/nak_torch +``` + +If you plan on using the examples, make sure that `[examples]` option is installed. Also, make sure that there is no other installation of `pystan`, which is a dependency---we use a fork of the original package to reduce latency for our algorithms when using a stan posterior. + +## List of Algorithms +### MSIP +We largely focus on _mean-shift interacting particle_ (MSIP) algorithms, and we are working to implement several of these. Currently, we have: + +- MSIP +- MSIPGS + +For these algorithms, we have multiple estimators---each of these produces a certain set of dynamics. In particular, we have: + +- MSIPFredholm +- MSIPGradientFree +- MSIPGradientInformed +- MSIPGMMGaussianKernel + +### Other algorithms +We also include several other typical interacting-particle sampling algorithms. + +- Consensus-based sampler (`CBS`) +- Deep ensembles (`DeepEnsembles`) +- Ensemble Kalman Sampler (`EKS`) +- Gradient-informed affine-invariant Langevin dynamics (`GradALDI`) +- Gradient-free ALDI (`GradFreeALDI`) +- Stein variational gradient descent (`SVGD`) + From 792229c53dfc19262a8f5a09265cda9222d3643d Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Wed, 29 Apr 2026 17:10:23 -0400 Subject: [PATCH 47/60] Finish BNN example prelim --- examples/bnn/bnn_impl.py | 310 ++++++++++++++++++-------------- examples/bnn/msip_svgd_new.py | 217 ++++++++++++++++++++++ src/nak_torch/__init__.py | 3 +- src/nak_torch/tools/__init__.py | 2 + src/nak_torch/tools/util.py | 13 +- 5 files changed, 412 insertions(+), 133 deletions(-) create mode 100644 examples/bnn/msip_svgd_new.py diff --git a/examples/bnn/bnn_impl.py b/examples/bnn/bnn_impl.py index 00ff34d..e4d139d 100644 --- a/examples/bnn/bnn_impl.py +++ b/examples/bnn/bnn_impl.py @@ -1,18 +1,23 @@ +from typing import Optional + import torch +from torch import Tensor import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt +from jaxtyping import Float from torch import nn from torch.nn.utils import vector_to_parameters from torch.func import functional_call from collections import OrderedDict +from tqdm import tqdm + # ══════════════════════════════════════════════════════════════════════════════ # BNN model # ══════════════════════════════════════════════════════════════════════════════ - class bnn(nn.Module): """ Standard MLP for binary classification. @@ -24,9 +29,14 @@ class bnn(nn.Module): hidden_dim : int width of each hidden layer n_layers : int number of hidden layers """ + + hidden_dim: int + n_layers: int + def __init__(self, d_in: int, hidden_dim: int, n_layers: int = 1): super().__init__() layers = [] + self.hidden_dim, self.n_layers = hidden_dim, n_layers in_dim = d_in for _ in range(n_layers): layers.append(nn.Linear(in_dim, hidden_dim)) @@ -44,6 +54,7 @@ def forward(self, xb: torch.Tensor) -> torch.Tensor: # Dataset loading with train-test splitting # ══════════════════════════════════════════════════════════════════════════════ + def load_dataset(dataset_name, train_ratio=0.8, seed=0): """ Loads datasets/{dataset_name}.npz (keys X: (d,N), Y: (1,N)) @@ -52,13 +63,13 @@ def load_dataset(dataset_name, train_ratio=0.8, seed=0): Returns X_train, Y_train, X_test, Y_test as torch.double tensors with shapes (N, d) and (N,) """ - data = np.load(f"datasets/{dataset_name}.npz") + data = np.load(f"datasets/{dataset_name}.npz") X = torch.from_numpy(data["X"].T).double() Y = torch.from_numpy(data["Y"].T).double().squeeze() N_total = X.shape[0] - rng = np.random.RandomState(seed) - idx = rng.permutation(N_total) + rng = np.random.RandomState(seed) + idx = rng.permutation(N_total) n_train = int(N_total * train_ratio) i_train, i_test = idx[:n_train], idx[n_train:] @@ -66,93 +77,139 @@ def load_dataset(dataset_name, train_ratio=0.8, seed=0): return X[i_train], Y[i_train], X[i_test], Y[i_test] +def theta_to_param_dict(theta_1d, param_info): + out, i = OrderedDict(), 0 + for name, shape, numel in param_info: + out[name] = theta_1d[i : i + numel].view(shape) + i += numel + return out + + # ══════════════════════════════════════════════════════════════════════════════ # Objective functions # ══════════════════════════════════════════════════════════════════════════════ +def bnn_evaluator(theta_1d, model, buffer_dict, param_info, data): + param_dict = theta_to_param_dict(theta_1d, param_info) + pred = functional_call(model, (param_dict, buffer_dict), (data,)) + return pred + + +bnn_evaluator_v = torch.vmap(bnn_evaluator, in_dims=(0, None, None, None, None)) + + +def soft_margin_loss(x, y): + return torch.mean(F.softplus(-x * y)) -def make_objective(X, Y, model_class='bnn', - hidden_dim=10, n_layers=1, - beta=1.0, lambda2=0.01): + +soft_margin_v = torch.vmap(soft_margin_loss, in_dims=(0, None)) + + +class BNNClassifierPosterior: """ - Returns a log-posterior callable theta -> scalar + A log-posterior for BNN for classification Parameters ---------- - X, Y : torch.double tensors (N, d) and (N,) + data, labels : torch.double tensors (N, d) and (N,) model_class : 'bnn' for now, but we can think about something else hidden_dim : int hidden width n_layers : int depth beta : float temperature - lambda2 : float prior weight + weight_decay : float prior weight """ - d = X.shape[1] - - if model_class == 'bnn': - model = bnn(d_in=d, hidden_dim=hidden_dim, n_layers=n_layers).double() - else: - raise ValueError(f"Unknown model_class: {model_class!r}") - - param_info = [(name, p.shape, p.numel()) - for name, p in model.named_parameters()] - buffer_dict = OrderedDict(model.named_buffers()) - total_numel = sum(nu for _, _, nu in param_info) - - print(f" Model : {model_class} | " - f"hidden_dim = {hidden_dim} n_layers = {n_layers} | " - f"dim(theta) = {total_numel}") - - def theta_to_param_dict(theta_1d): - out, i = OrderedDict(), 0 - for name, shape, numel in param_info: - out[name] = theta_1d[i:i + numel].view(shape) - i += numel - return out - - def loss_for_single_theta(theta_1d): - param_dict = theta_to_param_dict(theta_1d) - pred = functional_call(model, (param_dict, buffer_dict), (X,)) - pred = pred.squeeze() - data_loss = soft_margin_loss(pred, Y) - reg = lambda2 * (theta_1d ** 2).sum() - return -(data_loss + reg) / beta - - def objective_function(theta): - if theta.ndim == 1: - return loss_for_single_theta(theta) - elif theta.ndim == 2: - return torch.stack([loss_for_single_theta(theta[i]) - for i in range(theta.shape[0])]) - else: - raise ValueError(f"theta must be 1D or 2D, got {theta.shape}") - - - objective_function.total_numel = total_numel - objective_function.model_class = model_class - objective_function.model = model - objective_function.param_info = param_info - objective_function.buffer_dict = buffer_dict - - return objective_function + model: bnn + param_info: list[tuple[str, torch.Size, int]] + + def __init__( + self, + data: Float[Tensor, "N_samples dim"], + labels: Float[Tensor, " N_samples"], + hidden_dim: int = 10, + n_layers: int = 1, + beta: float = 1.0, + weight_decay: float = 0.01, + ): + self.data = data + self.labels = labels + self.beta = beta + self.lambda2 = weight_decay + d = data.shape[1] + self.model = bnn(d_in=d, hidden_dim=hidden_dim, n_layers=n_layers).double() + self.param_info = [ + (name, p.shape, p.numel()) for name, p in self.model.named_parameters() + ] + self.buffer_dict = OrderedDict(self.model.named_buffers()) + self.dimension = sum(nu for _, _, nu in self.param_info) + + def __repr__(self): + return ( + f" Model : BNNClassifierPosterior | " + f"hidden_dim = {self.model.hidden_dim} n_layers = {self.model.n_layers} | " + f"dim(theta) = {self.dimension}" + ) + + def __call__( + self, theta: Tensor, data_labels: Optional[tuple[Tensor, Tensor]] = None + ): + single_theta = theta.ndim == 1 + if single_theta: + theta = theta.unsqueeze(0) + elif theta.ndim > 2: + raise ValueError(f"theta.ndim must be 1 or 2. Got {theta.ndim}") + + if data_labels is None: + data, labels = self.data, self.labels + else: + data, labels = data_labels + + pred: Tensor = bnn_evaluator_v( + theta, self.model, self.buffer_dict, self.param_info, data + ) + data_loss = soft_margin_v(pred, labels) * self.data.shape[0] + reg = self.lambda2 * (theta**2).sum(-1) + # Post = -(soft_margin(theta) + lambda_2 ||theta||^2) + return data_loss.add_(reg).div_(-self.beta).squeeze_() + + def get_data_loader( + self, + batch_size: int = 1, + shuffle: bool = False, + num_workers: int = 0, + *data_loader_args, + **data_loader_kwargs, + ): + import torch.utils.data as torch_data + + data: torch_data.TensorDataset + data = torch_data.TensorDataset(self.data, self.labels) + return torch_data.DataLoader( + data, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + *data_loader_args, + **data_loader_kwargs, + ) # ══════════════════════════════════════════════════════════════════════════════ # Tools for handling grids # ══════════════════════════════════════════════════════════════════════════════ + def _make_grid(bounds, M_res): - a, b = bounds + a, b = bounds xs, ys = np.linspace(a, b, M_res), np.linspace(a, b, M_res) Xg, Yg = np.meshgrid(xs, ys) - grid = torch.tensor(np.stack([Xg.ravel(), Yg.ravel()], 1), - dtype=torch.double) + grid = torch.tensor(np.stack([Xg.ravel(), Yg.ravel()], 1), dtype=torch.double) return Xg, Yg, grid def _get_ensemble_probs(traj, obj_fn, grid): particles = traj[-1] - model = obj_fn.model - probs = [] + model = obj_fn.model + probs = [] for i in range(len(particles)): vector_to_parameters(particles[i].double(), model.parameters()) with torch.no_grad(): @@ -163,35 +220,44 @@ def _get_ensemble_probs(traj, obj_fn, grid): def _scatter(ax, X, Y): x = X.numpy() y = Y.numpy() - ax.scatter(x[y > 0, 0], x[y > 0, 1], c='gold', s=25, zorder=5) - ax.scatter(x[y < 0, 0], x[y < 0, 1], c='tomato', s=25, zorder=5) - + ax.scatter(x[y > 0, 0], x[y > 0, 1], c="gold", s=25, zorder=5) + ax.scatter(x[y < 0, 0], x[y < 0, 1], c="tomato", s=25, zorder=5) # ══════════════════════════════════════════════════════════════════════════════ # Visualization tools # ══════════════════════════════════════════════════════════════════════════════ -def plot_boundaries(trajectories_dict, objective_fns_dict, - X_tr, Y_tr, bounds=[-0.1, 1.1], M_res=60): + +def plot_boundaries( + trajectories_dict, objective_fns_dict, X_tr, Y_tr, bounds=[-0.1, 1.1], M_res=60 +): Xg, Yg, grid = _make_grid(bounds, M_res) M = len(trajectories_dict) - fig, axes = plt.subplots(1, M, figsize=(6*M, 5)) - if M == 1: axes = [axes] + fig, axes = plt.subplots(1, M, figsize=(6 * M, 5)) + if M == 1: + axes = [axes] for ax, (name, traj) in zip(axes, trajectories_dict.items()): probs = _get_ensemble_probs(traj, objective_fns_dict[name], grid) for pi in probs: - ax.contour(Xg, Yg, np.sign(pi - 0.5).reshape(M_res, M_res), - levels=[0], colors=['steelblue'], linewidths=0.6, alpha=0.3) + ax.contour( + Xg, + Yg, + np.sign(pi - 0.5).reshape(M_res, M_res), + levels=[0], + colors=["steelblue"], + linewidths=0.6, + alpha=0.3, + ) p_bar = probs.mean(0).reshape(M_res, M_res) - ax.contour(Xg, Yg, p_bar, levels=[0.5], colors=['black'], linewidths=2.0) + ax.contour(Xg, Yg, p_bar, levels=[0.5], colors=["black"], linewidths=2.0) _scatter(ax, X_tr, Y_tr) - ax.set_title(name, fontsize=13, fontweight='bold') + ax.set_title(name, fontsize=13, fontweight="bold") ax.set_xlim(*bounds) ax.set_ylim(*bounds) - ax.set_xlabel('x1') - ax.set_ylabel('x2') + ax.set_xlabel("x1") + ax.set_ylabel("x2") plt.suptitle("Decision boundary (final iteration)", fontsize=14) plt.savefig("de_vs_msip_boundaries.pdf") @@ -199,25 +265,32 @@ def plot_boundaries(trajectories_dict, objective_fns_dict, plt.show() - -def plot_mean_prediction(trajectories_dict, objective_fns_dict, - X_tr, Y_tr, bounds=[-0.1, 1.1], M_res=60): +def plot_mean_prediction( + trajectories_dict, objective_fns_dict, X_tr, Y_tr, bounds=[-0.1, 1.1], M_res=60 +): Xg, Yg, grid = _make_grid(bounds, M_res) M = len(trajectories_dict) - fig, axes = plt.subplots(1, M, figsize=(6*M, 5)) - if M == 1: axes = [axes] + fig, axes = plt.subplots(1, M, figsize=(6 * M, 5)) + if M == 1: + axes = [axes] for ax, (name, traj) in zip(axes, trajectories_dict.items()): probs = _get_ensemble_probs(traj, objective_fns_dict[name], grid) p_bar = probs.mean(0).reshape(M_res, M_res) - im = ax.imshow(p_bar, extent=[*bounds, *bounds], origin='lower', - cmap='RdBu_r', vmin=0, vmax=1) + im = ax.imshow( + p_bar, + extent=[*bounds, *bounds], + origin="lower", + cmap="RdBu_r", + vmin=0, + vmax=1, + ) plt.colorbar(im, ax=ax) - ax.contour(Xg, Yg, p_bar, levels=[0.5], colors=['white'], linewidths=2.0) + ax.contour(Xg, Yg, p_bar, levels=[0.5], colors=["white"], linewidths=2.0) _scatter(ax, X_tr, Y_tr) - ax.set_title(f"{name}", fontsize=12, - fontweight='bold') - ax.set_xlabel('x1'); ax.set_ylabel('x2') + ax.set_title(f"{name}", fontsize=12, fontweight="bold") + ax.set_xlabel("x1") + ax.set_ylabel("x2") plt.suptitle("Mean predictive probability", fontsize=14) plt.savefig("de_vs_msip_mean_predictive_p.pdf") @@ -225,53 +298,32 @@ def plot_mean_prediction(trajectories_dict, objective_fns_dict, plt.show() -def plot_diversity_curve(trajectories_dict, subsample=10): - fig, ax = plt.subplots(figsize=(8, 4)) - for name, traj in trajectories_dict.items(): - T, N, D = traj.shape - steps, divs = [], [] - for t in range(0, T, subsample): - P = traj[t] - sq = ((P.unsqueeze(0) - P.unsqueeze(1)) ** 2).sum(-1) - idx = torch.triu_indices(N, N, offset=1) - divs.append(sq[idx[0], idx[1]].sqrt().min().item()) - steps.append(t) - ax.plot(steps, divs, lw=2, label=name) - ax.set_xlabel("Iteration") - ax.set_ylabel("Smallest pairwise distance") - ax.set_title("Particle diversity over training") - plt.savefig("de_vs_msip_div.pdf") - ax.legend() - plt.tight_layout() - plt.show() - - # ══════════════════════════════════════════════════════════════════════════════ # Evaluations on the dataset # ══════════════════════════════════════════════════════════════════════════════ def evaluate(trajectories_dict, objective_fns_dict, X, Y, split_name="test"): - y_true = (Y.numpy() > 0) - grid = X.double() + y_true = Y.numpy() > 0 + grid = X.double() - print(f"\n{'─'*75}") + print(f"\n{'─' * 75}") print(f" Evaluation on {split_name} set ({len(X)} points)") - print(f"{'─'*75}") + print(f"{'─' * 75}") print(f" {'Method':<12} {'Accuracy':>6} {'DAMV':>10}") - print(f"{'─'*75}") + print(f"{'─' * 75}") results = {} for name, traj in trajectories_dict.items(): - probs = _get_ensemble_probs(traj, objective_fns_dict[name], grid) - p_bar = probs.mean(0) - acc = ((p_bar > 0.5) == y_true).mean() - damv = traj[-1].var(dim=0).mean().item() + probs = _get_ensemble_probs(traj, objective_fns_dict[name], grid) + p_bar = probs.mean(0) + acc = ((p_bar > 0.5) == y_true).mean() + damv = traj[-1].var(dim=0).mean().item() print(f" {name:<12} {acc:>6.3f} {damv:>10.3f}") results[name] = dict(p_bar=p_bar, accuracy=acc, damv=damv) - print(f"{'─'*75}\n") + print(f"{'─' * 75}\n") return results @@ -281,8 +333,8 @@ def plot_diversity_curve(trajectories_dict): T, N, D = traj.shape steps, divs = [], [] for t in range(0, T, 1): - P = traj[t] - sq = ((P.unsqueeze(0) - P.unsqueeze(1)) ** 2).sum(-1) + P = traj[t] + sq = ((P.unsqueeze(0) - P.unsqueeze(1)) ** 2).sum(-1) idx = torch.triu_indices(N, N, offset=1) divs.append(sq[idx[0], idx[1]].sqrt().min().item()) steps.append(t) @@ -296,24 +348,20 @@ def plot_diversity_curve(trajectories_dict): plt.show() - def eval_function_trajectories(obj_fn, trajectories, algo_name): T, M, d = trajectories.shape - eval_tensor = torch.zeros(T, M) + eval_tensor = torch.zeros(T, M, dtype=trajectories.dtype) + prog = tqdm(total=T*M) for t in range(T): for m in range(M): - eval_tensor[t, m] = -obj_fn(trajectories[t, m, :]) - - + eval_tensor[t, m] = -obj_fn(trajectories[t, m, :], None) + prog.update(1) + prog.close() plt.plot(eval_tensor.detach().numpy().min(1), label=algo_name) plt.xlabel("Iteration") plt.ylabel("Objective function") - plt.title("Evaluation of the best particle for "+algo_name) - plt.savefig("best_particle_"+algo_name+".pdf") + plt.title("Evaluation of the best particle for " + algo_name) + plt.savefig("best_particle_" + algo_name + ".pdf") plt.legend(fontsize=16) plt.tight_layout() plt.show() - - -def soft_margin_loss(x,y): - return torch.mean(F.softplus(-x * y)) \ No newline at end of file diff --git a/examples/bnn/msip_svgd_new.py b/examples/bnn/msip_svgd_new.py new file mode 100644 index 0000000..fc5596a --- /dev/null +++ b/examples/bnn/msip_svgd_new.py @@ -0,0 +1,217 @@ +# %% +import torch +import nak_torch +from nak_torch.algorithms import MSIP, SVGD +from nak_torch.algorithms.msip import MSIPFredholm +import bnn_impl as bnn +from nak_torch.tools.types import BatchGradLogDensityEvaluator +from nak_torch.tools.util import infinite_iter + +# %% +# Data loading +DATASET = "two_bananas" +PROP_TRAIN = 0.8 +data_train, labels_train, data_test, labels_test = bnn.load_dataset( + DATASET, train_ratio=PROP_TRAIN, seed=0 +) + +# %% +N_LAYERS = 1 +HIDDEN_DIM = 50 +BETA = 1.0 +LAMBDA = 0.2 +BATCH_SIZE = 64 + +bnn_posterior = bnn.BNNClassifierPosterior( + data=data_train, + labels=labels_train, + beta=BETA, + hidden_dim=HIDDEN_DIM, + weight_decay=LAMBDA**2, +) + +data_loader = bnn_posterior.get_data_loader(BATCH_SIZE, shuffle = True) + +# %% +N_PARTICLES = 100 +init_particles = torch.randn(N_PARTICLES, bnn_posterior.dimension).double() / LAMBDA + +# %% +b_post_ev = bnn_posterior(init_particles, None) + +# %% +def tmp_post(x, a = None): + ret = bnn_posterior(x, a) + return ret.sum(), ret +log_dens_grad_val = torch.func.grad(tmp_post, has_aux=True) + +# %% +GRADIENT_DECAY = 1.0 +KERNEL_LENGTHSCALE = 0.1 +target_msip_fr = MSIPFredholm(GRADIENT_DECAY, log_dens_grad_val) + +msip = MSIP( + dim = bnn_posterior.dimension, + n_particles= N_PARTICLES, + kernel_lengthscale = KERNEL_LENGTHSCALE, +) + + +# %% +BOUNDS = (-1000., 1000.) +N_STEPS = 10000 +LR_MSIP = 1e-3 +trajectories_pts_msip_fr, trajectories_wts_msip_fr = nak_torch.nak( + target_msip_fr, + msip, + n_steps=N_STEPS, + lr=LR_MSIP, + init_particles=init_particles, + get_target_args=infinite_iter(data_loader), + bounds=BOUNDS, + keep_all=True +) + +# %% +bnn.eval_function_trajectories(bnn_posterior, trajectories_pts_msip_fr.double(), "MSIP") + +# %% +target_svgd = BatchGradLogDensityEvaluator( + bnn_posterior, is_grad=False, is_batched=True +) + +svgd = SVGD( + dim = bnn_posterior.dimension, + n_particles= N_PARTICLES, + kernel_lengthscale_quantile= 0.5 +) + + +# %% +trajectories_pts_svgd = nak_torch.nak( + target_svgd, + svgd, + n_steps=N_STEPS, + lr=LR_MSIP, + init_particles=init_particles, + get_target_args=infinite_iter(data_loader), + bounds=BOUNDS, + keep_all=True +) + +# %% +bnn.eval_function_trajectories(bnn_posterior, trajectories_pts_svgd.double(), "SVGD") # type: ignore + +# %% +trajectories_dict = {"MSIP": trajectories_pts_msip_fr, "SVGD": trajectories_pts_svgd} +objective_fns_dict = {"MSIP": bnn_posterior, "SVGD": bnn_posterior} + +# Visualization +bnn.plot_boundaries(trajectories_dict, objective_fns_dict, bnn_posterior.data, bnn_posterior.labels) + +bnn.evaluate(trajectories_dict, objective_fns_dict, data_test, labels_test, "test") + +# %% +# ══════════════════════════════════════════════════════════════════════════════ +# Main +# ══════════════════════════════════════════════════════════════════════════════ + +if __name__ == "__main__": + # Config + DATASET = "two_bananas" + MODEL_CLASS = "bnn" + HIDDEN_DIM = 50 + N_LAYERS = 1 + N_TRAIN = 0.8 # train-test split ratio + N_PARTICLES = 250 + N_STEPS = 1000 + BETA = 1.0 # beta in x-> exp(-beta^{-1}V(x)) + LAMBDA2 = 0.00005 # lambda in prior;0005 + # lambda close to 0 means weak prior + LR_SVGD = 100e-2 + LR_MSIP = 100e-2 + SIGMA = 0.5 + + # Data loading + X_train, Y_train, X_test, Y_test = load_dataset( + DATASET, train_ratio=N_TRAIN, seed=0 + ) + + # Objective loading + obj_msip = make_objective( + X_train, + Y_train, + model_class=MODEL_CLASS, + hidden_dim=HIDDEN_DIM, + n_layers=N_LAYERS, + beta=BETA, + lambda2=LAMBDA2, + ) + obj_svgd = make_objective( + X_train, + Y_train, + model_class=MODEL_CLASS, + hidden_dim=HIDDEN_DIM, + n_layers=N_LAYERS, + beta=BETA, + lambda2=LAMBDA2, + ) + + dimension = obj_msip.total_numel + + # Shared inititializtion + init_particles = 5 * torch.randn(N_PARTICLES, dimension).double() + + # Run MSIP + post_log_dens_grad_val = torch.func.grad_and_value(obj_msip) + msip_fredholm = MSIPFredholm(1.0, post_log_dens_grad_val) + + trajectories_msip, wts_msip = msip( + obj_msip, + N_PARTICLES, + N_STEPS, + dim=dimension, + lr=LR_MSIP, + init_particles=init_particles, + kernel_length_scale=SIGMA, + is_log_density_batched=True, + kernel_diag_infl=1e-8, + bounds=(-100.0, 100.0), + gradient_decay=1.0, + keep_all=True, + compile_step=False, + verbose=True, + ) + + # Run SVGD + + trajectories_svgd = svgd( + obj_svgd, + N_PARTICLES, + N_STEPS, + dimension, + LR_SVGD, + seed=None, + device=None, + init_particles=init_particles, + kernel_length_scale=SIGMA, + keep_all=True, + is_log_density_batched=True, + verbose=True, + ) + + trajectories_dict = {"MSIP": trajectories_msip, "SVGD": trajectories_svgd} + objective_fns_dict = {"MSIP": obj_msip, "SVGD": obj_svgd} + + # Optimization diagnostics + eval_function_trajectories(obj_msip, trajectories_msip, "MSIP") + eval_function_trajectories(obj_svgd, trajectories_svgd, "SVGD") + + # Visualization + plot_boundaries(trajectories_dict, objective_fns_dict, X_train, Y_train) + plot_mean_prediction(trajectories_dict, objective_fns_dict, X_train, Y_train) + plot_diversity_curve(trajectories_dict) + + # Evaluation on a dataset + evaluate(trajectories_dict, objective_fns_dict, X_train, Y_train, "train") + evaluate(trajectories_dict, objective_fns_dict, X_test, Y_test, "test") diff --git a/src/nak_torch/__init__.py b/src/nak_torch/__init__.py index 42ac01c..e8405aa 100644 --- a/src/nak_torch/__init__.py +++ b/src/nak_torch/__init__.py @@ -1,6 +1,6 @@ from . import algorithms, tools from .algorithms import nak -from .tools import GaussianModel, metrics, LogisticRegressionModel +from .tools import GaussianModel, metrics, LogisticRegressionModel, infinite_iter __all__ = [ "algorithms", @@ -9,4 +9,5 @@ "LogisticRegressionModel", "metrics", "nak", + "infinite_iter", ] diff --git a/src/nak_torch/tools/__init__.py b/src/nak_torch/tools/__init__.py index 5b7b872..5c0f2b6 100644 --- a/src/nak_torch/tools/__init__.py +++ b/src/nak_torch/tools/__init__.py @@ -7,6 +7,7 @@ from .average import recursive_weighted_average_alpha_v from .torchify import differentiable_density_factory from .types import GaussianModel, LogisticRegressionModel +from .util import infinite_iter __all__ = [ "kernel", @@ -18,6 +19,7 @@ "quadrature", "adaptive_step", "metrics", + "infinite_iter", ] if importlib.util.find_spec("pyro") is not None: from . import pyro_tools # noqa: F401 diff --git a/src/nak_torch/tools/util.py b/src/nak_torch/tools/util.py index c29ec71..ef5aeba 100644 --- a/src/nak_torch/tools/util.py +++ b/src/nak_torch/tools/util.py @@ -1,11 +1,13 @@ import torch from torch import Tensor from jaxtyping import Float -from typing import Optional, Callable +from typing import Iterable, Optional, Callable, TypeVar from .types import BatchGradLogDensity, BatchPtType, DeviceLike import numpy as np import inspect +__all__ = ["sym_sqrtm", "quantile_distance", "infinite_iter"] + def sym_sqrtm(A: Float[Tensor, "n n"], use_inv: bool = False): e, v = torch.linalg.eigh(A) @@ -88,3 +90,12 @@ def quantile_distance(pts: BatchPtType, quantile: float = 0.5) -> Float: ) diffs_list = diffs[diffs_idxs[0], diffs_idxs[1]] return torch.quantile(diffs_list, quantile) + + +IterType = TypeVar("IterType") + + +def infinite_iter(iterable: Iterable[IterType]): + while True: + for x in iter(iterable): + yield x From 86b830fa558fce2f928533f2cd65ffdb2e72351d Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Wed, 29 Apr 2026 17:50:32 -0400 Subject: [PATCH 48/60] Change PDB version --- pyproject.toml | 3 ++- uv.lock | 7 ++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 32c5f4c..275cfc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ nak-torch = "nak_torch:main" examples = [ "ipykernel>=7.2.0", "matplotlib>=3.10.8", - "posteriordb>=0.2.0", + "posteriordb", "pyro-ppl>=1.9.1", "pystan", "scipy>=1.17.1", @@ -59,3 +59,4 @@ ignore = ["F722"] [tool.uv.sources] pystan = { git = "https://github.com/dannys4/pystan", branch = "change_function_interface" } +posteriordb = { git = "https://github.com/stan-dev/posteriordb-python" } diff --git a/uv.lock b/uv.lock index 82c15e4..8cebfc8 100644 --- a/uv.lock +++ b/uv.lock @@ -1242,7 +1242,7 @@ requires-dist = [ { name = "jaxtyping", specifier = ">=0.3.5" }, { name = "matplotlib", marker = "extra == 'examples'", specifier = ">=3.10.8" }, { name = "numpy", specifier = ">=2.4.1" }, - { name = "posteriordb", marker = "extra == 'examples'", specifier = ">=0.2.0" }, + { name = "posteriordb", marker = "extra == 'examples'", git = "https://github.com/stan-dev/posteriordb-python" }, { name = "pyro-ppl", marker = "extra == 'examples'", specifier = ">=1.9.1" }, { name = "pystan", marker = "extra == 'examples'", git = "https://github.com/dannys4/pystan?branch=change_function_interface" }, { name = "scipy", marker = "extra == 'examples'", specifier = ">=1.17.1" }, @@ -1665,13 +1665,10 @@ wheels = [ [[package]] name = "posteriordb" version = "0.2.0" -source = { registry = "https://pypi.org/simple" } +source = { git = "https://github.com/stan-dev/posteriordb-python#929c54afec39f78df43ca370e3f9b09c4c70af65" } dependencies = [ { name = "requests" }, ] -wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/4d/b72e0782abec07f3d8dabf24cf12673d26b173af2046eb4e67365c776ccf/posteriordb-0.2.0-py3-none-any.whl", hash = "sha256:b6d6f3a349d34db6d4a68da899c818a95e5824c5e23824fc0ebe422f4bd6bac1", size = 24059, upload-time = "2020-11-25T12:04:47.729Z" }, -] [[package]] name = "prompt-toolkit" From 05f5e049c488f9c0f332ffcdb7684a368ed9a58f Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Wed, 29 Apr 2026 17:51:04 -0400 Subject: [PATCH 49/60] Move to cap'ing default kernel --- examples/himmelblau.py | 2 +- src/nak_torch/algorithms/msip/msip_adapt.py | 4 ++-- src/nak_torch/algorithms/msip/msip_tools.py | 4 ++-- src/nak_torch/algorithms/svgd.py | 4 ++-- src/nak_torch/tools/kernel.py | 6 +++--- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/himmelblau.py b/examples/himmelblau.py index 3f194f4..a591698 100644 --- a/examples/himmelblau.py +++ b/examples/himmelblau.py @@ -10,7 +10,7 @@ from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientFree from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre from datetime import datetime -from nak_torch.tools.kernel import kernel_optimal_weight_factory, default_kernel_matrix +from nak_torch.tools.kernel import kernel_optimal_weight_factory, DEFAULT_KERNEL_MATRIX save_gif = False function_name = "himmelblau" diff --git a/src/nak_torch/algorithms/msip/msip_adapt.py b/src/nak_torch/algorithms/msip/msip_adapt.py index b6c0b5d..e0f74b8 100644 --- a/src/nak_torch/algorithms/msip/msip_adapt.py +++ b/src/nak_torch/algorithms/msip/msip_adapt.py @@ -5,7 +5,7 @@ import numpy as np import torch -from nak_torch.tools.kernel import default_kernel_matrix +from nak_torch.tools.kernel import DEFAULT_KERNEL_MATRIX from nak_torch.tools.util import initialize_particles, quantile_distance from .msip_map import msip_map from .estimators import MSIPEstimator @@ -99,7 +99,7 @@ def _choose_running(_: BatchPtType, running: BatchType): choose_running = _choose_running if get_kernel_matrix is None: - get_kernel_matrix = default_kernel_matrix + get_kernel_matrix = DEFAULT_KERNEL_MATRIX msip_estimator = process_msip_density(log_density, **msip_kwargs) particles = initialize_particles(n_particles, dim, init_particles, device, bounds) diff --git a/src/nak_torch/algorithms/msip/msip_tools.py b/src/nak_torch/algorithms/msip/msip_tools.py index fb05eba..1abd03a 100644 --- a/src/nak_torch/algorithms/msip/msip_tools.py +++ b/src/nak_torch/algorithms/msip/msip_tools.py @@ -4,7 +4,7 @@ import torch from nak_torch.tools.func import AlgorithmArgsT, WeightedAdaptiveNAKAlgorithm -from nak_torch.tools.kernel import default_kernel_matrix +from nak_torch.tools.kernel import DEFAULT_KERNEL_MATRIX from nak_torch.tools.util import get_keywords, quantile_distance from .msip_map import msip_map from .estimators import MSIPEstimator, MSIPFredholm @@ -70,7 +70,7 @@ def __init__( self.default_kernel_lengthscale = kernel_lengthscale self.kernel_lengthscale_quantile = kernel_lengthscale_quantile if get_kernel_matrix is None: - self.get_kernel_matrix = default_kernel_matrix + self.get_kernel_matrix = DEFAULT_KERNEL_MATRIX else: self.get_kernel_matrix = get_kernel_matrix diff --git a/src/nak_torch/algorithms/svgd.py b/src/nak_torch/algorithms/svgd.py index 6d7cf32..dc64f24 100644 --- a/src/nak_torch/algorithms/svgd.py +++ b/src/nak_torch/algorithms/svgd.py @@ -10,7 +10,7 @@ from typing import Optional import torch from nak_torch.tools.func import UnweightedAdaptiveNAKAlgorithm -from nak_torch.tools.kernel import default_kernel_elem +from nak_torch.tools.kernel import DEFAULT_KERNEL_ELEM from nak_torch.tools.types import ( BatchGradLogDensityEvaluator, BatchKernelGradValFunction, @@ -97,7 +97,7 @@ def __init__( f"Expected kernel_lengthscale_quantile in [0,1], given {kernel_lengthscale_quantile}" ) if kernel_elem is None: - kernel_elem = default_kernel_elem + kernel_elem = DEFAULT_KERNEL_ELEM self.default_kernel_lengthscale = ( 0.0 if kernel_lengthscale is None else kernel_lengthscale ) diff --git a/src/nak_torch/tools/kernel.py b/src/nak_torch/tools/kernel.py index 1288ff5..5238fc2 100644 --- a/src/nak_torch/tools/kernel.py +++ b/src/nak_torch/tools/kernel.py @@ -15,7 +15,7 @@ ) __all__ = [ - "default_kernel_matrix", + "DEFAULT_KERNEL_MATRIX", "sqexp_kernel_matrix", "sqexp_kernel_elem", "matricize_kernel_elem", @@ -48,8 +48,8 @@ def sqexp_kernel_elem(x: PtType, y: PtType, kernel_length_scale: float) -> Float return ret -default_kernel_elem = sqexp_kernel_elem -default_kernel_matrix = sqexp_kernel_matrix +DEFAULT_KERNEL_ELEM = sqexp_kernel_elem +DEFAULT_KERNEL_MATRIX = sqexp_kernel_matrix def inverse_multi_quadric_kernel_elem( From 7b861f8d02ef0ddc8a7454aa307ced5b05ce4df7 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Wed, 29 Apr 2026 17:51:25 -0400 Subject: [PATCH 50/60] Add sample MMD metric --- src/nak_torch/tools/metrics.py | 116 +++++++++++++++++++++++++++++---- 1 file changed, 104 insertions(+), 12 deletions(-) diff --git a/src/nak_torch/tools/metrics.py b/src/nak_torch/tools/metrics.py index df6d5db..5f4c186 100644 --- a/src/nak_torch/tools/metrics.py +++ b/src/nak_torch/tools/metrics.py @@ -1,5 +1,7 @@ -from typing import Optional +import os +from typing import Callable, Optional from abc import ABC, abstractmethod +import pickle from warnings import warn from jaxtyping import Float @@ -12,12 +14,20 @@ BatchPtType, BatchLogDensity, GradLogDensity, + KernelFunction, + KernelMatrixType, LogDensity, LogDensityGradVal, MatSelfKernelFunction, + PtType, ) -from .kernel import sqexp_kernel_elem, stein_kernel_mat_factory +from .kernel import ( + DEFAULT_KERNEL_ELEM, + matricize_kernel_elem, + sqexp_kernel_elem, + stein_kernel_mat_factory, +) __all__ = ["CrossEntropy", "KernelSteinDiscrepancy", "RelativeESS"] @@ -44,7 +54,7 @@ def __init__( is_log_dens_grad_val: bool = False, ): log_dens_val = ( - log_dens if not is_log_dens_grad_val else lambda x: log_dens(x)[1] + log_dens if not is_log_dens_grad_val else lambda x, a: log_dens(x, a)[1] ) self.log_dens = log_dens_val self.is_log_dens_vectorized = is_log_dens_vectorized @@ -55,12 +65,12 @@ class CrossEntropy(GradFreeMetric): Given target $\pi$ and particle approximation $\mu$, estimate $D_{KL}(\mu || \pi)$. """ - def __call__(self, pts, wts=None): + def __call__(self, pts, wts=None, target_args=None): N = pts.shape[0] N_tens = torch.as_tensor(N, device=pts.device, dtype=pts.dtype) cross_entropy: Float if self.is_log_dens_vectorized: - log_dens_evals = self.log_dens(pts) + log_dens_evals = self.log_dens(pts, target_args) if wts is None: cross_entropy = -log_dens_evals.mean() else: @@ -68,7 +78,7 @@ def __call__(self, pts, wts=None): else: cross_entropy = torch.zeros_like(N_tens) for idx in range(pts.shape[0]): - cross_entropy_eval = self.log_dens(pts[idx]) + cross_entropy_eval = self.log_dens(pts[idx], target_args) if wts is None: cross_entropy_eval /= N_tens else: @@ -83,17 +93,17 @@ class ExclusiveKullbackLeibler(GradFreeMetric): DO NOT USE. MATHEMATICALLY INCORRECT. """ - def __call__(self, pts, wts=None): + def __call__(self, pts, wts=None, target_args=None): warn("Exclusive Kullback Leibler is not mathematically correct.") N = pts.shape[0] N_tens = torch.as_tensor(N, device=pts.device, dtype=pts.dtype) log_dens_evals: Tensor if self.is_log_dens_vectorized: - log_dens_evals = self.log_dens(pts) + log_dens_evals = self.log_dens(pts, target_args) else: log_dens_evals = torch.zeros(N, device=pts.device, dtype=pts.dtype) for idx in range(pts.shape[0]): - log_dens_evals[idx] = self.log_dens(pts[idx]) + log_dens_evals[idx] = self.log_dens(pts[idx], target_args) kl: Float if wts is None: log_ratios = log_dens_evals - N_tens.log() @@ -111,15 +121,15 @@ class RelativeESS(GradFreeMetric): $$rESS(Y,w; \pi) = \frac{1}{\sum_i v_i^2}, v_i = \frac{w_i \pi(y_i)}{\sum_j w_j \pi(y_j)}$.$ """ - def __call__(self, pts, wts=None): + def __call__(self, pts, wts=None, target_args=None): evals: torch.Tensor N = pts.shape[0] if self.is_log_dens_vectorized: - evals = self.log_dens(pts) + evals = self.log_dens(pts, target_args) else: evals = torch.zeros(N, device=pts.device, dtype=pts.dtype) for idx in range(N): - evals[idx] = self.log_dens(pts[idx]) + evals[idx] = self.log_dens(pts[idx], target_args) log_weights = evals if wts is not None: @@ -165,3 +175,85 @@ def __call__(self, pts, wts=None): return stein_mat.mean().sqrt() else: return (wts @ stein_mat @ wts).sqrt() + + +class SampleMMD(Metric): + kernel_sample_reduce: Callable[ + [BatchPtType], BatchType + ] # k(X,y,sigma), X as vector + kernel_mat: Callable[[BatchPtType], KernelMatrixType] + self_mmd: float + kernel_lengthscale: float + + def __init__( + self, + samples: BatchPtType, + kernel_lengthscale: float, + kernel_elem: Optional[KernelFunction] = None, + self_mmd: Optional[Float] = None, + self_mmd_serial: Optional[str] = None, + use_compiled: bool = True, + ): + if kernel_elem is None: + kernel_elem = DEFAULT_KERNEL_ELEM + kernel_vec = torch.vmap(kernel_elem, in_dims=(0, None, None)) + + def _kernel_reduce(y: PtType) -> Float: + return kernel_vec(samples, y, kernel_lengthscale).mean() + + kernel_reduce_elem: Callable[[PtType], Float] + if use_compiled: + kernel_reduce_elem = torch.compile(_kernel_reduce) + else: + kernel_reduce_elem = _kernel_reduce + kernel_reduce = torch.vmap(kernel_reduce_elem) + self.kernel_sample_reduce = kernel_reduce + kernel_mat = matricize_kernel_elem(kernel_elem, use_compiled) + self.kernel_mat = lambda pts: kernel_mat(pts, kernel_lengthscale) + if self_mmd is None: + if self_mmd_serial is None: + self.self_mmd = kernel_reduce(samples).mean() + else: + import fcntl # TODO: Fix for windows, I guess + + self_mmd_dict: dict[float, Float] = {} + file_existed_prev = os.path.exists(self_mmd_serial) + with open(self_mmd_serial, "rb+") as f: + if file_existed_prev: + try: + self_mmd_dict = pickle.load(f) + except EOFError: + self_mmd_dict = {} + self_mmd_pkl: float + if kernel_lengthscale in self_mmd_dict.keys(): + self_mmd_pkl = self_mmd_dict[kernel_lengthscale] + else: + self_mmd_pkl = kernel_reduce(samples).mean() + # Ensure file is locked and up-to-date in case things are written in parallel + with open(self_mmd_serial, "wb+") as f: + fcntl.lockf(f, fcntl.LOCK_EX) + try: + new_self_mmd_dict = pickle.load(f) + except EOFError: + new_self_mmd_dict = self_mmd_dict + self_mmd_dict[kernel_lengthscale] = self_mmd_pkl + pickle.dump(new_self_mmd_dict, f) + fcntl.lockf(f, fcntl.LOCK_UN) + # End lock + self.self_mmd = self_mmd_pkl + else: + self.self_mmd = self_mmd + + def __call__(self, pts, wts=None): + self_mmd = self.self_mmd + cross_kernel_vec = self.kernel_sample_reduce(pts) + pts_kernel_mat = self.kernel_mat(pts) + cross_mmd: Float + pts_mmd: Float + if wts is None: + cross_mmd = cross_kernel_vec.mean() + pts_mmd = pts_kernel_mat.mean() + else: + cross_mmd = cross_kernel_vec @ wts + pts_mmd = wts @ (pts_kernel_mat @ wts) + return torch.sqrt(self_mmd - 2 * cross_mmd + pts_mmd) From bbad43934e4e461b68c0ec19063e4b534491e7d2 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Wed, 29 Apr 2026 19:03:41 -0400 Subject: [PATCH 51/60] minimal PDB example --- examples/stan/pdb_schools.py | 102 +++++++++++++++++++++++++++++- examples/stan/schools.py | 2 +- src/nak_torch/tools/stan_tools.py | 42 ++++++++++-- 3 files changed, 138 insertions(+), 8 deletions(-) diff --git a/examples/stan/pdb_schools.py b/examples/stan/pdb_schools.py index 5781d83..159a590 100644 --- a/examples/stan/pdb_schools.py +++ b/examples/stan/pdb_schools.py @@ -1,9 +1,18 @@ # %% import nest_asyncio import os +import matplotlib.pyplot as plt +import torch +from tqdm import tqdm +import nak_torch +from nak_torch.algorithms import MSIP, SVGD +from nak_torch.algorithms.msip import MSIPFredholm +from nak_torch.tools import stan_tools +from nak_torch.tools.types import BatchGradLogDensityEvaluator nest_asyncio.apply() # See pystan documentation on why you need this when doing jupyter -from posteriordb import PosteriorDatabaseGithub # noqa: E402 +import stan # noqa: E402 +from posteriordb import PosteriorDatabaseGithub # noqa: E402 # %% if "GITHUB_PAT" not in os.environ.keys(): @@ -12,7 +21,98 @@ my_pdb = PosteriorDatabaseGithub() pos = my_pdb.posterior_names() +def sample_tau_prior(N_samples, loc: float = 0., scale: float = 5.): + dist = torch.distributions.Cauchy(loc, scale, True) + return dist.rsample((N_samples,)).abs_() + # %% posterior = my_pdb.posterior("eight_schools-eight_schools_centered") # %% +post_model = stan.build(posterior.model.stan_code(), data=posterior.data.values()) + +# %% +stan_model = stan_tools.StanModel(post_model) + +# %% +pts = torch.randn((100, stan_model.dim)) +pdfs = stan_model.log_dens_batch(pts, None) +grad_log_pdfs = stan_model.grad_log_dens_batch(pts, None) +grad_log_pdfs_2, pdfs_2 = stan_model.grad_val_log_dens_batch(pts, None) + +# %% +GRADIENT_DECAY = 1.0 +N_PARTICLES = 100 +KERNEL_DIAG_INFL = 1e-6 +KERNEL_LENGTHSCALE = 1e-1 +BOUNDS = (-100.0, 100.0) +target_msip_fr = MSIPFredholm(GRADIENT_DECAY, stan_model.grad_val_log_dens_batch) +init_eta = torch.randn((N_PARTICLES, 8)) +init_tau = sample_tau_prior(N_PARTICLES).clamp_(*BOUNDS) +init_mu = torch.randn((N_PARTICLES, 1)) * 5 +init_particles = torch.column_stack((init_mu, init_tau, init_eta)) + +msip = MSIP( + stan_model.dim, + N_PARTICLES, + kernel_diag_infl=KERNEL_DIAG_INFL, + kernel_lengthscale=KERNEL_LENGTHSCALE, +) + +# %% +N_STEPS = 1000 +LR = 1e-3 +trajectories_msip_fr = nak_torch.nak( + target_msip_fr, + msip, + N_STEPS, + LR, + init_particles=init_particles, + bounds=BOUNDS, +) +trajectories_pts_msip_fr, trajectories_wts_msip_fr = trajectories_msip_fr + +# %% +target_svgd = BatchGradLogDensityEvaluator( + stan_model.grad_log_dens_batch, + is_grad=True, + is_batched=True +) + +svgd = SVGD( + stan_model.dim, + N_PARTICLES, + kernel_lengthscale=KERNEL_LENGTHSCALE, + kernel_lengthscale_quantile=0.5 +) + +# %% +N_STEPS = 1000 +LR = 1e-3 +trajectories_pts_svgd = nak_torch.nak( + target_svgd, + svgd, + N_STEPS, + LR, + init_particles=init_particles, + bounds=BOUNDS, +) + + +# %% +draws = stan_tools.get_draws(post_model, posterior) + +# %% +cross_ent = nak_torch.metrics.CrossEntropy(stan_model.log_dens_batch) + +# %% +msip_cross_ent = [cross_ent(pts, None, None) for pts in tqdm(trajectories_pts_msip_fr)] +svgd_cross_ent = [cross_ent(pts, None, None) for pts in tqdm(trajectories_pts_svgd)] + +# %% +plt.plot(msip_cross_ent, label="MSIP") +plt.plot(svgd_cross_ent, label="SVGD") +plt.plot() +plt.title("Cross entropy across iterations") +plt.legend() +# %% diff --git a/examples/stan/schools.py b/examples/stan/schools.py index 4c8c570..adc0215 100644 --- a/examples/stan/schools.py +++ b/examples/stan/schools.py @@ -1,6 +1,6 @@ # %% -import nest_asyncio import torch +import nest_asyncio import nak_torch from nak_torch.tools import stan_tools from nak_torch.algorithms import MSIP, SVGD diff --git a/src/nak_torch/tools/stan_tools.py b/src/nak_torch/tools/stan_tools.py index 295e7e3..a3ba859 100644 --- a/src/nak_torch/tools/stan_tools.py +++ b/src/nak_torch/tools/stan_tools.py @@ -1,10 +1,38 @@ from typing import Optional +import re import stan.model import torch +from posteriordb.posterior import Posterior from nak_torch.tools.types import BatchPtType, BatchType, NAKTarget +__all__ = ["get_draws", "StanModel"] + + +def expanded_var_names(model: stan.model.Model): + names = [] + array_param = re.compile(r"\.\d+$") + for var in model.constrained_param_names: + if array_param.search(var) is None: + names.append(var) + else: + v, n = var.split(".") + names.append(v + "[" + n + "]") + return names + + +def get_draws(model: stan.model.Model, posterior: Posterior): + reference_draws = posterior.reference_draws() + var_names = expanded_var_names(model) + all_draws = torch.concat( + [ + torch.column_stack([torch.as_tensor(chain[v]) for v in var_names]) + for chain in reference_draws + ] + ) + return all_draws + class StanModel(NAKTarget): dim: int @@ -21,21 +49,23 @@ def __init__(self, model: stan.model.Model, dim: Optional[int] = None): self.dim = dim self.model = model - def log_dens_batch(self, theta: BatchPtType, _) -> BatchType: + def log_dens_batch(self, theta: BatchPtType, _=None) -> BatchType: device, dtype = theta.device, theta.dtype - out_np = self.model.log_prob(theta.cpu().numpy()) + out_np = self.model.log_prob(theta.cpu().to(dtype=torch.float64).numpy()) return torch.as_tensor(out_np, device=device, dtype=dtype) - def grad_log_dens_batch(self, theta: BatchPtType, _) -> BatchPtType: + def grad_log_dens_batch(self, theta: BatchPtType, _=None) -> BatchPtType: device, dtype = theta.device, theta.dtype - out_np = self.model.grad_log_prob(theta.cpu().numpy()) + out_np = self.model.grad_log_prob(theta.cpu().to(dtype=torch.float64).numpy()) return torch.as_tensor(out_np, device=device, dtype=dtype) def grad_val_log_dens_batch( - self, theta: BatchPtType, _ + self, theta: BatchPtType, _=None ) -> tuple[BatchPtType, BatchType]: device, dtype = theta.device, theta.dtype - out_grad_np, out_val_np = self.model.grad_val_log_prob(theta.cpu().numpy()) + out_grad_np, out_val_np = self.model.grad_val_log_prob( + theta.cpu().to(dtype=torch.float64).numpy() + ) return torch.as_tensor( out_grad_np, device=device, dtype=dtype ), torch.as_tensor(out_val_np, device=device, dtype=dtype) From 2738e646fcf93426e4c026d3608470d8f7c9aeea Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sat, 2 May 2026 04:11:18 -0400 Subject: [PATCH 52/60] Fix up posteriordb ver --- pyproject.toml | 3 +-- uv.lock | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 275cfc2..32c5f4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ nak-torch = "nak_torch:main" examples = [ "ipykernel>=7.2.0", "matplotlib>=3.10.8", - "posteriordb", + "posteriordb>=0.2.0", "pyro-ppl>=1.9.1", "pystan", "scipy>=1.17.1", @@ -59,4 +59,3 @@ ignore = ["F722"] [tool.uv.sources] pystan = { git = "https://github.com/dannys4/pystan", branch = "change_function_interface" } -posteriordb = { git = "https://github.com/stan-dev/posteriordb-python" } diff --git a/uv.lock b/uv.lock index 8cebfc8..82c15e4 100644 --- a/uv.lock +++ b/uv.lock @@ -1242,7 +1242,7 @@ requires-dist = [ { name = "jaxtyping", specifier = ">=0.3.5" }, { name = "matplotlib", marker = "extra == 'examples'", specifier = ">=3.10.8" }, { name = "numpy", specifier = ">=2.4.1" }, - { name = "posteriordb", marker = "extra == 'examples'", git = "https://github.com/stan-dev/posteriordb-python" }, + { name = "posteriordb", marker = "extra == 'examples'", specifier = ">=0.2.0" }, { name = "pyro-ppl", marker = "extra == 'examples'", specifier = ">=1.9.1" }, { name = "pystan", marker = "extra == 'examples'", git = "https://github.com/dannys4/pystan?branch=change_function_interface" }, { name = "scipy", marker = "extra == 'examples'", specifier = ">=1.17.1" }, @@ -1665,10 +1665,13 @@ wheels = [ [[package]] name = "posteriordb" version = "0.2.0" -source = { git = "https://github.com/stan-dev/posteriordb-python#929c54afec39f78df43ca370e3f9b09c4c70af65" } +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, ] +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/4d/b72e0782abec07f3d8dabf24cf12673d26b173af2046eb4e67365c776ccf/posteriordb-0.2.0-py3-none-any.whl", hash = "sha256:b6d6f3a349d34db6d4a68da899c818a95e5824c5e23824fc0ebe422f4bd6bac1", size = 24059, upload-time = "2020-11-25T12:04:47.729Z" }, +] [[package]] name = "prompt-toolkit" From 15dfab6db281902e8604d23f3a2273e85276570f Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 3 May 2026 10:41:28 +0100 Subject: [PATCH 53/60] Working PDB --- examples/stan/pdb_covid.py | 150 +++++++++++++++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 examples/stan/pdb_covid.py diff --git a/examples/stan/pdb_covid.py b/examples/stan/pdb_covid.py new file mode 100644 index 0000000..0e3b2cb --- /dev/null +++ b/examples/stan/pdb_covid.py @@ -0,0 +1,150 @@ +# %% +from typing import Optional + +import nest_asyncio +import os +import matplotlib.pyplot as plt +import torch +from tqdm import tqdm +import nak_torch +from nak_torch.algorithms import MSIP, SVGD +from nak_torch.algorithms.msip import MSIPFredholm +from nak_torch.tools import stan_tools +from nak_torch.tools.types import BatchGradLogDensityEvaluator, DeviceLike + +nest_asyncio.apply() # See pystan documentation on why you need this when doing jupyter +import stan # noqa: E402 +import posteriordb # noqa: E402 + +def covid_prior_sample( + N_samples: int, + M_y: int = 14, + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceLike] = None, + rng: Optional[torch.Generator] = None, +): + if rng is None: + rng = torch.default_generator + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.get_default_device() + tau = ( + torch.empty((N_samples, 1), dtype=dtype, device=device) + .exponential_(generator=rng) + .div_(0.03) + ) + y = ( + torch.empty((N_samples, M_y), dtype=dtype, device=device) + .exponential_(generator=rng) + .div_(tau) + ) + phi = torch.randn((N_samples, 1), generator=rng).mul_(5.0) + kappa = torch.randn((N_samples, 1), generator=rng).mul_(0.5) + mu = torch.randn((N_samples, M_y), generator=rng).mul_(kappa).add_(3.28) + alpha_hier = torch._standard_gamma( + torch.as_tensor(0.1667, dtype=dtype, device=device).expand(N_samples, 6), + generator=rng, + ) + ifr_noise = torch.randn((N_samples, M_y), generator=rng).mul_(0.1).add_(1.0) + log_tau = tau.log_() + log_alpha_hier = alpha_hier.log_() + log_y = y.log_() + return torch.column_stack((mu, log_alpha_hier, kappa, log_y, phi, log_tau, ifr_noise)) + +# %% +pdb = posteriordb.PosteriorDatabase() +which_posterior = "ecdc0501-covid19imperial_v3" +posterior = pdb.posterior(which_posterior) +post_model = stan.build(posterior.model.stan_code(), data=posterior.data.values()) +dim = sum(posterior.information["dimensions"].values()) +stan_model = stan_tools.StanModel(post_model, dim) + +# %% +pts = torch.randn((10, stan_model.dim)) +pdfs = stan_model.log_dens_batch(pts, None) +grad_log_pdfs = stan_model.grad_log_dens_batch(pts, None) +grad_log_pdfs_2, pdfs_2 = stan_model.grad_val_log_dens_batch(pts, None) +assert (pdfs - pdfs_2).square_().sum() < 1e-10 +assert (grad_log_pdfs - grad_log_pdfs_2).square_().sum() < 1e-10 + +# %% +GRADIENT_DECAY = 1.0 +N_PARTICLES = 25 +KERNEL_DIAG_INFL = 1e-2 +KERNEL_LENGTHSCALE = 1e-1 +BOUNDS = (-100.0, 100.0) +target_msip_fr = MSIPFredholm(GRADIENT_DECAY, stan_model.grad_val_log_dens_batch) +init_particles = torch.randn((N_PARTICLES, stan_model.dim))#covid_prior_sample(N_PARTICLES) + +msip = MSIP( + stan_model.dim, + N_PARTICLES, + kernel_diag_infl=KERNEL_DIAG_INFL, + kernel_lengthscale=KERNEL_LENGTHSCALE, +) + +# %% +N_STEPS = 1000 +LR = 1e-3 +trajectories_msip_fr = nak_torch.nak( + target_msip_fr, + msip, + N_STEPS, + LR, + init_particles=init_particles, + bounds=BOUNDS, +) +trajectories_pts_msip_fr, trajectories_wts_msip_fr = trajectories_msip_fr + +# %% +# %% +target_svgd = BatchGradLogDensityEvaluator( + stan_model.grad_log_dens_batch, + is_grad=True, + is_batched=True +) + +svgd = SVGD( + stan_model.dim, + N_PARTICLES, + kernel_lengthscale=KERNEL_LENGTHSCALE, + kernel_lengthscale_quantile=0.5 +) + +# %% +N_STEPS = 1000 +LR = 1e-4 +trajectories_pts_svgd = nak_torch.nak( + target_svgd, + svgd, + N_STEPS, + LR, + init_particles=init_particles, + bounds=BOUNDS, +) + + +# %% +crossent = nak_torch.metrics.CrossEntropy(stan_model.log_dens_batch) +msip_crossent = [crossent(p,w) for p,w in tqdm(zip(*trajectories_msip_fr), total=N_STEPS+1)] +svgd_crossent = [crossent(p) for p in tqdm(trajectories_pts_svgd)] + +# %% +ksd = nak_torch.metrics.KernelSteinDiscrepancy(stan_model.grad_log_dens_batch, KERNEL_LENGTHSCALE) +msip_ksd = [ksd(p,w) for p,w in tqdm(zip(*trajectories_msip_fr), total=N_STEPS+1)] +svgd_ksd = [ksd(p) for p in tqdm(trajectories_pts_svgd)] + +# %% +plt.plot(msip_crossent, label="msip") +plt.plot(svgd_crossent, label="svgd") +plt.title("Cross Entropy") +plt.legend() +plt.show() + +# %% +plt.plot(msip_ksd, label="msip") +plt.plot(svgd_ksd, label="svgd") +plt.legend() +plt.title("KSD") +plt.show() \ No newline at end of file From d3a86f4046156e95637c1d2555f3fc7a9acf0353 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Sun, 3 May 2026 11:12:49 +0100 Subject: [PATCH 54/60] Fix metrics test --- tests/test_metrics.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 64ef98b..290d2e8 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -6,16 +6,18 @@ MAX_POW10 = 4 -def normal_logpdf(x: Tensor): +def normal_logpdf(x: Tensor, _ = None): return x.square().sum(-1).neg().div(2) +def grad_normal_logpdf(x: Tensor, _ = None): + return x.neg() + def test_ksd(): - grad_log_normal = torch.neg KSD_KERNEL_ELEM = kernel.inverse_multi_quadric_kernel_elem rng = torch.Generator() rng.manual_seed(321393021) kernel_length_scale = 0.1 - ksd = metrics.KernelSteinDiscrepancy(grad_log_normal, kernel_length_scale, kernel_elem=KSD_KERNEL_ELEM) + ksd = metrics.KernelSteinDiscrepancy(grad_normal_logpdf, kernel_length_scale, kernel_elem=KSD_KERNEL_ELEM) in_sizes = torch.arange(MAX_POW10 + 1) # 15 is empirical, does not really matter. expected_ksd = 15*torch.pow(10.0, -0.5 * in_sizes) From 0ece5a994aa24fe9d941349e99ba9cf7a8de4bd3 Mon Sep 17 00:00:00 2001 From: Daniel <43151183+dannys4@users.noreply.github.com> Date: Mon, 4 May 2026 08:36:27 +0100 Subject: [PATCH 55/60] Specify device and dtype Updated tensor creation to specify device and dtype. --- examples/functions/aristoff_bangerth.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/functions/aristoff_bangerth.py b/examples/functions/aristoff_bangerth.py index afb4ecd..043e455 100644 --- a/examples/functions/aristoff_bangerth.py +++ b/examples/functions/aristoff_bangerth.py @@ -251,7 +251,7 @@ def build_forward_solver_args(N, N_obs, device=None, dtype: Optional[torch.dtype [-1./6, 2./3, -1./6, -1./3], [-1./3, -1./6, 2./3, -1./6], [-1./6, -1./3, -1./6, 2./3] - ], device=device) + ], dtype=dtype, device=device) # Locate boundary labels boundaries = torch.concat(( @@ -323,7 +323,7 @@ def get_patch_idx(idx): A_rows = A_idxs A_cols = A_idxs.permute((0,2,1)) - A_dens = torch.zeros((N_batch, Np1**2, Np1**2)) + A_dens = torch.zeros((N_batch, Np1**2, Np1**2), device=theta.device, dtype=theta.dtype) # Unroll loop A_dens[:,A_rows[:,0,0],A_cols[:,0,0]] += A_locs[:,:,0,0] A_dens[:,A_rows[:,0,1],A_cols[:,0,1]] += A_locs[:,:,0,1] @@ -345,7 +345,6 @@ def get_patch_idx(idx): A_dens[:,boundaries, :] = 0. A_dens[:,:,boundaries] = 0. A_dens[:,boundaries, boundaries] = 1. - # Solve linear equation for coefficients, U, and then # get the Z vector by multiplying by the measurement matrix return torch.linalg.solve(A_dens, b.repeat(N_batch,1)) @@ -456,4 +455,4 @@ def get_file(fname): if len(sys.argv) < 1: raise ValueError("Expected path as first argument to this script") path = sys.argv[1] - verify_against_stored_tests(path, z_hat_true) \ No newline at end of file + verify_against_stored_tests(path, z_hat_true) From af1272c81d20880f07791bcab871f94311b5c024 Mon Sep 17 00:00:00 2001 From: Daniel <43151183+dannys4@users.noreply.github.com> Date: Mon, 4 May 2026 08:37:03 +0100 Subject: [PATCH 56/60] Update type annotations for nak function parameters --- src/nak_torch/algorithms/loop.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/nak_torch/algorithms/loop.py b/src/nak_torch/algorithms/loop.py index a3e1a8a..acda3a8 100644 --- a/src/nak_torch/algorithms/loop.py +++ b/src/nak_torch/algorithms/loop.py @@ -7,20 +7,20 @@ from torch import Tensor from nak_torch.tools.util import initialize_particles -from nak_torch.tools.types import ( - NAKTarget, -) from nak_torch.tools.func import ( + AlgorithmArgsT, GeneralAdaptiveNAKAlgorithm, + NAKTargetT, + WeightT, ) __all__ = ["nak"] def nak( - target: NAKTarget, - algorithm: GeneralAdaptiveNAKAlgorithm, + target: NAKTargetT, + algorithm: GeneralAdaptiveNAKAlgorithm[NAKTargetT, WeightT, AlgorithmArgsT], n_steps: int, lr: float, rng_or_seed: Optional[int | torch.Generator] = None, From 04fb72fbda87228978e71281504a3473748b5da4 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Mon, 4 May 2026 09:41:11 +0100 Subject: [PATCH 57/60] Fixing dtype and device issues --- src/nak_torch/algorithms/cbs.py | 9 ++++++++- src/nak_torch/algorithms/eks.py | 20 ++++++++++++++++---- src/nak_torch/algorithms/grad_aldi.py | 7 ++++++- src/nak_torch/algorithms/gradfree_aldi.py | 20 ++++++++++++++++---- src/nak_torch/tools/quadrature.py | 6 +++--- src/nak_torch/tools/util.py | 6 ++++-- 6 files changed, 53 insertions(+), 15 deletions(-) diff --git a/src/nak_torch/algorithms/cbs.py b/src/nak_torch/algorithms/cbs.py index 072beaa..3b707b3 100644 --- a/src/nak_torch/algorithms/cbs.py +++ b/src/nak_torch/algorithms/cbs.py @@ -29,7 +29,14 @@ def cbs_step( drift_term = particles_diff.neg_() noise_sqrt_cov = sym_sqrtm(particles_cov.mul_(motion_scaling_sq)) motion_term = ( - torch.normal(0.0, 1.0, particles.shape, generator=rng, device=rng.device) + torch.normal( + 0.0, + 1.0, + particles.shape, + generator=rng, + dtype=particles.dtype, + device=rng.device, + ) @ noise_sqrt_cov ) return drift_term, motion_term diff --git a/src/nak_torch/algorithms/eks.py b/src/nak_torch/algorithms/eks.py index a13ff7c..f919d16 100644 --- a/src/nak_torch/algorithms/eks.py +++ b/src/nak_torch/algorithms/eks.py @@ -66,7 +66,14 @@ def eks_step( new_particles: BatchPtType = torch.linalg.solve( prior_term_premul, particles - likely_term, left=False ) - noise_tens = torch.randn(particles.shape, generator=rng) + noise_tens = torch.normal( + 0.0, + 1.0, + size=particles.shape, + device=particles.device, + dtype=particles.dtype, + generator=rng, + ) noise_samp = noise_tens @ sqrt_prior_cov return new_particles.add_(noise_samp) @@ -114,9 +121,14 @@ def step( algorithm_args: None, target_args: Any, ) -> tuple[Tensor, None, None]: - forward_model, likelihood_precision, prior_precision, true_obs, prior_mean = ( - astuple(target) - ) + ( + forward_model, + likelihood_precision, + prior_precision, + true_obs, + prior_mean, + _, + ) = astuple(target) forecast_observations = forward_model(particles, target_args) new_particles = eks_step( particles, diff --git a/src/nak_torch/algorithms/grad_aldi.py b/src/nak_torch/algorithms/grad_aldi.py index 8e991ba..b235fed 100644 --- a/src/nak_torch/algorithms/grad_aldi.py +++ b/src/nak_torch/algorithms/grad_aldi.py @@ -31,7 +31,12 @@ def grad_aldi_step( particles_sqrt_cov = sym_sqrtm(2 * particles_cov) # sqrt(2) comes from noise particles_noise_iid = torch.normal( - 0.0, 1.0, size=particles.shape, generator=rng, device=particles.device + 0.0, + 1.0, + size=particles.shape, + generator=rng, + dtype=particles.dtype, + device=particles.device, ) particles_noise = particles_noise_iid @ particles_sqrt_cov drift_term = term1.add_(term2) diff --git a/src/nak_torch/algorithms/gradfree_aldi.py b/src/nak_torch/algorithms/gradfree_aldi.py index 7422c66..5a34ac0 100644 --- a/src/nak_torch/algorithms/gradfree_aldi.py +++ b/src/nak_torch/algorithms/gradfree_aldi.py @@ -55,7 +55,14 @@ def gradfree_aldi_step( prior_term2 = forecast_deviation.mul_((dim + 1) / N_batch) particle_diff = prior_term2.sub_(prior_term1).sub_(likely_term) - noise = torch.normal(0.0, 1.0, particles.shape, generator=rng) + noise = torch.normal( + 0.0, + 1.0, + particles.shape, + dtype=particles.dtype, + device=particles.device, + generator=rng, + ) motion = torch.matmul(noise, sqrt_cov_forecast) return particle_diff, motion @@ -107,9 +114,14 @@ def step( algorithm_args: None, target_args: Any, ) -> tuple[Tensor, None, None]: - forward_model, likelihood_precision, prior_precision, true_obs, prior_mean = ( - astuple(target) - ) + ( + forward_model, + likelihood_precision, + prior_precision, + true_obs, + prior_mean, + _, + ) = astuple(target) forecast_observations = forward_model(particles, target_args) particles_diff, particles_noise = gradfree_aldi_step( particles, diff --git a/src/nak_torch/tools/quadrature.py b/src/nak_torch/tools/quadrature.py index 4cadc2b..c16c3f3 100644 --- a/src/nak_torch/tools/quadrature.py +++ b/src/nak_torch/tools/quadrature.py @@ -79,8 +79,8 @@ def spherical_MC_radial_Laguerre( def gauss_MC( - batch_size: int, N_quad: int, d: int + batch_size: int, N_quad: int, d: int, device=None, dtype=None ) -> tuple[Float[Tensor, "batch N_quad d"], Float[Tensor, "batch N_quad"]]: - pts = torch.randn((batch_size, N_quad, d)) - wts = torch.ones((batch_size, N_quad)).div_(N_quad) + pts = torch.randn((batch_size, N_quad, d), device=device, dtype=dtype) + wts = torch.ones((batch_size, N_quad), device=device, dtype=dtype).div_(N_quad) return pts, wts diff --git a/src/nak_torch/tools/util.py b/src/nak_torch/tools/util.py index ef5aeba..ee0b7dc 100644 --- a/src/nak_torch/tools/util.py +++ b/src/nak_torch/tools/util.py @@ -12,9 +12,11 @@ def sym_sqrtm(A: Float[Tensor, "n n"], use_inv: bool = False): e, v = torch.linalg.eigh(A) if use_inv: - return torch.einsum("ij,j,kj->ik", v, torch.reciprocal_(e.sqrt_()), v) + return torch.einsum("ij,j,kj->ik", v, torch.reciprocal_(e.sqrt_()), v).to( + dtype=A.dtype + ) else: - return torch.einsum("ij,j,kj->ik", v, e.sqrt_(), v) + return torch.einsum("ij,j,kj->ik", v, e.sqrt_(), v).to(dtype=A.dtype) def get_keywords(fcn: Callable): From 3fa28c3f908754423652fb142c970ae8174ea710 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Mon, 4 May 2026 17:01:57 +0100 Subject: [PATCH 58/60] Add generator for MC quad --- src/nak_torch/tools/quadrature.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nak_torch/tools/quadrature.py b/src/nak_torch/tools/quadrature.py index c16c3f3..2d01a90 100644 --- a/src/nak_torch/tools/quadrature.py +++ b/src/nak_torch/tools/quadrature.py @@ -79,8 +79,10 @@ def spherical_MC_radial_Laguerre( def gauss_MC( - batch_size: int, N_quad: int, d: int, device=None, dtype=None + batch_size: int, N_quad: int, d: int, rng: torch.Generator, device=None, dtype=None ) -> tuple[Float[Tensor, "batch N_quad d"], Float[Tensor, "batch N_quad"]]: - pts = torch.randn((batch_size, N_quad, d), device=device, dtype=dtype) + pts = torch.randn( + (batch_size, N_quad, d), device=device, dtype=dtype, generator=rng + ) wts = torch.ones((batch_size, N_quad), device=device, dtype=dtype).div_(N_quad) return pts, wts From 2b6881c393e87545f290d7d642e7c7a9557c8d8e Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 12 May 2026 19:56:43 +0100 Subject: [PATCH 59/60] Himmelblau stuff --- examples/himmelblau.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/himmelblau.py b/examples/himmelblau.py index a591698..b48d99d 100644 --- a/examples/himmelblau.py +++ b/examples/himmelblau.py @@ -8,9 +8,8 @@ from nak_torch import nak from nak_torch.algorithms import MSIP, SVGD from nak_torch.algorithms.msip import MSIPFredholm, MSIPQuadGradientFree -from nak_torch.tools.quadrature import spherical_MC_radial_Laguerre +from nak_torch.tools.quadrature import gauss_MC, spherical_MC_radial_Laguerre from datetime import datetime -from nak_torch.tools.kernel import kernel_optimal_weight_factory, DEFAULT_KERNEL_MATRIX save_gif = False function_name = "himmelblau" @@ -25,8 +24,8 @@ init_particles = torch.randn((n_particles, 2)) + 8.0 params = { "n_steps": 100, - "bounds": (-15, 15), - "kernel_lengthscale": 0.15, + "bounds": (-10, 10), + "kernel_lengthscale": 0.2, "init_particles": init_particles, "n_particles": n_particles, "dim": 2, @@ -37,11 +36,11 @@ # %% msip = MSIP(**params) -svgd = SVGD(**params) +svgd = SVGD(kernel_lengthscale_quantile=0.5, **params) # %% target_msip_fr = MSIPFredholm( - gradient_decay=0.95, + gradient_decay=1.0, log_dens_grad_val=torch.vmap( torch.func.grad_and_value(log_density), in_dims=(0,None) @@ -87,11 +86,14 @@ # %% target_msip_gf = MSIPQuadGradientFree( log_density, - lambda b: spherical_MC_radial_Laguerre(b, N_spherical=5, d=2, N_radial=2) + # lambda b: spherical_MC_radial_Laguerre(b, N_spherical=5, d=2, N_radial=2) + lambda b: gauss_MC(b, 10, 2, torch.default_generator) ) params_gf = params.copy() -params_gf['lr'] = 0.8 +params_gf['lr'] = 0.75 n_particles = 25 +rng = torch.Generator() +rng.manual_seed(12321) trajectories_pts_gf,trajectories_wts_gf = nak(target_msip_gf, msip, **params_gf) # %% @@ -106,7 +108,7 @@ def kernel_elem(x: torch.Tensor, y: torch.Tensor, sigma: float): return torch.reciprocal(1 + (x - y).div(sigma).square().sum()) ksd_eval = nak_torch.metrics.KernelSteinDiscrepancy(batch_grad_log_dens, 0.25, kernel_elem=kernel_elem) -print("KSD", ksd_eval(pts_fr, wts_fr), ksd_eval(pts_svgd), ksd_eval(pts_gf, wts_gf)) +print("KSD", ksd_eval(pts_fr, wts_fr).item(), ksd_eval(pts_svgd).item(), ksd_eval(pts_gf, wts_gf).item()) # %% ress = nak_torch.metrics.RelativeESS(batch_log_dens) @@ -126,7 +128,7 @@ def kernel_elem(x: torch.Tensor, y: torch.Tensor, sigma: float): g_min, g_max = [m(x[i] for x in extrema_pts) for (m,i) in [(min,0), (max,1)]] extrema_wts = [(w.min(), w.max()) for w in wt_list if w is not None] vmin, vmax = [m(x[i] for x in extrema_wts) for (m,i) in [(min,0), (max,1)]] -titles = ["Initialization", "SVGD", "MSIP-1", "MSIP-GF"] +titles = ["Initialization", "SVGD", "MSIP-F", "MSIP-GF"] title_weights = [None, None, 'heavy', 'heavy'] for (ax, title, pt, wt, title_wt) in zip(axs, titles, pt_list, wt_list, title_weights): ax.set_axis_off() From f93bbad6812e3cb628980c103fc8dee61e3cac69 Mon Sep 17 00:00:00 2001 From: Daniel Sharp Date: Tue, 12 May 2026 20:11:51 +0100 Subject: [PATCH 60/60] final himmelblau thing --- examples/himmelblau.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/himmelblau.py b/examples/himmelblau.py index b48d99d..494b3d9 100644 --- a/examples/himmelblau.py +++ b/examples/himmelblau.py @@ -87,16 +87,16 @@ target_msip_gf = MSIPQuadGradientFree( log_density, # lambda b: spherical_MC_radial_Laguerre(b, N_spherical=5, d=2, N_radial=2) - lambda b: gauss_MC(b, 10, 2, torch.default_generator) + lambda b: gauss_MC(b, 5, 2, torch.default_generator) ) params_gf = params.copy() -params_gf['lr'] = 0.75 +params_gf['lr'] = 0.85 +params_gf['n_steps'] = 100 n_particles = 25 rng = torch.Generator() rng.manual_seed(12321) trajectories_pts_gf,trajectories_wts_gf = nak(target_msip_gf, msip, **params_gf) -# %% pts_gf = trajectories_pts_gf[-1] wts_gf = trajectories_wts_gf[-1] plt.contourf(X,Y,Z, levels=20, cmap="Grays")