diff --git a/.gitattributes b/.gitattributes
index 3e5cf230..23dcd25c 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -2,3 +2,4 @@
examples/**/*.npz filter=lfs diff=lfs merge=lfs -text
examples/**/*.npy filter=lfs diff=lfs merge=lfs -text
examples/**/*.dat filter=lfs diff=lfs merge=lfs -text
+
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 00000000..e8d1e862
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1,10 @@
+# Global rules (global owner) apply to all files in the repository.
+* @lisa-gm @vincent-maillou
+
+# Specific rules for the statistics-related files.
+/src/dalia/likelihoods/ @lisa-gm
+/src/dalia/prior_hyperparameters/ @lisa-gm
+
+# Specific rules to the computation-related files.
+/src/dalia/solvers/ @vincent-maillou
+/src/dalia/kernels/ @vincent-maillou
diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md
new file mode 100644
index 00000000..834714d5
--- /dev/null
+++ b/.github/CONTRIBUTING.md
@@ -0,0 +1,134 @@
+# Contributing to DALIA
+
+Thank you for your interest in contributing to DALIA! This document provides guidelines and instructions for contributing to the project.
+
+## Before You Start
+
+When modifying the code, ensure that the following are still working:
+- All existing tests pass
+- Code follows the established style guidelines
+- Documentation is updated as needed
+- New features include appropriate tests
+
+## General Coding Guidelines
+
+We follow the NumPy/CuPy coding style guidelines, which are derived from the [PEP8](https://peps.python.org/pep-0008/) style guide.
+
+### Development Environment Setup
+
+1. **Install pre-commit hooks**: To ensure correct formatting, install and use `pre-commit`:
+ ```bash
+ pip install pre-commit
+ pre-commit install
+ ```
+
+2. **Code formatting**: We use automated tools to maintain code quality. The pre-commit hooks will automatically check your code before each commit.
+
+### Contribution workflow
+
+The DALIA repository uses a dual-branch workflow with `main` and `dev` branches:
+
+1. **Development**: New features are developed from the `dev` branch and merged via pull request back to it
+2. **Release**: When a release is ready, changes are merged via pull request from the `dev` branch to the `main` branch
+
+#### How to contribute:
+
+1. **Fork** the DALIA repository
+2. **Create a new branch** from the `dev` branch (not `main`)
+3. **Develop** your feature or fix on your branch
+4. **Create a pull request** to merge your changes into the `dev` branch of the DALIA repository
+
+> **Note**: Always create your feature branches from `dev`, not `main`, to ensure your changes can be properly integrated.
+
+### Guidelines for commit messages
+
+Use descriptive commit messages with one of the following prefixes:
+
+#### Core Development
+- `STATS` : new feature or change related to statistical modeling
+- `SLVR` : new feature or change related to solvers or numerical methods
+- `API`: an (incompatible) API change
+- `DEP`: deprecate something, or remove a deprecated object
+- `ENH`: enhancement
+- `BUG`: bug fix
+- `CI`: continuous integration
+
+#### Documentation and Testing
+- `DOC`: documentation
+- `TST`: addition or modification of tests
+- `EXPL`: changes related to examples or tutorials
+- `DEV`: development tool or utility
+
+#### Code Quality and Maintenance
+- `MAINT`: maintenance commit (refactoring, typos, etc.)
+- `REV`: revert an earlier commit
+- `STY`: style fix (whitespace, PEP8)
+- `TYP`: static typing
+
+#### Build and Release
+- `BLD`: change related to building DALIA or its dependencies
+- `REL`: related to releasing DALIA
+
+#### Work in Progress
+- `WIP`: work in progress, do not merge
+
+**Example**: `ENH: Add new spatial kernel implementation`
+
+
+
+
+## Testing Guidelines
+
+DALIA testing relies on the [pytest](https://pytest.org/) framework.
+
+Since DALIA is designed to be as performant as possible given the available hardware (HW) backends, the testing suite is designed to separate the testing of HW-agnostic code from the testing of HW-specific code.
+
+Tests are located in the `tests/` directory. We support three levels of tests:
+
+- **`unit/`**: These tests are designed to test HW-agnostic code and use pytest's `mock` feature to mock the backends
+- **`component_integration/`**: These tests are designed to test HW-specific code and leverage specific backend features for validation
+- **`integration/`**: These tests are designed to test the full DALIA stack and are full end-to-end tests that should be run on real hardware setups
+
+For more detailed information, see the `tests/README.md` file.
+
+## Documentation Guidelines
+
+All functions and classes should be documented using the [NumPy docstring format](https://numpydoc.readthedocs.io/en/latest/format.html).
+
+### Requirements:
+- **Public functions**: Must have complete docstrings with parameters, returns, and examples
+- **Classes**: Must document the class purpose, attributes, and key methods
+- **Modules**: Should have module-level docstrings explaining their purpose
+
+### Example:
+```python
+def example_function(param1: int, param2: str) -> bool:
+ """
+ Brief description of the function.
+
+ Parameters
+ ----------
+ param1 : int
+ Description of param1.
+ param2 : str
+ Description of param2.
+
+ Returns
+ -------
+ bool
+ Description of return value.
+ """
+ pass
+```
+
+
+## Examples and Tutorials Guidelines
+
+Examples and tutorials should be placed in the `examples/` directory. They should be self-contained and demonstrate clear use cases of DALIA workflows.
+
+### Requirements for each example:
+- **README file**: Must explain the use case, expected output, and how to run the example,
+- **Self-contained**: Should include all necessary data files or data generation scripts,
+- **Clear documentation**: Code should be well-commented and easy to follow,
+- **Tested**: Examples should be verified to work with the current DALIA version.
+
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 00000000..838ff9fe
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,81 @@
+name: "Bug Report"
+description: Create a report to help us improve
+labels: [bug]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thank you for taking the time to file a bug report.
+
+ - type: textarea
+ id: bug_description
+ attributes:
+ label: Describe the Bug
+ description: Please provide a short, but precise, description of the issue.
+ render: text
+ validations:
+ required: true
+
+ - type: textarea
+ id: reproduce
+ attributes:
+ label: Reproducing the Issue
+ description: |
+ Please provide a short self-contained example that reproduces the issue,
+ if not possible, please provide a gist script that reproduces the issue.
+ placeholder: |
+ import dalia
+ << your code here >>
+ render: python
+ validations:
+ required: true
+
+ - type: textarea
+ id: error_message
+ attributes:
+ label: Error Message or Unexpected Behavior
+ description: |
+ Please include the full error message, if there is one. Otherwise, describe what you expected as output as well as what you obtained.
+ validations:
+ required: true
+
+ - type: textarea
+ id: hardware_env
+ attributes:
+ label: Hardware and Environment Details
+ description: |
+ Please specify your hardware (CPU, GPU, RAM, etc.), environment
+ configuration, and relevant package versions.
+ placeholder: |
+ Example:
+ - CPU: Intel i7-12700K
+ - GPU: NVIDIA A100
+ - RAM: 32GB
+ - OS: Ubuntu 22.04
+ render: text
+ validations:
+ required: true
+
+ - type: input
+ id: python_version
+ attributes:
+ label: Python version
+ description: e.g. 3.11
+ validations:
+ required: true
+
+ - type: input
+ id: dalia_version
+ attributes:
+ label: Dalia version
+ description: e.g. 0.1.0
+ validations:
+ required: true
+
+ - type: textarea
+ id: additional_context
+ attributes:
+ label: Additional context
+ description: Please add any other relevant context about the problem here.
+ validations:
+ required: false
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
new file mode 100644
index 00000000..64694d98
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -0,0 +1,35 @@
+name: "Feature Request"
+description: Suggest a new feature or change in functionality
+labels: [new feature]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ You would like to request a new feature or change in functionality? After
+ checking the open issues, please explain your idea below.
+
+ - type: textarea
+ id: proposed_feature
+ attributes:
+ label: Proposed feature or change
+ description: Please provide a clear description of your suggestion.
+ validations:
+ required: true
+
+ - type: textarea
+ id: how_to_implement
+ attributes:
+ label: How to implement
+ description: |
+ Please provide a clear description of how you would implement
+ this feature or change.
+ validations:
+ required: true
+
+ - type: textarea
+ id: additional_context
+ attributes:
+ label: Additional context
+ description: You can add any other relevant context here.
+ validations:
+ required: false
diff --git a/.github/ISSUE_TEMPLATE/performance_request.yml b/.github/ISSUE_TEMPLATE/performance_request.yml
new file mode 100644
index 00000000..19e44a49
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/performance_request.yml
@@ -0,0 +1,123 @@
+name: "Performance Request"
+description: Report a performance issue or suggest a performance optimization
+labels: [performance]
+body:
+ - type: markdown
+ attributes:
+ value: |
+ Thank you for helping us improve DALIA performances!
+
+ - type: textarea
+ id: performance_issue
+ attributes:
+ label: Describe the performance issue
+ description: |
+ What part of the code or workflow seems slow or could be improved
+ further? Please describe the context and what you observed.
+ placeholder: Describe the operation, function, or workflow...
+ render: text
+ validations:
+ required: true
+
+ - type: textarea
+ id: reproduce
+ attributes:
+ label: Steps to reproduce
+ description: |
+ Please provide a minimal example or steps to reproduce the behavior,
+ if not possible, please provide a gist script that reproduces the issue.
+ placeholder: List the steps or provide a code snippet...
+ render: python
+ validations:
+ required: true
+
+ - type: textarea
+ id: how_to_fix
+ attributes:
+ label: How to fix or optimize
+ description: |
+ If you have suggestions on how to solve this performance issue or
+ optimize the code, please describe them here.
+ placeholder: Describe your suggestions here...
+ render: text
+ validations:
+ required: true
+
+ - type: textarea
+ id: profiling
+ attributes:
+ label: Profiling or timing information (optional)
+ description: |
+ If available, please include any profiling output or timing measurements.
+ placeholder: Paste profiling output or timing results here...
+ render: text
+ validations:
+ required: false
+
+ - type: textarea
+ id: hardware_env
+ attributes:
+ label: Hardware and Environment Details
+ description: |
+ Please specify your hardware (CPU, GPU, RAM, etc.), environment
+ configuration, and relevant package versions.
+ placeholder: |
+ Example:
+ - CPU: Intel i7-12700K
+ - GPU: NVIDIA A100
+ - RAM: 32GB
+ - OS: Ubuntu 22.04
+ render: text
+ validations:
+ required: true
+
+ - type: input
+ id: python_version
+ attributes:
+ label: Python version
+ description: e.g. 3.11
+ validations:
+ required: true
+
+ - type: input
+ id: dalia_version
+ attributes:
+ label: Dalia version
+ description: e.g. 0.1.0
+ validations:
+ required: true
+
+ - type: textarea
+ id: other_packages
+ attributes:
+ label: Other packages
+ description: |
+ List any other relevant packages and their versions that might affect
+ performance.
+ placeholder: e.g. numpy 1.21.0, scipy 1.3.0
+ render: text
+ validations:
+ required: false
+
+ - type: input
+ id: run_command
+ attributes:
+ label: How did you run the code? (command or bash script)
+ description: |
+ Please specify the exact command, script, or notebook cell you used to
+ execute the code.
+ placeholder: |
+ e.g. python run.py, mpiexec -n 4 python run.py, or provide the bash
+ script you used.
+ validations:
+ required: true
+
+ - type: textarea
+ id: additional_context
+ attributes:
+ label: Additional context
+ description: |
+ Add any other relevant context or suggestions for optimization.
+ render: text
+ validations:
+ required: false
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 00000000..faa12557
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,117 @@
+# DALIA Pull Request
+
+
+
+
+## Submission Checklist
+- [ ] I have read the [CONTRIBUTING.md](../CONTRIBUTING.md) file
+- [ ] My code follows the project's coding style (NumPy/CuPy guidelines)
+- [ ] I have run `pre-commit` hooks and fixed any issues
+- [ ] My branch is created from the `dev` branch (not `main`)
+- [ ] My commit messages follow the project's format (e.g., `ENH: description`)
+- [ ] I have added tests for my changes (if applicable)
+- [ ] All existing tests pass
+- [ ] I have updated documentation (if applicable)
+- [ ] I have added examples or tutorials (if adding new features)
+
+## Type of Change
+Please select the type of change this PR introduces:
+
+- [ ] Bug fix (non-breaking change that fixes an issue)
+- [ ] New feature (non-breaking change that adds functionality)
+- [ ] Breaking change (fix or feature that would cause existing functionality to change)
+- [ ] Documentation update
+- [ ] Test improvement or addition
+- [ ] Code refactoring (no functional changes)
+- [ ] Performance improvement
+- [ ] Style/formatting changes
+- [ ] Build system or dependency changes
+
+## Summary of Changes
+
+
+## Description
+
+
+## Motivation and Context
+
+
+
+## How Has This Been Tested?
+
+
+
+
+
+### Tests Performed (and passed)
+- [ ] Unit tests (HW-agnostic code with mocked backends)
+- [ ] Component integration tests (HW-specific code)
+- [ ] End-to-end integration tests (full DALIA stack)
+- [ ] Performance/benchmark tests (if applicable)
+- [ ] Examples/tutorials verification (if applicable)
+
+## Performance Impact
+
+- [ ] No performance impact expected
+- [ ] Performance improvement (please quantify)
+- [ ] Potential performance regression (please justify)
+- [ ] Performance impact unknown/untested
+
+**Performance details** (if applicable):
+
+
+## API and Backwards Compatibility
+
+
+- [ ] This change maintains backwards compatibility
+- [ ] This change introduces breaking changes (requires major version bump)
+- [ ] This change adds new public APIs
+- [ ] This change modifies existing public APIs
+
+**Breaking changes** (if any):
+
+
+**New APIs** (if any):
+
+
+## Documentation Updates
+- [ ] Docstrings updated/added for new or modified functions
+- [ ] User guide documentation updated (if applicable)
+- [ ] API reference documentation updated (if applicable)
+- [ ] Examples updated to reflect changes (if applicable)
+- [ ] Changelog entry added
+
+## Copyright and Licensing
+By making a contribution to this project, I certify that:
+
+(a) The contribution was created in whole or in part by me and I
+ have the right to submit it under the BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
+
+(b) The contribution is based upon previous work that, to the best
+ of my knowledge, is covered under an appropriate open source
+ license and I have the right under that license to submit that
+ work with modifications, whether created in whole or in part
+ by me, under the same open source license (unless I am
+ permitted to submit under a different license), as indicated
+ in the file; or
+
+(c) The contribution was provided directly to me by some other
+ person who certified (a), (b) or (c) and I have not modified
+ it.
+
+(d) I understand and agree that this project and the contribution
+ are public and that a record of the contribution (including all
+ personal information I submit with it, including my sign-off) is
+ maintained indefinitely and may be redistributed consistent with
+ this project or the open source license(s) involved.
+
+---
+
+## Additional Resources
+- 📖 [Contributing Guidelines](./CONTRIBUTING.md)
+- 🧪 [Testing Documentation](../tests/README.md)
+- 📋 [Code Owners](./CODEOWNERS)
+- 🏷️ [Commit Message Guidelines](./CONTRIBUTING.md#guidelines-for-commit-messages)
+
+**Thank you for contributing to DALIA!**
+
diff --git a/.gitignore b/.gitignore
index 35e613a9..d19426db 100644
--- a/.gitignore
+++ b/.gitignore
@@ -171,6 +171,16 @@ runs_sc25/*
settings.json
*.out
+# data files that can easily be generated
+examples/brainiac/**/*.npy
+examples/brainiac/**/*.npz
+
+examples/g_ar1/**/*.npy
+examples/g_ar1/**/*.npz
+
+examples/p_ar1/**/*.npy
+examples/p_ar1/**/*.npz
+
# Profiler outputs
*.nsys-rep
*.qdstrm
diff --git a/CITATION.cff b/CITATION.cff
new file mode 100644
index 00000000..d0ef6b98
--- /dev/null
+++ b/CITATION.cff
@@ -0,0 +1,88 @@
+cff-version: 1.2.0
+message: "If you use this software, please cite it using the metadata from this file."
+title: "Accelerated Spatio-Temporal Bayesian Modeling for Multivariate Gaussian Processes"
+authors:
+ - family-names: "Gaedke-Merzhäuser"
+ given-names: "Lisa"
+ - family-names: "Maillou"
+ given-names: "Vincent"
+ - family-names: "Rodriguez Avellaneda"
+ given-names: "Fernando"
+ - family-names: "Schenk"
+ given-names: "Olaf"
+ - family-names: "Moraga"
+ given-names: "Paula"
+ - family-names: "Luisier"
+ given-names: "Mathieu"
+ - family-names: "Ziogas"
+ given-names: "Alexandros Nikolaos"
+ - family-names: "Rue"
+ given-names: "Håvard"
+year: 2025
+doi: "10.1145/3712285.3759832"
+url: "https://doi.org/10.1145/3712285.3759832"
+publisher: "Association for Computing Machinery"
+keywords:
+ - "Large-Scale Bayesian Inference"
+ - "Spatio-Temporal Modeling"
+ - "Distributed Memory Computing"
+references:
+ - type: conference-paper
+ title: "Accelerated Spatio-Temporal Bayesian Modeling for Multivariate Gaussian Processes"
+ authors:
+ - family-names: "Gaedke-Merzhäuser"
+ given-names: "Lisa"
+ - family-names: "Maillou"
+ given-names: "Vincent"
+ - family-names: "Rodriguez Avellaneda"
+ given-names: "Fernando"
+ - family-names: "Schenk"
+ given-names: "Olaf"
+ - family-names: "Moraga"
+ given-names: "Paula"
+ - family-names: "Luisier"
+ given-names: "Mathieu"
+ - family-names: "Ziogas"
+ given-names: "Alexandros Nikolaos"
+ - family-names: "Rue"
+ given-names: "Håvard"
+ year: 2025
+ doi: "10.1145/3712285.3759832"
+ url: "https://doi.org/10.1145/3712285.3759832"
+ publisher: "Association for Computing Machinery"
+ collection-title: "Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis"
+ start: 949
+ end: 972
+ isbn: "9798400714665"
+ - type: conference-paper
+ title: "Parallel Selected Inversion of Block-Tridiagonal with Arrowhead Matrices"
+ authors:
+ - family-names: "Maillou"
+ given-names: "Vincent"
+ - family-names: "Gaedke-Merzhauser"
+ given-names: "Lisa"
+ - family-names: "Ziogas"
+ given-names: "Alexandros Nikolaos"
+ - family-names: "Schenk"
+ given-names: "Olaf"
+ - family-names: "Luisier"
+ given-names: "Mathieu"
+ year: 2025
+ doi: "10.1109/CLUSTER59342.2025.11186484"
+ url: "https://doi.ieeecomputersociety.org/10.1109/CLUSTER59342.2025.11186484"
+ publisher: "IEEE Computer Society"
+ collection-title: "2025 IEEE International Conference on Cluster Computing (CLUSTER)"
+ start: 1
+ end: 12
+ keywords:
+ - "Materials science and technology"
+ - "Temperature distribution"
+ - "Computational modeling"
+ - "Graphics processing units"
+ - "Linear algebra"
+ - "Predictive models"
+ - "Libraries"
+ - "Supercomputers"
+ - "Sparse matrices"
+ - "Parallel algorithms"
+
diff --git a/DEV_README.md b/DEV_README.md
deleted file mode 100644
index 71e7ade8..00000000
--- a/DEV_README.md
+++ /dev/null
@@ -1,154 +0,0 @@
-## CSCS @ daint.alps cluster
-Here are some installation guidelines to install the project on the Piz Daint cluster of the ALPS supercomputer.
-1. Pull and start the necessary `uenv`:
-```bash
-uenv image find
-uenv repo create
-uenv image pull prgenv-gnu/24.11:v1
-uenv start --view=modules prgenv-gnu/24.11:v1
-```
-2. Load the necessary modules:
-```bash
-module load cuda
-module load gcc
-module load meson
-module load ninja
-module load nccl
-module load cray-mpich
-module load cmake
-module load openblas
-module load aws-ofi-nccl
-```
-3. Export library PATH:
-```bash
-export NCCL_ROOT=/user-environment/linux-sles15-neoverse_v2/gcc-13.3.0/nccl-2.22.3-1-4j6h3ffzysukqpqbvriorrzk2lm762dd
-export NCCL_LIB_DIR=$NCCL_ROOT/lib
-export NCCL_INCLUDE_DIR=$NCCL_ROOT/include
-export CUDA_DIR=$CUDA_HOME
-export CUDA_PATH=$CUDA_HOME
-export CPATH=$CUDA_HOME/include:$CPATH
-export LIBRARY_PATH=$CUDA_HOME/lib64:$LIBRARY_PATH
-export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
-export CPATH=$NCCL_ROOT/include:$CPATH
-export LIBRARY_PATH=$NCCL_ROOT/lib:$LIBRARY_PATH
-export LD_LIBRARY_PATH=$NCCL_ROOT/lib:$LD_LIBRARY_PATH
-```
-4. Install miniconda:
-```bash
-wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh
-chmod u+x Miniconda3-latest-Linux-aarch64.sh
-./Miniconda3-latest-Linux-aarch64.sh
-```
-5. Create the conda environment and install the required libraries:
-```bash
-conda create -n myenv
-conda activate myenv
-conda install python=3.12
-conda install numpy scipy
-MPICC=$(which mpicc) python -m pip install --no-cache-dir mpi4py
-pip install cupy --no-dependencies --no-cache-dir
-conda install -c conda-forge pytest pytest-mpi pytest-cov coverage black isort ruff just pre-commit matplotlib tabulate numba -y
-# Test the NCCL/CuPy installation
-python -c "from cupy.cuda.nccl import *"
-```
-6. (Optional) Install serinv and run the tests:
-```bash
-git clone https://github.com/vincent-maillou/serinv # https://github.com/vincent-maillou/serinv/tree/dev
-cd /path/to/serinv/
-python -m pip install -e .
-# Run the sequential tests.
-pytest .
-```
-7. Install dalia
-```bash
-cd /path/to/dalia/
-python -m pip install -e .
-```
-
-### Known Installation Issues
-The `sqlite` module might not work properly. Forcing the following version of `sqlite` might help:
-```bash
-conda install conda-forge::sqlite=3.45.3
-```
-
-## Erlangen @ Fau cluster
-Here are some installation guidelines to install the project on the Fau; Alex and Fritz clusters.
-We recommend to test any development in 3 separated environments:
-- Bare: The environment without any MPI or GPU support
-- Fritz: The environment with MPI support, CPU backend
-- Alex: The environment with MPI support, GPU backend (optional: NCCL)
-
-This ensure compatibility no matter the available backend.
-
-```bash
-# --- Alex-env ---
-module load python
-module load openmpi/4.1.6-nvhpc23.7-cuda
-module load cuda/12.6.1
-
-conda create -n alex
-conda activate alex
-
-CFLAGS=-noswitcherror MPICC=$(which mpicc) pip install --no-cache-dir mpi4py
-
-salloc --partition=a40 --nodes=1 --gres=gpu:a40:1 --time 01:00:00
-conda activate alex
-
-conda install -c conda-forge cupy-core
-conda install blas=*=*mkl
-conda install libblas=*=*mkl
-conda install numpy scipy
-conda install -c conda-forge pytest pytest-mpi pytest-cov coverage black isort ruff just pre-commit matplotlib tabulate numba -y
-
-cd /path/to/serinv/
-python -m pip install -e .
-
-cd /path/to/dalia/
-python -m pip install -e .
-```
-
-```bash
-# --- Fritz-env ---
-module load python
-module load openmpi/4.1.2-gcc11.2.0
-
-conda create -n fritz
-conda activate fritz
-
-MPICC=$(which mpicc) pip install --no-cache-dir mpi4py
-
-salloc -N 4 --time 01:00:00
-conda activate fritz
-
-conda install blas=*=*mkl
-conda install libblas=*=*mkl
-conda install numpy scipy
-conda install -c conda-forge pytest pytest-mpi pytest-cov coverage black isort ruff just pre-commit matplotlib tabulate numba -y
-
-cd /path/to/serinv/
-python -m pip install -e .
-
-cd /path/to/dalia/
-python -m pip install -e .
-```
-
-```bash
-# --- Bare-env ---
-module load python
-conda create -n bare
-
-salloc -N 4 --time 01:00:00
-conda activate bare
-
-conda install blas=*=*mkl
-conda install libblas=*=*mkl
-conda install numpy scipy
-conda install -c conda-forge pytest pytest-mpi pytest-cov coverage black isort ruff just pre-commit matplotlib tabulate numba -y
-
-cd /path/to/serinv/
-python -m pip install -e .
-
-cd /path/to/dalia/
-python -m pip install -e .
-```
-
diff --git a/README.md b/README.md
index 817aa4e0..d1684329 100644
--- a/README.md
+++ b/README.md
@@ -1,41 +1,24 @@
-[](https://github.com/psf/black)
+
# DALIA
-Python implementation of the methodology of integrated nested Laplace approximations (INLA), putting the accent on portability, modularity and performance (formerly known as PyINLA).
+[](https://github.com/psf/black)
+
+---
-If you want to help us in the developement of DALIA, you can fill the following `missing features` survey: https://forms.gle/o4CxBDcr1t73pBHbA
+Python implementation of the methodology of Integrated Nested Laplace Approximations (INLA), putting the accent on portability, modularity and performance.
If you want to get involved in the development of DALIA, please feel free to contact us directly.
## Installation
-DALIA is a python package that can be installed from its source code. You will need a working `conda` installation as well as the `Serinv` (https://github.com/vincent-maillou/serinv) solver library for accelerated solution of spatio-temporal models.
-
-You can get a working installation of `conda` on the Miniconda website: https://repo.anaconda.com/miniconda/
+Detailed installation instructions are provided in [install.md](./install.md).
-This package relies on several libraries, some of which enabling high performance computing (HPC) features and GPU acceleration. These libraries (CuPy, MPI4Py, etc.) are not required for the basic functionality of the package, but are required for the advanced features.
-
-Default required packages are:
-```bash
-conda install numpy scipy
-conda install -c conda-forge pytest pytest-mpi pytest-cov coverage black isort ruff just pre-commit matplotlib tabulate numba -y
-```
-You can then optionally install the Serinv solver (required for spatio-temporal models)
-```bash
-cd /path/to/serinv/
-python -m pip install -e .
-```
-
-And finally install the DALIA package:
-```bash
-cd /path/to/dalia/
-python -m pip install -e .
-```
+## Testing
-We then recommend you to run some of the examples provided in the `examples/` directory to test your installation.
-For more detailed installation instructions, especially on clusters, leveraging GPU acceleration through `CuPy` and distributed computing through `MPI4Py`, please refer to the [dev note](DEV_README.md) in the `DEV_README.md` file.
+The testing suite is described in detail in [tests/README.md](./tests/README.md). It relies on `pytest` and can be run either directly or through the provided `runner.sh` script.
## Examples
+
Some examples are provided with running scripts. The examples are being tracked using `git-lfs`, to download them, run the following commands:
```bash
git lfs pull
@@ -47,8 +30,54 @@ You can then navigate in the `examples/` directory and run the given examples. F
python gst_small/run.py
```
-## Known Installation Issues
-The `sqlite` module might not work properly. Forcing the following version of `sqlite` might help:
-```bash
-conda install conda-forge::sqlite=3.45.3
+Additionaly, *slurms* scripts to run the examples on different HPC clusters are provided.
+
+## Benchmarks
+
+... work in progress
+
+
+# Citing DALIA
+
+The main DALIA paper describing its high performance computing strategies is available through the following reference:
+
+``` bibtex
+@inproceedings{10.1145/3712285.3759832,
+ author = {Gaedke-Merzh\"{a}user, Lisa and Maillou, Vincent and Rodriguez Avellaneda, Fernando and Schenk, Olaf and Moraga, Paula and Luisier, Mathieu and Ziogas, Alexandros Nikolaos and Rue, H\r{a}vard},
+ title = {Accelerated Spatio-Temporal Bayesian Modeling for Multivariate Gaussian Processes},
+ year = {2025},
+ isbn = {9798400714665},
+ publisher = {Association for Computing Machinery},
+ address = {New York, NY, USA},
+ url = {https://doi.org/10.1145/3712285.3759832},
+ doi = {10.1145/3712285.3759832},
+ abstract = {Multivariate Gaussian processes (GPs) offer a powerful probabilistic framework to represent complex interdependent phenomena. They pose, however, significant computational challenges in high-dimensional settings, which frequently arise in spatio-temporal applications. We present DALIA, a highly scalable framework for performing Bayesian inference tasks on spatio-temporal multivariate GPs, based on the methodology of integrated nested Laplace approximations. Our approach relies on a sparse inverse covariance matrix formulation of the GP, puts forward a GPU-accelerated block-dense approach, and introduces a hierarchical, triple-layer, distributed-memory parallel scheme. We showcase weak-scaling performance surpassing the state of the art by two orders of magnitude on a model whose parameter space is 8 \texttimes{} larger and measure strong-scaling speedups of three orders of magnitude when running on 496 GH200 superchips on the Alps supercomputer. Applying DALIA to an air pollution study over northern Italy spanning 48 days, we showcase refined spatial resolutions over the aggregated pollutant measurements.},
+ booktitle = {Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis},
+ pages = {949–972},
+ numpages = {24},
+ keywords = {Large-Scale Bayesian Inference, Spatio-Temporal Modeling, Distributed Memory Computing},
+ location = {},
+ series = {SC '25}
+}
+```
+
+If you are using the *Serinv* solver for Spatio-Temporal modeling, please also cite the following reference:
+
+``` bibtex
+@inproceedings{11186484,
+ author = { Maillou, Vincent and Gaedke-Merzhauser, Lisa and Ziogas, Alexandros Nikolaos and Schenk, Olaf and Luisier, Mathieu },
+ booktitle = { 2025 IEEE International Conference on Cluster Computing (CLUSTER) },
+ title = {{ Parallel Selected Inversion of Block-Tridiagonal with Arrowhead Matrices }},
+ year = {2025},
+ volume = {},
+ ISSN = {},
+ pages = {1-12},
+ abstract = { The inversion of structured sparse matrices is a fundamental yet computationally and memory-intensive task in many scientific applications, such as Bayesian statistical modeling and material science. In certain cases, only particular entries of the full inverse are required. This has motivated the development of so-called selected inversion algorithms (SIA), capable of computing only specific elements of the full inverse. Currently, most SIA implementations are restricted to shared-/distributed-memory CPU architectures or to single GPUs. Here, we introduce novel numerical methods to perform the parallel selected inversion and Cholesky decomposition of positive-definite, block-tridiagonal with arrowhead matrices. A distributed memory, GPU-accelerated implementation of our approach is presented and integrated into the structured solver library Serinv. We demonstrate its performance on synthetic and real datasets from statistical air temperature prediction models and achieve CPU (GPU) speedups of up to $2.6 \times(71.4 \times)$ over the SIA of the PARDISO library and up to $14 \times(380.9 \times)$ over the MUMPS library, when scaling to 16 processes. },
+ keywords = {Materials science and technology;Temperature distribution;Computational modeling;Graphics processing units;Linear algebra;Predictive models;Libraries;Supercomputers;Sparse matrices;Parallel algorithms},
+ doi = {10.1109/CLUSTER59342.2025.11186484},
+ url = {https://doi.ieeecomputersociety.org/10.1109/CLUSTER59342.2025.11186484},
+ publisher = {IEEE Computer Society},
+ address = {Los Alamitos, CA, USA},
+ month =sep
+}
```
\ No newline at end of file
diff --git a/envs/dalia_base_aarch64.yml b/envs/dalia_base_aarch64.yml
new file mode 100644
index 00000000..3651af14
--- /dev/null
+++ b/envs/dalia_base_aarch64.yml
@@ -0,0 +1,203 @@
+name: dalia_base
+channels:
+ - conda-forge
+dependencies:
+ - _libgcc_mutex=0.1
+ - _openmp_mutex=5.1
+ - alsa-lib=1.2.14
+ - annotated-types=0.7.0
+ - aom=3.6.0
+ - black=25.1.0
+ - blas=2.138
+ - blas-devel=3.9.0
+ - brotli=1.0.9
+ - brotli-bin=1.0.9
+ - bzip2=1.0.8
+ - ca-certificates=2025.10.5
+ - cairo=1.18.4
+ - cffi=2.0.0
+ - cfgv=3.3.1
+ - click=8.3.0
+ - colorama=0.4.6
+ - contourpy=1.3.1
+ - coverage=7.11.0
+ - cycler=0.12.1
+ - cyrus-sasl=2.1.28
+ - dav1d=1.2.1
+ - dbus=1.16.2
+ - distlib=0.4.0
+ - exceptiongroup=1.3.0
+ - execnet=2.1.1
+ - expat=2.7.1
+ - filelock=3.20.0
+ - fontconfig=2.15.0
+ - fonttools=4.60.1
+ - freetype=2.13.3
+ - fribidi=1.0.10
+ - graphite2=1.3.14
+ - harfbuzz=10.2.0
+ - icu=73.1
+ - identify=2.6.15
+ - importlib-metadata=8.7.0
+ - iniconfig=2.3.0
+ - isort=7.0.0
+ - jpeg=9e
+ - just=1.39.0
+ - kiwisolver=1.4.8
+ - lcms2=2.16
+ - ld_impl_linux-aarch64=2.44
+ - lerc=4.0.0
+ - libabseil=20250127.0
+ - libavif=1.1.1
+ - libblas=3.9.0
+ - libbrotlicommon=1.0.9
+ - libbrotlidec=1.0.9
+ - libbrotlienc=1.0.9
+ - libcblas=3.9.0
+ - libclang13=20.1.8
+ - libcups=2.4.2
+ - libdeflate=1.22
+ - libdrm=2.4.124
+ - libegl=1.7.0
+ - libevent=2.1.12
+ - libffi=3.4.4
+ - libgcc-ng=11.2.0
+ - libgfortran-ng=13.2.0
+ - libgfortran5=13.2.0
+ - libgl=1.7.0
+ - libglib=2.84.4
+ - libglvnd=1.7.0
+ - libglx=1.7.0
+ - libgomp=11.2.0
+ - libiconv=1.17
+ - libkrb5=1.21.3
+ - liblapack=3.9.0
+ - liblapacke=3.9.0
+ - libllvm15=15.0.7
+ - libllvm20=20.1.8
+ - libmpdec=4.0.0
+ - libopenblas=0.3.30
+ - libopenjpeg=2.5.4
+ - libopus=1.3.1
+ - libpciaccess=0.18
+ - libpng=1.6.50
+ - libpq=17.6
+ - libre2-11=2024.07.02
+ - libsodium=1.0.20
+ - libstdcxx-ng=11.2.0
+ - libtiff=4.7.0
+ - libuuid=1.41.5
+ - libvpx=1.13.1
+ - libwebp-base=1.3.2
+ - libxcb=1.17.0
+ - libxkbcommon=1.9.1
+ - libxkbfile=1.1.0
+ - libxml2=2.13.9
+ - libxslt=1.1.43
+ - libzlib=1.3.1
+ - llvmlite=0.45.1
+ - lmdb=0.9.29
+ - lz4-c=1.9.4
+ - matplotlib=3.10.6
+ - matplotlib-base=3.10.6
+ - mesalib=25.1.5
+ - munkres=1.1.4
+ - mypy_extensions=1.1.0
+ - mysql-common=9.3.0
+ - mysql-libs=9.3.0
+ - ncurses=6.5
+ - nodeenv=1.9.1
+ - nomkl=3.0
+ - nspr=4.37
+ - nss=3.117
+ - numba=0.62.1
+ - numpy=2.3.4
+ - numpy-base=2.3.4
+ - openblas=0.3.30
+ - openblas-devel=0.3.30
+ - openjpeg=2.5.4
+ - openldap=2.6.10
+ - openssl=3.0.18
+ - packaging=25.0
+ - pathspec=0.12.1
+ - pcre2=10.46
+ - pillow=11.3.0
+ - pip=25.2
+ - pixman=0.46.4
+ - platformdirs=4.5.0
+ - pluggy=1.6.0
+ - pre-commit=4.3.0
+ - psutil=7.0.0
+ - pthread-stubs=0.3
+ - pycparser=2.22
+ - pydantic=2.12.3
+ - pydantic-core=2.41.4
+ - pygments=2.19.2
+ - pyparsing=3.2.5
+ - pyside6=6.9.2
+ - pytest=8.4.2
+ - pytest-cov=7.0.0
+ - pytest-mpi=0.6
+ - pytest-xdist=3.8.0
+ - python=3.13.9
+ - python-dateutil=2.9.0.post0
+ - python_abi=3.13
+ - pyyaml=6.0.3
+ - qtbase=6.9.2
+ - qtdeclarative=6.9.2
+ - qtshadertools=6.9.2
+ - qtsvg=6.9.2
+ - qttools=6.9.2
+ - qtwebchannel=6.9.2
+ - qtwebengine=6.9.2
+ - qtwebsockets=6.9.2
+ - re2=2024.07.02
+ - readline=8.3
+ - ruff=0.12.0
+ - scipy=1.16.2
+ - setuptools=80.9.0
+ - six=1.17.0
+ - snappy=1.2.1
+ - spirv-tools=2025.1
+ - sqlite=3.50.2
+ - tabulate=0.9.0
+ - tbb=2022.0.0
+ - tk=8.6.15
+ - tomli=2.3.0
+ - tornado=6.5.1
+ - typing-extensions=4.15.0
+ - typing-inspection=0.4.2
+ - typing_extensions=4.15.0
+ - tzdata=2025b
+ - ukkonen=1.0.1
+ - unicodedata2=16.0.0
+ - virtualenv=20.35.4
+ - wheel=0.45.1
+ - xcb-util=0.4.1
+ - xcb-util-cursor=0.1.5
+ - xcb-util-image=0.4.0
+ - xcb-util-keysyms=0.4.1
+ - xcb-util-renderutil=0.3.10
+ - xcb-util-wm=0.4.2
+ - xkeyboard-config=2.44
+ - xorg-libice=1.1.2
+ - xorg-libsm=1.2.6
+ - xorg-libx11=1.8.12
+ - xorg-libxau=1.0.12
+ - xorg-libxcomposite=0.4.6
+ - xorg-libxdamage=1.1.6
+ - xorg-libxdmcp=1.1.5
+ - xorg-libxext=1.3.6
+ - xorg-libxfixes=6.0.1
+ - xorg-libxi=1.8.2
+ - xorg-libxrandr=1.5.4
+ - xorg-libxrender=0.9.12
+ - xorg-libxshmfence=1.3.3
+ - xorg-libxtst=1.2.5
+ - xorg-libxxf86vm=1.1.6
+ - xorg-xorgproto=2024.1
+ - xz=5.6.4
+ - yaml=0.2.5
+ - zipp=3.23.0
+ - zlib=1.3.1
+ - zstd=1.5.7
diff --git a/envs/dalia_base_x86.yml b/envs/dalia_base_x86.yml
new file mode 100644
index 00000000..8d21d319
--- /dev/null
+++ b/envs/dalia_base_x86.yml
@@ -0,0 +1,182 @@
+name: dalia_base
+channels:
+ - conda-forge
+dependencies:
+ - _libgcc_mutex=0.1
+ - _openmp_mutex=4.5
+ - alsa-lib=1.2.14
+ - annotated-types=0.7.0
+ - black=25.1.0
+ - brotli=1.1.0
+ - brotli-bin=1.1.0
+ - bzip2=1.0.8
+ - ca-certificates=2025.10.5
+ - cairo=1.18.4
+ - cffi=2.0.0
+ - cfgv=3.3.1
+ - click=8.3.0
+ - colorama=0.4.6
+ - contourpy=1.3.3
+ - coverage=7.10.7
+ - cycler=0.12.1
+ - cyrus-sasl=2.1.28
+ - dbus=1.16.2
+ - distlib=0.4.0
+ - double-conversion=3.3.1
+ - exceptiongroup=1.3.0
+ - execnet=2.1.1
+ - filelock=3.20.0
+ - font-ttf-dejavu-sans-mono=2.37
+ - font-ttf-inconsolata=3.000
+ - font-ttf-source-code-pro=2.038
+ - font-ttf-ubuntu=0.83
+ - fontconfig=2.15.0
+ - fonts-conda-ecosystem=1
+ - fonts-conda-forge=1
+ - fonttools=4.60.1
+ - freetype=2.14.1
+ - graphite2=1.3.14
+ - harfbuzz=12.1.0
+ - icu=75.1
+ - identify=2.6.15
+ - importlib-metadata=8.7.0
+ - iniconfig=2.0.0
+ - isort=7.0.0
+ - just=1.43.0
+ - keyutils=1.6.3
+ - kiwisolver=1.4.9
+ - krb5=1.21.3
+ - lcms2=2.17
+ - ld_impl_linux-64=2.44
+ - lerc=4.0.0
+ - libblas=3.9.0
+ - libbrotlicommon=1.1.0
+ - libbrotlidec=1.1.0
+ - libbrotlienc=1.1.0
+ - libcblas=3.9.0
+ - libclang-cpp21.1=21.1.3
+ - libclang13=21.1.3
+ - libcups=2.3.3
+ - libdeflate=1.24
+ - libdrm=2.4.125
+ - libedit=3.1.20250104
+ - libegl=1.7.0
+ - libexpat=2.7.1
+ - libffi=3.4.6
+ - libfreetype=2.14.1
+ - libfreetype6=2.14.1
+ - libgcc=15.2.0
+ - libgcc-ng=15.2.0
+ - libgfortran=15.2.0
+ - libgfortran5=15.2.0
+ - libgl=1.7.0
+ - libglib=2.86.0
+ - libglvnd=1.7.0
+ - libglx=1.7.0
+ - libgomp=15.2.0
+ - libiconv=1.18
+ - libjpeg-turbo=3.1.0
+ - liblapack=3.9.0
+ - libllvm21=21.1.3
+ - liblzma=5.8.1
+ - libmpdec=4.0.0
+ - libntlm=1.8
+ - libopenblas=0.3.30
+ - libopengl=1.7.0
+ - libpciaccess=0.18
+ - libpng=1.6.50
+ - libpq=18.0
+ - libsqlite=3.50.4
+ - libstdcxx=15.2.0
+ - libstdcxx-ng=15.2.0
+ - libtiff=4.7.1
+ - libuuid=2.41.2
+ - libvulkan-loader=1.4.328.1
+ - libwebp-base=1.6.0
+ - libxcb=1.17.0
+ - libxcrypt=4.4.36
+ - libxkbcommon=1.12.0
+ - libxml2=2.15.0
+ - libxml2-16=2.15.0
+ - libxslt=1.1.43
+ - libzlib=1.3.1
+ - llvmlite=0.45.1
+ - matplotlib=3.10.6
+ - matplotlib-base=3.10.6
+ - munkres=1.1.4
+ - mypy_extensions=1.1.0
+ - ncurses=6.5
+ - nodeenv=1.9.1
+ - numba=0.62.1
+ - numpy=2.3.3
+ - openjpeg=2.5.4
+ - openldap=2.6.10
+ - openssl=3.5.4
+ - packaging=25.0
+ - pathspec=0.12.1
+ - pcre2=10.46
+ - pillow=11.3.0
+ - pip=25.2
+ - pixman=0.46.4
+ - platformdirs=4.5.0
+ - pluggy=1.6.0
+ - pre-commit=4.3.0
+ - psutil=7.1.0
+ - pthread-stubs=0.4
+ - pycparser=2.22
+ - pydantic=2.12.0
+ - pydantic-core=2.41.1
+ - pygments=2.19.2
+ - pyparsing=3.2.5
+ - pyside6=6.9.3
+ - pytest=8.4.2
+ - pytest-cov=7.0.0
+ - pytest-mpi=0.6
+ - pytest-xdist=3.8.0
+ - python=3.13.7
+ - python-dateutil=2.9.0.post0
+ - python_abi=3.13
+ - pyyaml=6.0.3
+ - qhull=2020.2
+ - qt6-main=6.9.3
+ - readline=8.2
+ - ruff=0.14.0
+ - scipy=1.16.2
+ - setuptools=80.9.0
+ - six=1.17.0
+ - tabulate=0.9.0
+ - tk=8.6.13
+ - tomli=2.3.0
+ - tornado=6.5.2
+ - typing-extensions=4.15.0
+ - typing-inspection=0.4.2
+ - typing_extensions=4.15.0
+ - tzdata=2025b
+ - ukkonen=1.0.1
+ - virtualenv=20.35.3
+ - wayland=1.24.0
+ - xcb-util=0.4.1
+ - xcb-util-cursor=0.1.5
+ - xcb-util-image=0.4.0
+ - xcb-util-keysyms=0.4.1
+ - xcb-util-renderutil=0.3.10
+ - xcb-util-wm=0.4.2
+ - xkeyboard-config=2.46
+ - xorg-libice=1.1.2
+ - xorg-libsm=1.2.6
+ - xorg-libx11=1.8.12
+ - xorg-libxau=1.0.12
+ - xorg-libxcomposite=0.4.6
+ - xorg-libxcursor=1.2.3
+ - xorg-libxdamage=1.1.6
+ - xorg-libxdmcp=1.1.5
+ - xorg-libxext=1.3.6
+ - xorg-libxfixes=6.0.2
+ - xorg-libxi=1.8.2
+ - xorg-libxrandr=1.5.4
+ - xorg-libxrender=0.9.12
+ - xorg-libxtst=1.2.5
+ - xorg-libxxf86vm=1.1.6
+ - yaml=0.2.5
+ - zipp=3.23.0
+ - zstd=1.5.7
\ No newline at end of file
diff --git a/examples/README_template.md b/examples/README_template.md
new file mode 100644
index 00000000..92a227b3
--- /dev/null
+++ b/examples/README_template.md
@@ -0,0 +1,10 @@
+# The Model
+
+### Overview
+
+- Please provide a brief description of the model, i.e.
+
+### Scripts
+
+
+### Usage
\ No newline at end of file
diff --git a/examples/brainiac/README.md b/examples/brainiac/README.md
new file mode 100644
index 00000000..d299f962
--- /dev/null
+++ b/examples/brainiac/README.md
@@ -0,0 +1,28 @@
+# Brainiac Model
+
+The Brainiac model is based on https://www.sciencedirect.com/science/article/pii/S1878929325000647 where it is used for the analysis of fMRI data.
+
+From a computational perspective it stands out as the projection matrix $A$ mapping the latent variables to the observations is dense, resulting in a dense conditional precision matrix $Q_c$ (while $Q_p$ is diagonal). It also has the particularity that the variance in the observations is coupled to a hyperparameter from the model. More concretely we have that
+
+$$
+Q_p([h, \alpha_1, ..., \alpha_d]) = h^2 \ \Phi(\alpha)
+$$
+while $y \sim N(0, 1-h^2 I)$.
+Thus, while fitting the LGM framework, it doesn't cohere with the usual R-INLA setup.
+More details can be found in `generate_data.py` where the dataset (and thus the model) is generated.
+
+## Scripts
+
+- **`generate_data.py`**: Generates a synthetic dataset given the dimensions specified at the top of the file.
+
+- **`run.py`**: Loads the generated dataset (requires the correct dimension to be specified at the top of the file) and runs DALIA.
+
+The GPU-backend is especially suitable for this model.
+
+## Usage
+
+```bash
+export ARRAY_MODULE=cupy
+python generate_data.py
+python run.py
+```
\ No newline at end of file
diff --git a/examples/brainiac/generate_data.py b/examples/brainiac/generate_data.py
deleted file mode 100644
index faf7597f..00000000
--- a/examples/brainiac/generate_data.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import numpy as np
-import scipy.sparse as sp
-from scipy.sparse import diags
-
-if __name__ == "__main__":
-
- np.random.seed(5)
-
- # epsilon N(0, (1-h^2)I), 0 < h^2 < 1,
- # beta prior on h^2
-
- # \beta ~ N(0, h^2 \Phi)
- # dim(Z_i) = (M,1)
- # alpha ~ N(0, \sigma_a^2 I), \alpha \in R^m
- # \sigma_a^2 large & fixed
-
- # dim(\Phi) = (b,b)
-
- no = 1000 # number of observations
- b = 2 # number of latent variables (number of features)
- m = 2 # number of annotations per feature
-
- # generate random Z -> needs to be loaded with the model
- z = np.random.rand(b, m)
- np.save("inputs_brainiac/z.npy", z)
-
- # Generate random h^2 with a Beta prior defined on the interval [0, 1]
- # TODO: how to estimate alpha_beta and beta_beta?
- # alpha_beta = 5.0
- # beta_beta = 1.0
- #h2 = np.random.beta(alpha_beta, beta_beta)
- h2 = 0.95
- print("h2: ", h2)
-
- # \sigma_a^2: large and fixed
- sigma_a2 = 1
- print("sigma_a2: ", sigma_a2)
-
- # sample alpha from N(0, \sigma_a^2 I)
- alpha = np.random.normal(2, np.sqrt(sigma_a2), (m, 1))
- #alpha = np.ones((m, 1))
- # print(alpha)
-
- theta_original = np.concatenate(([h2], alpha.flatten()))
- print(theta_original)
-
- # save original hyperparameters
- np.save("inputs_brainiac/theta_original.npy", theta_original)
-
- # \Phi = 1 / \sum_k=1^B exp(Z^k \alpha) * diag(exp(Z_1 \alpha), exp(Z_2 \alpha), ... )
- print("z : ", z)
- print("alpha : ", alpha)
- exp_Z_alpha = np.exp(z @ alpha)
- # print(exp_Z_alpha)
- sum_exp_Z_alpha = np.sum(exp_Z_alpha)
- print(sum_exp_Z_alpha)
-
- normalized_exp_Z_alpha = exp_Z_alpha / sum_exp_Z_alpha
- print("normalized_exp(Z*alpha): ", normalized_exp_Z_alpha)
-
- h2_phi = h2 * normalized_exp_Z_alpha.flatten()
- Qprior = diags(1 / h2_phi)
- print("Qprior: \n", Qprior.toarray())
-
- # save Qprior as a sparse matrix
- sp.save_npz("inputs_brainiac/Qprior_original.npz", Qprior)
-
- # sample full model: Y = a \beta + \epsilon
- # X random covariates of dimension (no, b)
- a = np.random.rand(no, b)
- # np.save("a.npy", a)
- a_sp = sp.csc_matrix(a)
- sp.save_npz("inputs_brainiac/a.npz", a_sp)
-
- # beta ~ N(0, (h^2 \Phi)^-1)
- var = 1 / Qprior.diagonal()
- print(var)
- beta = np.random.normal(0, np.sqrt(var)).reshape(b, 1)
- np.save("inputs_brainiac/beta_original.npy", beta.flatten())
- # print(beta)
-
- # beta regression parameters with
- eps = np.random.normal(0, np.sqrt(1 - h2), (no, 1))
- y = a @ beta + eps
- np.save("y.npy", y)
-
- # construct Qconditional
- Qconditional = Qprior + 1 / (1 - h2) * a_sp.T @ a_sp
- sp.save_npz("inputs_brainiac/Qconditional_original.npz", Qconditional)
-
- # recover beta
- # beta_initial = beta
- # grad_y = - 1 / (1 - h2) * (a @ beta - y)
- # information_vector = -1 * Qprior @ beta + a_sp.T @ grad_y
-
- beta_initial = np.zeros((b, 1))
- grad_y = - 1 / (1 - h2) * (-y)
- information_vector = a_sp.T @ grad_y
-
- beta_recovered = beta_initial + np.linalg.solve(Qconditional.toarray(), information_vector)
- print("beta recovered: ", beta_recovered.flatten())
- print("beta original : ", beta.flatten())
- print("norm(diff) : ", np.linalg.norm(beta_recovered - beta))
\ No newline at end of file
diff --git a/examples/brainiac/generate_synthetic_dataset.py b/examples/brainiac/generate_synthetic_dataset.py
new file mode 100644
index 00000000..6d476d59
--- /dev/null
+++ b/examples/brainiac/generate_synthetic_dataset.py
@@ -0,0 +1,111 @@
+from pathlib import Path
+from typing import Literal
+
+import numpy as np
+from scipy import sparse as sp
+
+np.random.seed(5)
+
+if __name__ == "__main__":
+
+ # Study parameters
+ n_observations: int = 100 # keep: 100
+ n_features: int = 20 # keep: 20
+ n_annotations_per_features: int = 2 # keep: 2
+
+ # Model parameters
+ h2: float = 0.9
+
+ # General parameters
+ model_format: Literal["dense", "sparse"] = "dense"
+ density: float = 0.4 # only used if model_format is "sparse"
+
+ path_script: Path = Path(__file__).parent
+ path_inputs: Path = path_script / "inputs_brainiac"
+ path_reference: Path = path_inputs / "reference"
+
+ path_inputs.mkdir(parents=True, exist_ok=True)
+ path_reference.mkdir(parents=True, exist_ok=True)
+
+ save_precision_matrices: bool = False
+
+ # 1. Generate random Z matrix
+ z = np.random.rand(n_features, n_annotations_per_features)
+
+ # 2. Sample alpha(s) from N(0, 1)
+ alpha = np.random.normal(0, 1, (n_annotations_per_features, 1))
+ print("alpha: ", alpha.flatten())
+
+ # 3. Construct reference hyperparameters vector
+ theta = np.concatenate(([h2], alpha.flatten()))
+
+ # 4. Construct reference prior precision matrix
+ normalized_exp_Z_alpha = np.exp(z @ alpha) / np.sum(np.exp(z @ alpha))
+
+ h2_phi = h2 * normalized_exp_Z_alpha.flatten()
+ Q_prior = sp.diags(1 / h2_phi)
+
+ # 5. Generate random projection matrix "a"
+ if model_format == "dense":
+ a = np.random.rand(n_observations, n_features)
+ elif model_format == "sparse":
+ a = sp.random(
+ n_observations, n_features, density=density, format="csc", dtype=np.float64
+ )
+
+ # 6. Construct reference beta vector
+ var = 1 / Q_prior.diagonal()
+ beta = np.random.normal(0, np.sqrt(var), n_features)
+
+ # 7. Generate observation vector: sample full model: Y = a beta + epsilon
+ if model_format == "dense":
+ y = a @ beta
+ elif model_format == "sparse":
+ y = a.dot(beta)
+ y += np.random.normal(0, np.sqrt(1 - h2), n_observations)
+
+ # 8. Construct reference conditional precision matrix
+ if model_format == "dense":
+ Q_conditional = Q_prior.toarray() + 1 / (1 - h2) * (a.T @ a)
+ elif model_format == "sparse":
+ Q_conditional = Q_prior + 1 / (1 - h2) * (a.T @ a).tocsc()
+
+ # Print Summary
+ print("Generated BRAINIAC synthetic dataset with the following parameters:")
+ print(f" - Number of observations: {n_observations}")
+ print(f" - Number of features (latent variables): {n_features}")
+ print(f" - Number of annotations per feature: {n_annotations_per_features}")
+ print(f" - Model format: {model_format}")
+ if model_format == "sparse":
+ print(f" - Projection matrix density: {density}")
+ print(f" - h2: {h2}")
+ print()
+ print("Saved the following files:")
+ print(f" - BRAINIAC inputs in {path_inputs.resolve()}")
+ print(f" - BRAINIAC reference outputs in {path_reference.resolve()}")
+
+ # Save BRAINIAC inputs
+ np.save(path_inputs / "z.npy", z)
+ if model_format == "dense":
+ np.save(path_inputs / "a.npy", a)
+ elif model_format == "sparse":
+ sp.save_npz(path_inputs / "a.npz", a)
+ np.save(path_script / "y.npy", y)
+ model_params = {
+ "n_observations": n_observations,
+ "n_features": n_features,
+ "n_annotations_per_features": n_annotations_per_features,
+ "h2": h2,
+ }
+ np.save(path_inputs / "model_params.npy", model_params)
+
+ # Save BRAINIAC references
+ np.save(path_reference / "theta.npy", theta)
+ np.save(path_reference / "beta.npy", beta)
+
+ if save_precision_matrices:
+ sp.save_npz(path_reference / "Q_prior.npz", Q_prior)
+ if model_format == "dense":
+ np.save(path_reference / "Q_conditional.npy", Q_conditional)
+ elif model_format == "sparse":
+ sp.save_npz(path_reference / "Q_conditional.npz", Q_conditional)
diff --git a/examples/brainiac/plotting.py b/examples/brainiac/plotting.py
new file mode 100644
index 00000000..1b0e57fa
--- /dev/null
+++ b/examples/brainiac/plotting.py
@@ -0,0 +1,242 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.stats import norm
+
+
+def plot_prior_hp(param_name, theta_interval, prior_hp, log=False):
+ """
+ Plot prior distribution of a hyperparameter.
+
+ All priors are in log-scale. Therefore exponeniate unless log=True.
+
+ Parameters
+ ----------
+ param_name : str
+ Name of the hyperparameter.
+ theta_interval: tuple of float
+ Interval (min, max) for plotting the prior.
+ prior_hp : PriorHyperparameters
+ Prior hyperparameter object.
+ log : bool, optional
+ Whether to plot in log-scale or original scale. Default is False.
+
+
+ Note
+ ----
+ If log is True, the plot will be in log-scale. Otherwise, it will be in the original scale.
+
+
+ Returns
+ -------
+ fig, ax : matplotlib Figure and Axes
+ The figure and axes objects containing the plot.
+ """
+
+ if theta_interval[0] == 0:
+ theta_interval = (1e-6, theta_interval[1])
+
+ theta_vals = np.linspace(theta_interval[0], theta_interval[1], 200)
+ prior_vals = np.array([prior_hp.evaluate_log_prior(theta) for theta in theta_vals])
+
+ if log:
+ xlabel = f"{param_name}"
+ ylabel = "Log Prior Density"
+ else:
+ prior_vals = np.exp(prior_vals)
+ xlabel = f"{param_name}"
+ ylabel = "Prior Density"
+
+ fig, ax = plt.subplots(figsize=(8, 5))
+ ax.plot(theta_vals, prior_vals, "b-", linewidth=2)
+ ax.set_xlabel(xlabel)
+ ax.set_ylabel(ylabel)
+ ax.set_title(f"Prior Distribution of {param_name}")
+ ax.grid(True, alpha=0.3)
+
+ return fig, ax
+
+
+def plot_marginal_distributions_hp(marginals_hp):
+ """Plot marginal distributions of hyperparameters in both internal and external parametrizations."""
+
+ # Get all hyperparameters
+ hyperparams = marginals_hp["hyperparameters"]
+ n_params = len(hyperparams)
+
+ # Create subplot grid: n_params rows, 2 columns (internal left, external right)
+ fig, axes = plt.subplots(n_params, 2, figsize=(15, 5 * n_params))
+
+ # Handle case of single parameter
+ if n_params == 1:
+ axes = axes.reshape(1, -1)
+
+ # Quantile colors and labels
+ colors = ["#DEB887", "#DEB887", "darkred", "#DEB887", "#DEB887"]
+ labels = ["2.5%", "25%", "50%", "75%", "97.5%"]
+
+ for row, (param_name, param_data) in enumerate(hyperparams.items()):
+ # Get internal parameters
+ mean_internal = param_data["mean_internal"]
+ var_internal = param_data["variance_internal"]
+ std_internal = np.sqrt(var_internal)
+
+ # Get external parameters
+ mean_external = param_data["mean_external"]
+ var_external = param_data["variance_external"]
+ theta_external, pdf_external = param_data["pdf_data"]
+
+ # Get quantiles
+ quantile_pairs_internal = param_data["quantiles"]["internal"]["pairs"]
+ quantile_pairs_external = param_data["quantiles"]["external"]["pairs"]
+
+ # ===== LEFT PLOT: INTERNAL PARAMETRIZATION =====
+ ax_left = axes[row, 0]
+
+ # Create internal distribution (Gaussian)
+ x_internal = np.linspace(
+ mean_internal - 4 * std_internal, mean_internal + 4 * std_internal, 100
+ )
+ pdf_internal = norm.pdf(x_internal, loc=mean_internal, scale=std_internal)
+
+ # Plot internal PDF
+ ax_left.plot(
+ x_internal, pdf_internal, "b-", linewidth=2, label="PDF (Internal)"
+ )
+
+ # Mark internal mean
+ ax_left.axvline(
+ mean_internal,
+ color="red",
+ linestyle="--",
+ linewidth=2,
+ label=f"Mean = {mean_internal:.3f}",
+ )
+
+ # Mark internal quantiles
+ for i, (prob, q_val) in enumerate(quantile_pairs_internal):
+ if i < len(labels):
+ ax_left.axvline(
+ q_val,
+ color=colors[i],
+ linestyle=":",
+ linewidth=2,
+ label=f"{labels[i]} = {q_val:.3f}",
+ )
+
+ ax_left.set_xlabel(f"{param_name} (internal scale)")
+ ax_left.set_ylabel("PDF")
+ ax_left.set_title(f"{param_name}: Internal Distribution (Gaussian)")
+ ax_left.legend()
+ ax_left.grid(True, alpha=0.3)
+
+ # ===== RIGHT PLOT: EXTERNAL PARAMETRIZATION =====
+ ax_right = axes[row, 1]
+
+ # Plot external PDF
+ ax_right.plot(theta_external, pdf_external, "b-", linewidth=2, label="PDF")
+
+ # Mark external mean
+ ax_right.axvline(
+ mean_external,
+ color="red",
+ linestyle="--",
+ linewidth=2,
+ label=f"Mean = {mean_external:.3f}",
+ )
+
+ # Mark external quantiles
+ for i, (prob, q_val) in enumerate(quantile_pairs_external):
+ if i < len(labels):
+ ax_right.axvline(
+ q_val,
+ color=colors[i],
+ linestyle=":",
+ linewidth=2,
+ label=f"{labels[i]} = {q_val:.3f}",
+ )
+
+ ax_right.set_xlabel(f"{param_name} ")
+ ax_right.set_ylabel("PDF")
+ ax_right.set_title(f"{param_name}: Marginal Distribution")
+ ax_right.legend()
+ ax_right.grid(True, alpha=0.3)
+
+ # plt.tight_layout()
+ # plt.show()
+
+ return fig, axes
+
+
+def plot_marginal_distributions_hp_external(marginals_hp, true_means=None):
+ """Plot marginal distributions of hyperparameters in both internal and external parametrizations."""
+
+ # Get all hyperparameters
+ hyperparams = marginals_hp["hyperparameters"]
+ n_params = len(hyperparams)
+
+ # Create subplot grid: 1 row, n_params columns
+ fig, axes = plt.subplots(1, n_params, figsize=(5 * n_params, 5))
+
+ # Handle case of single parameter - make it iterable
+ if n_params == 1:
+ axes = [axes]
+
+ # Quantile colors and labels
+ colors = ["#DEB887", "#DEB887", "darkblue", "#DEB887", "#DEB887"]
+ labels = ["2.5%", "25%", "50%", "75%", "97.5%"]
+
+ for col, (param_name, param_data) in enumerate(hyperparams.items()):
+ # Get external parameters
+ mean_external = param_data["mean_external"]
+ var_external = param_data["variance_external"]
+ theta_external, pdf_external = param_data["pdf_data"]
+
+ # Get quantiles
+ quantile_pairs_external = param_data["quantiles"]["external"]["pairs"]
+
+ # ===== PLOT: EXTERNAL PARAMETRIZATION =====
+ ax_right = axes[col]
+
+ # Plot external PDF
+ ax_right.plot(theta_external, pdf_external, "b-", linewidth=2, label="PDF")
+
+ # Mark external mean
+ ax_right.axvline(
+ mean_external,
+ color="darkgreen",
+ linestyle="--",
+ linewidth=2,
+ label=f"Mean = {mean_external:.3f}",
+ )
+
+ # Mark true mean if provided
+ if true_means is not None and col < len(true_means):
+ ax_right.axvline(
+ true_means[col],
+ color="red",
+ linestyle="-",
+ linewidth=2,
+ label=f"True Mean = {true_means[col]:.3f}",
+ )
+
+ # Mark external quantiles
+ for i, (prob, q_val) in enumerate(quantile_pairs_external):
+ if i < len(labels):
+ ax_right.axvline(
+ q_val,
+ color=colors[i],
+ linestyle=":",
+ linewidth=2,
+ label=f"{labels[i]} = {q_val:.3f}",
+ )
+
+ ax_right.set_xlabel(f"{param_name} ")
+ ax_right.set_ylabel("PDF")
+ ax_right.set_title(f"{param_name}: Marginal Distribution")
+ ax_right.legend()
+ ax_right.grid(True, alpha=0.3)
+
+ # plt.tight_layout()
+ # plt.show()
+
+ return fig, axes
diff --git a/examples/brainiac/run.py b/examples/brainiac/run.py
index e2a4a553..1c3e82ec 100644
--- a/examples/brainiac/run.py
+++ b/examples/brainiac/run.py
@@ -1,161 +1,130 @@
-import os
+import time
+from pathlib import Path
+import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sp
-from dalia.configs import likelihood_config, dalia_config, submodels_config
-from dalia.core.model import Model
+from dalia import xp
+from dalia.configs import dalia_config, likelihood_config, submodels_config
from dalia.core.dalia import DALIA
+from dalia.core.model import Model
from dalia.submodels import BrainiacSubModel
-from dalia.utils import scaled_logit
-
-path = os.path.dirname(__file__)
+from dalia.utils import print_msg
+from plotting import plot_marginal_distributions_hp_external
if __name__ == "__main__":
- base_dir = os.path.dirname(os.path.abspath(__file__))
-
- b = 2 # number of latent variables (number of features)
- m = 2 # number of annotations per feature
- sigma_a2 = 1.0 / 1.0
- precision_mat = sigma_a2 * np.eye(m)
-
- theta_ref = np.load(f"{base_dir}/inputs_brainiac/theta_original.npy")
- x_ref = np.load(f"{base_dir}/inputs_brainiac/beta_original.npy")
-
- initial_h2 = theta_ref[0]
- initial_alpha = theta_ref[1:]
-
- brainiac_dict = {
- "type": "brainiac",
- "input_dir": f"{base_dir}/inputs_brainiac",
- "h2": initial_h2,
- "alpha": initial_alpha,
- "ph_h2": {"type": "beta", "alpha": 5.0, "beta": 1.0},
- "ph_alpha": {
- "type": "gaussian_mvn",
- "mean": np.asarray(theta_ref[1:]),
- "precision": sp.csr_matrix(precision_mat),
- },
- }
- brainiac = BrainiacSubModel(
- config=submodels_config.parse_config(brainiac_dict),
+ print_msg(f"Running BRAINIAC model on synthetic dataset.")
+
+ path_inputs: Path = Path(__file__).parent / "inputs_brainiac"
+ path_reference: Path = path_inputs / "reference"
+
+ # 1. Load model parameters
+ model_params: dict = np.load(
+ path_inputs / "model_params.npy", allow_pickle=True
+ ).item()
+ n_observations: int = model_params["n_observations"]
+ n_features: int = model_params["n_features"]
+ n_annotations_per_features: int = model_params["n_annotations_per_features"]
+ h2: float = model_params["h2"]
+ sigma_a2: float = 5.0
+
+ # 2. Load references
+ theta_reference: np.ndarray = np.load(path_reference / "theta.npy")
+ x_reference: np.ndarray = np.load(path_reference / "beta.npy")
+
+ # 3. Create starting values for DALIA
+ initial_h2: float = theta_reference[0] - 0.1
+ initial_alpha: np.ndarray = theta_reference[1:] + 0.5 * np.random.randn(
+ n_annotations_per_features
)
- print(brainiac)
-
- print("SubModel initialized.")
- likelihood_dict = {"type": "gaussian", "fix_hyperparameters": True}
+ # 4. Initialize the Brainiac submodel and the DALIA model
+ brainiac = BrainiacSubModel(
+ config=submodels_config.parse_config(
+ {
+ "type": "brainiac",
+ "input_dir": str(path_inputs.resolve()),
+ "h2": initial_h2,
+ "alpha": initial_alpha,
+ "ph_h2": {"type": "beta", "alpha": 1.0, "beta": 1.0},
+ "ph_alpha": {
+ "type": "gaussian_mvn",
+ "mean": xp.zeros(n_annotations_per_features),
+ "precision": (1.0 / sigma_a2) * sp.eye(n_annotations_per_features),
+ },
+ }
+ )
+ )
model = Model(
submodels=[brainiac],
- likelihood_config=likelihood_config.parse_config(likelihood_dict),
+ likelihood_config=likelihood_config.parse_config(
+ {"type": "gaussian", "fix_hyperparameters": True}
+ ),
)
+ print_msg(model)
- print(model)
-
- # check dimensions
- if model.submodels[0].z.shape[0] != b or model.submodels[0].z.shape[1] != m:
- raise ValueError("Dimension mismatch in Z matrix.")
-
- print("Model initialized.")
-
- print("model.theta", model.theta)
- print("length(model.theta)", len(model.theta))
- print("model.theta_keys", model.theta_keys)
-
- eta = np.ones((model.n_observations, 1))
-
- model.construct_Q_prior()
- model.construct_Q_conditional(eta)
-
- # compare to reference solution
- Qprior_ref = sp.load_npz(f"{base_dir}/inputs_brainiac/Qprior_original.npz")
- Qcond_ref = sp.load_npz(f"{base_dir}/inputs_brainiac/Qconditional_original.npz")
-
- print("Qcond_ref\n", Qcond_ref.toarray())
- print("Qcond\n", model.Q_conditional.toarray())
-
- print(
- "norm(Qprior_ref - model.Q_prior) = ",
- np.linalg.norm((Qprior_ref - model.Q_prior).toarray()),
- )
- print(
- "norm(Qcond_ref - model.Q_conditional) = ",
- np.linalg.norm((Qcond_ref - model.Q_conditional).toarray()),
- )
-
- # Q_prior_dense = model.Q_prior.todense()
- # print("Q_prior_dense\n", Q_prior_dense)
- # Q_cond_dense = model.Q_conditional.todense()
- # print("Q_cond_dense\n", Q_cond_dense)
-
- # plt.matshow(Q_prior_dense)
- # plt.suptitle("Q_prior from brainiac model")
- # plt.savefig("Q_prior.png")
-
- # plt.matshow(Q_cond_dense)
- # plt.suptitle("Q_conditional from brainiac model")
- # plt.savefig("Q_conditional.png")
-
- dalia_dict = {
- # "solver": {"type": "serinv"},
- "solver": {"type": "dense"},
- "minimize": {
- "max_iter": 10,
- "gtol": 1e-2,
- "disp": True,
- },
- "inner_iteration_max_iter": 50,
- "eps_inner_iteration": 1e-3,
- "eps_gradient_f": 1e-3,
- "simulation_dir": ".",
- }
+ # 5. Initialize DALIA
dalia = DALIA(
model=model,
- config=dalia_config.parse_config(dalia_dict),
+ config=dalia_config.parse_config(
+ {
+ "solver": {"type": "dense"},
+ "minimize": {
+ "max_iter": 50,
+ "gtol": 1e-3,
+ "disp": True,
+ },
+ "inner_iteration_max_iter": 50,
+ "eps_inner_iteration": 1e-3,
+ "eps_gradient_f": 1e-3,
+ "simulation_dir": ".",
+ }
+ ),
)
- print("x ref: ", x_ref)
- # minimization_result = dalia.minimize()
-
- # output = dalia._evaluate_f(model.theta)
- # x = model.x
- # print("x: ", x)
-
- print("\n------ Compare to reference solution ------\n")
- # load reference solution
- # theta_ref = np.load(f"{base_dir}/inputs_brainiac/theta_original.npy")
-
- # x_ref = np.load(f"{base_dir}/inputs_brainiac/beta_original.npy")
- # x = minimization_result["x"]
- # print("\nx ", x)
- # print("x_ref", x_ref)
- # print("norm(x_ref - x) = ", np.linalg.norm(x_ref - x))
-
- results = dalia.run()
-
- print("theta_ref: ", theta_ref)
- theta = results["theta"]
- # rescale
- theta[0] = scaled_logit(theta[0], direction="backward")
- print("theta: ", theta)
- print(
- "norm(theta_ref - minimization_result['theta']) = ",
- np.linalg.norm(theta_ref - results["theta"]),
+ # 6. Run inference
+ tic = time.perf_counter()
+ result = dalia.run()
+ toc = time.perf_counter()
+ print_msg(f"DALIA finished in {toc - tic:.2f} seconds.")
+
+ # 7. Compare to reference solution
+ print_msg("\n------ Compare to reference solution ------\n")
+ print_msg("theta_reference: ", theta_reference)
+ print_msg("theta dalia:", result["theta"])
+
+ print_msg("norm(x_reference - x) = ", np.linalg.norm(xp.asarray(x_reference) - result["x"]))
+
+ # 8. Check marginal variances of latent parameters
+ var_latent_params = result["marginal_variances_latent"]
+ Q_conditional = dalia.model.construct_Q_conditional(eta=model.a @ model.x)
+
+ if sp.issparse(Q_conditional):
+ Q_inv_ref = xp.linalg.inv(Q_conditional.toarray())
+ else:
+ Q_inv_ref = xp.linalg.inv(Q_conditional)
+ print_msg(
+ f"Norm (marginal variances of latent parameters - reference): {xp.linalg.norm(var_latent_params - xp.diag(Q_inv_ref)):.4e}",
)
- # print("results['theta']: ", results["theta"])
- # print("results['f']: ", results["f"])
- # print("results['grad_f']: ", results["grad_f"])
- print("cov_theta: \n", results["cov_theta"])
- print(
- "marginal standard deviations of the hyperparameters: ",
- np.sqrt(results["cov_theta"].diagonal()),
- )
- print("mean of the latent parameters : ", results["x"])
- print(
- "marginal variances of the latent parameters: ",
- results["marginal_variances_latent"],
+ # 9. Compute marginal distributions of the hyperparameters
+ marginals_hyperparameters = dalia.marginal_distributions_hp()
+
+ # 10. Plot marginal distributions of hyperparameters
+ # fig, axes = plot_marginal_distributions_hp(marginals_hyperparameters)
+ # plt.savefig("marginal_distributions_hyperparameters.png")
+
+ fig, axes = plot_marginal_distributions_hp_external(
+ marginals_hyperparameters, theta_reference
)
+ plt.savefig("marginal_distributions_hyperparameters.png")
+
+ h2 = marginals_hyperparameters["hyperparameters"]["h2"]
+ quantile_pairs = h2["quantiles"]["external"]["pairs"]
+
+ print("Quantile pairs of h2:")
+ for p, q in quantile_pairs:
+ print(f" {p:.3f} quantile: {q:.4f}")
- print("norm(theta - theta_ref): ", np.linalg.norm(results["theta"] - theta_ref))
- print("norm(x - x_ref): ", np.linalg.norm(results["x"] - x_ref))
+ print_msg("\n--- Finished ---")
diff --git a/examples/g_ar1/generate_data.py b/examples/g_ar1/generate_data.py
new file mode 100644
index 00000000..cba3858e
--- /dev/null
+++ b/examples/g_ar1/generate_data.py
@@ -0,0 +1,105 @@
+import os
+from pathlib import Path
+
+import numpy as np
+import scipy.sparse as sp
+from scipy.sparse.linalg import spsolve, spsolve_triangular
+from scipy.sparse import csc_matrix
+from scipy.linalg import cholesky
+
+BASE_DIR: Path = Path(__file__).parent
+
+if __name__ == "__main__":
+
+ np.random.seed(5)
+ n = 1000
+
+ ## define priors
+ s2 = 5
+ tau = 1 / s2
+ phi = 0.9
+ # noise obs
+ obs_noise_prec = 100
+ theta_original = [
+ phi,
+ tau,
+ obs_noise_prec,
+ ]
+
+ denom = s2 * (1 - phi**2)
+
+ diag = [(1 + phi**2) / denom] * n
+ diag[0] = diag[-1] = 1 / denom
+ off_diag = [-phi / denom] * (n - 1)
+
+ Q = sp.diags([diag, off_diag, off_diag], [0, -1, 1])
+
+ # Compute sparse Cholesky factorization: Q = L @ L.T
+ # For tridiagonal matrix, we can use dense Cholesky on small blocks or scipy
+ Q_csc = Q.tocsc()
+
+ print("Q shape:", Q.shape, "Q nnz:", Q.nnz)
+ print("Q sparsity:", 100 * Q.nnz / (Q.shape[0] * Q.shape[1]), "%")
+ print(Q.toarray()[:6, :6])
+
+ # Method 1: Use dense Cholesky (for moderate sizes this is still efficient)
+ Q_dense = Q.toarray()
+ L_dense = cholesky(Q_dense, lower=True)
+ L = csc_matrix(L_dense)
+
+ print("L nnz:", L.nnz, "L sparsity:", 100 * L.nnz / (L.shape[0] * L.shape[1]), "%")
+
+ # Efficient sampling: generate z ~ N(0,I), then solve L @ u = z
+ z = np.random.normal(0, 1, size=n)
+
+ # Solve L @ u = z using sparse triangular solver
+ u = spsolve_triangular(L, z, lower=True)
+
+ # Verify the sampling worked correctly
+ print("Sample u statistics - mean:", np.mean(u), "std:", np.std(u), ". Should be around sqrt(s2) =", np.sqrt(s2))
+
+ intercept = 2
+
+ x = np.concatenate((u, [intercept]))
+ print("x: ", x[:10])
+
+ os.makedirs(BASE_DIR / "reference_outputs", exist_ok=True)
+ np.save(BASE_DIR / "reference_outputs" / "x_original.npy", x)
+ np.save(BASE_DIR / "reference_outputs" / "theta_original.npy", theta_original)
+
+ os.makedirs(BASE_DIR / "inputs_ar1", exist_ok=True)
+ np.save(BASE_DIR / "inputs_ar1" / "x.npy", u)
+
+ a_ar1 = sp.eye(n)
+ sp.save_npz(BASE_DIR / "inputs_ar1" / "a.npz", a_ar1)
+
+ a_regression = sp.csr_matrix(np.ones((n, 1)))
+ os.makedirs(BASE_DIR / "inputs_regression", exist_ok=True)
+ sp.save_npz(BASE_DIR / "inputs_regression" / "a.npz", a_regression)
+
+ eta = a_ar1 @ u + intercept
+
+ print("eta: ", eta[:6])
+ np.save(BASE_DIR / "inputs_ar1" / "x_original.npy", eta)
+
+ noise = np.random.normal(0, np.sqrt(1 / obs_noise_prec), size=eta.shape)
+ print("noise: ", noise[:10])
+ y = eta + noise
+ np.save(BASE_DIR / "y.npy", y)
+
+ print("y: ", y[:10])
+
+ Qprior = sp.block_diag([Q, sp.csr_matrix([[0.001]])])
+
+ a = sp.hstack([a_ar1, a_regression]) # a_ar1 #
+ Qcond = Qprior + obs_noise_prec * a.T @ a
+ print("Qcond: \n", Qcond.toarray()[:6,:6])
+
+ b = obs_noise_prec * a.T @ y
+ print("b: ", b[:10])
+ # x_est = np.linalg.solve(Qcond.toarray(), b)
+ x_est = spsolve(csc_matrix(Qcond), b)
+ print("norm(x - x_est): ", np.linalg.norm(x - x_est))
+
+ print("norm(eta - eta_est): ", np.linalg.norm(a @ x - a @ x_est))
+ print("normalized norm(eta - eta_est): ", np.linalg.norm(a @ x - a @ x_est) / np.linalg.norm(a @ x))
diff --git a/examples/g_ar1/run.py b/examples/g_ar1/run.py
new file mode 100644
index 00000000..ef42c8ea
--- /dev/null
+++ b/examples/g_ar1/run.py
@@ -0,0 +1,181 @@
+from pathlib import Path
+
+import numpy as np
+
+from dalia import xp
+from dalia.configs import likelihood_config, dalia_config, submodels_config
+from dalia.core.model import Model
+from dalia.core.dalia import DALIA
+from dalia.submodels import AR1SubModel, RegressionSubModel
+from dalia.utils import print_msg, plot_marginal_distributions_hp, plot_prior_hp # , extract_diagonal
+
+BASE_DIR: Path = Path(__file__).parent
+
+if __name__ == "__main__":
+
+ np.random.seed(3)
+
+ # load reference output
+ theta_original = np.load(BASE_DIR / "reference_outputs" / "theta_original.npy")
+ print(
+ "theta original: ",
+ theta_original,
+ )
+
+ theta_initial = theta_original #[0.6, 1.0, 3.0]
+ print("theta initial: ", theta_initial)
+
+ x_original = np.load(BASE_DIR / "reference_outputs" / "x_original.npy")
+ print("x original: ", x_original[:10])
+ print("dim(x original): ", x_original.shape)
+
+ ar1_dict = {
+ "type": "ar1",
+ "input_dir": f"{BASE_DIR}/inputs_ar1",
+ "phi": 0.5, # has to be between 0 and 1
+ "ph_phi": {"type": "beta", "alpha": 5.0, "beta": 1.0},
+ # initial guess on the precision
+ "tau": 3, # has to be positive
+ "ph_tau": {"type": "gamma", "alpha": 2.0, "beta": 1.0},
+ # initial guess on the variance
+ # "sigma2": 0.33, # has to be positive
+ # "ph_sigma2": {"type": "invgamma", "alpha": 2.0, "beta": 1.0},
+ }
+ ar1 = AR1SubModel(
+ config=submodels_config.parse_config(ar1_dict),
+ )
+
+ # Configurations of the regression submodel
+ regression_dict = {
+ "type": "regression",
+ "input_dir": f"{BASE_DIR}/inputs_regression",
+ "n_fixed_effects": 1,
+ "fixed_effects_prior_precision": 0.001,
+ }
+ regression = RegressionSubModel(
+ config=submodels_config.parse_config(regression_dict),
+ )
+
+ likelihood_dict = {
+ "type": "gaussian",
+ "prec_o": 20,
+ # "prior_hyperparameters": {
+ # "type": "penalized_complexity",
+ # "alpha": 0.01,
+ # "u": 5,
+ # },
+ "prior_hyperparameters": {"type": "gaussian", "mean": theta_original[2], "precision": 0.05},
+ }
+
+ model = Model(
+ submodels=[ar1, regression], #
+ likelihood_config=likelihood_config.parse_config(likelihood_dict),
+ )
+ print_msg(model)
+
+ # plot phi
+ # theta_interval = [0, 1]
+ # prior_hp = model.prior_hyperparameters[0]
+ # fig, ax = plot_prior_hp("phi", theta_interval, prior_hp)
+
+ # plot tau
+ theta_interval = [0, 5]
+ prior_hp = model.prior_hyperparameters[1]
+ fig, ax = plot_prior_hp("tau", theta_interval, prior_hp)
+
+ import matplotlib.pyplot as plt
+ plt.show()
+
+ Qprior = model.construct_Q_prior()
+ print("Qprior: \n", Qprior.toarray()[:6, :6])
+ Qinv = xp.linalg.inv(Qprior.toarray())
+ geom_mean = xp.exp(xp.mean(xp.log(Qinv.diagonal())))
+ print("Geometric mean of Qinv diagonal: ", geom_mean)
+
+ # in gaussian case x = 0, thus eta = 0
+ x_i = xp.zeros(model.n_latent_parameters)
+ eta = model.a @ x_i
+ Qcond = model.construct_Q_conditional(eta=eta)
+ print("Qcond: \n", Qcond.toarray()[:6, :6])
+
+ b = model.construct_information_vector(eta=eta, x_i=x_i)
+ print("b: ", b[:10])
+
+ x_est = xp.linalg.solve(Qcond.toarray(), b)
+ #print("x est: ", x_est)
+ print("norm(x_original - x_est): ", xp.linalg.norm(xp.asarray(x_original) - x_est))
+
+ # Configurations of DALIA
+ dalia_dict = {
+ "solver": {"type": "dense"},
+ "minimize": {
+ "max_iter": 100,
+ "gtol": 1e-3,
+ "disp": True,
+ "maxcor": len(model.theta_external),
+ },
+ "f_reduction_tol": 1e-3,
+ "theta_reduction_tol": 1e-4,
+ "inner_iteration_max_iter": 50,
+ "eps_inner_iteration": 1e-3,
+ "eps_gradient_f": 1e-3,
+ "simulation_dir": ".",
+ "verbosity": 0,
+ }
+
+ dalia = DALIA(
+ model=model,
+ config=dalia_config.parse_config(dalia_dict),
+ )
+
+ print("initial model theta: ", model.theta_external)
+
+ print("\nCalling DALIA.run()")
+ results = dalia.run()
+
+ print_msg("\n--- Results ---")
+
+ theta = results["theta"]
+ print("theta: ", np.round(theta, 4))
+ print("theta original: ", theta_original)
+
+ print_msg("Covariance of theta:\n", results["cov_theta_internal"])
+ print_msg(
+ "Mean of the fixed effects:\n",
+ results["x"][-model.submodels[-1].n_fixed_effects :],
+ )
+
+ # print("x: ", results["x"])
+ # print("x_original: ", x_original)
+
+ # print("eta: ", model.a @ x_original)
+ # print("eta est: ", model.a @ results["x"])
+
+ print_msg("\n--- Comparisons ---")
+ print("norm(eta - eta_est): ", xp.linalg.norm(model.a @ xp.asarray(x_original) - model.a @ results["x"]))
+ print("normalized norm(eta - eta_est): ", xp.linalg.norm(model.a @ xp.asarray(x_original) - model.a @ results["x"]) / xp.linalg.norm(model.a @ xp.asarray(x_original)))
+
+ # Compare marginal variances of latent parameters
+ var_latent_params = results["marginal_variances_latent"]
+ Qconditional = dalia.model.construct_Q_conditional(eta=model.a @ model.x)
+ Qinv_ref = xp.linalg.inv(Qconditional.toarray())
+ print_msg(
+ "Norm (marg var latent - ref): ",
+ f"{xp.linalg.norm(var_latent_params - xp.diag(Qinv_ref)):.4e}",
+ )
+
+ print_msg("\n--- Marginal distributions of the hyperparameters ---")
+ marginals_hp = dalia.marginal_distributions_hp()
+
+ fig, axes = plot_marginal_distributions_hp(marginals_hp)
+ import matplotlib.pyplot as plt
+ plt.show()
+
+ phi = marginals_hp['hyperparameters']['phi']
+ quantile_pairs = phi['quantiles']['external']['pairs']
+
+ print("Quantile pairs of phi:")
+ for p, q in quantile_pairs:
+ print(f" {p:.3f} quantile: {q:.4f}")
+
+ print_msg("\n--- Finished ---")
\ No newline at end of file
diff --git a/examples/gr/preprocessing.py b/examples/gr/generate_data.py
similarity index 84%
rename from examples/gr/preprocessing.py
rename to examples/gr/generate_data.py
index 51330182..daa4cbbf 100644
--- a/examples/gr/preprocessing.py
+++ b/examples/gr/generate_data.py
@@ -17,7 +17,7 @@
path = os.path.dirname(__file__)
if __name__ == "__main__":
- n_observations = 1000
+ n_observations = 20
n_latent_parameters = 6
z = np.random.normal(size=n_latent_parameters)
@@ -28,12 +28,12 @@
x = L_Sigma_prior @ z
a = sparse.random(n_observations, n_latent_parameters, density=0.5)
- theta_observations = np.log(3)
+ theta_observations = 2.0
print(f"theta_observations: {theta_observations}")
theta_likelihood: dict = {"theta_observations": theta_observations}
# generate x from a gaussian distribution of dimensions n_latent_parameters with mean 0 and precision exp(theta_observations)
- variance = 1 / np.exp(theta_observations)
+ variance = 1 / theta_observations
eta = a @ x
y = np.random.normal(eta, scale=np.sqrt(variance), size=n_observations)
print(f"x: {x}")
@@ -48,12 +48,12 @@
os.makedirs(f"{path}/inputs", exist_ok=True)
# save the synthetic data
- np.save(f"{path}/inputs/y.npy", y)
+ np.save(f"{path}/y.npy", y)
# save a as .npz
sparse.save_npz(f"{path}/inputs/a.npz", a)
# save original latent parameters
- np.save(f"{path}/inputs/x_original.npy", x)
+ np.save(f"{path}/reference_outputs/x_ref.npy", x)
# save original hyperparameter theta
- np.save(f"{path}/inputs/theta_original.npy", theta_observations)
+ np.save(f"{path}/reference_outputs/theta_ref.npy", theta_observations)
diff --git a/examples/gr/inputs/y.npy b/examples/gr/inputs/y.npy
new file mode 100644
index 00000000..a5755615
--- /dev/null
+++ b/examples/gr/inputs/y.npy
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f5b972c648de0244b4fe3af2aac7e87210706d8ac879007011e061c2330c0ba
+size 8128
diff --git a/examples/gr/reference_outputs/theta_ref.npy b/examples/gr/reference_outputs/theta_ref.npy
index ccb2de10..01fbab03 100644
--- a/examples/gr/reference_outputs/theta_ref.npy
+++ b/examples/gr/reference_outputs/theta_ref.npy
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:722e1a2b1465e290a283836da486e044a051dc5c7d0b420a1fc0addfd4411c1f
+oid sha256:7b9d65f6717a0911c1cdf1cc680cb167f2ef1395be2db8ca2065cb2bd8a0a8f4
size 136
diff --git a/examples/gr/reference_outputs/x_ref.npy b/examples/gr/reference_outputs/x_ref.npy
index fbf38892..87f3421e 100644
--- a/examples/gr/reference_outputs/x_ref.npy
+++ b/examples/gr/reference_outputs/x_ref.npy
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:86bdf16672dbf80add11a285c0fe68f2532436b2d232a86992144017d02fcd61
+oid sha256:4ddd3a0b4a76c50a77c276c3c7a3a37b506a633b3e31564e8fd3bf859d3e2184
size 176
diff --git a/examples/gr/run.py b/examples/gr/run.py
index ed4df6f8..0716cfee 100644
--- a/examples/gr/run.py
+++ b/examples/gr/run.py
@@ -1,18 +1,18 @@
-import sys
import os
-
-parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
-sys.path.append(parent_dir)
+import sys
import numpy as np
from dalia import xp
-from dalia.configs import likelihood_config, dalia_config, submodels_config
-from dalia.core.model import Model
+from dalia.configs import dalia_config, likelihood_config, submodels_config
from dalia.core.dalia import DALIA
+from dalia.core.model import Model
from dalia.submodels import RegressionSubModel
-from dalia.utils import extract_diagonal, get_host, print_msg
-from examples_utils.parser_utils import parse_args
+from dalia.utils import extract_diagonal, get_host, print_msg, plot_marginal_distributions_hp, plot_prior_hp
+
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.append(parent_dir)
+from examples_utils.parser_utils import parse_args # noqa: E402
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -32,11 +32,13 @@
regression = RegressionSubModel(
config=submodels_config.parse_config(regression_dict),
)
+
# Likelihood
likelihood_dict = {
"type": "gaussian",
- "prec_o": 1.5,
- "prior_hyperparameters": {"type": "gaussian", "mean": 3.5, "precision": 0.5},
+ "prec_o": 1.0,
+ "prior_hyperparameters": {"type": "gamma", "alpha": 2.0, "beta": 2.0},
+ #"prior_hyperparameters": {"type": "gaussian", "mean": 1.0, "precision": 0.5},
}
# Creation of the first model by combining the Regression submodel and the likelihood
model = Model(
@@ -45,6 +47,14 @@
)
print_msg(model)
+ ## Plot prior of hyperparameter -- identification by [0], [1], ... not amazing but works for now
+ theta_interval = [-5, 7]
+ prior_hp = model.prior_hyperparameters[0]
+
+ fig, ax = plot_prior_hp("prec_o", theta_interval, prior_hp)
+ import matplotlib.pyplot as plt
+ plt.show()
+
# Configurations of DALIA
dalia_dict = {
"solver": {"type": "dense"},
@@ -69,8 +79,9 @@
results = dalia.run()
print_msg("\n--- Results ---")
- print_msg("Theta values:\n", results["theta"])
- print_msg("Covariance of theta:\n", results["cov_theta"])
+ print_msg("Theta values external:\n", results["theta"])
+ print_msg("Theta values internal:\n", results["theta_internal"])
+ print_msg("Internal Covariance of theta:\n", results["cov_theta_internal"])
print_msg(
"Mean of the fixed effects:\n",
results["x"][-model.submodels[-1].n_fixed_effects :],
@@ -79,16 +90,17 @@
print_msg("\n--- Comparisons ---")
# Compare hyperparameters
theta_ref = xp.load(f"{BASE_DIR}/reference_outputs/theta_ref.npy")
+ print_msg("Reference theta:", theta_ref)
print_msg(
"Norm (theta - theta_ref): ",
- f"{np.linalg.norm(results['theta'] - get_host(theta_ref)):.4e}",
+ f"{xp.linalg.norm(results['theta'] - theta_ref):.4e}",
)
# Compare latent parameters
x_ref = xp.load(f"{BASE_DIR}/reference_outputs/x_ref.npy")
print_msg(
"Norm (x - x_ref): ",
- f"{np.linalg.norm(results['x'] - get_host(x_ref)):.4e}",
+ f"{xp.linalg.norm(results['x'] - x_ref):.4e}",
)
# Compare marginal variances of latent parameters
@@ -101,11 +113,28 @@
)
# Compare marginal variances of observations
- var_obs = dalia.get_marginal_variances_observations(theta=theta_ref, x_star=x_ref)
+ var_obs = dalia.get_marginal_variances_observations(
+ theta_external=theta_ref, x_star=x_ref
+ )
+
var_obs_ref = extract_diagonal(model.a @ Qinv_ref @ model.a.T)
print_msg(
"Norm (var_obs - var_obs_ref): ",
f"{xp.linalg.norm(var_obs - var_obs_ref):.4e}",
)
+ print_msg("\n--- Marginal distributions of the hyperparameters ---")
+ marginals_hp = dalia.marginal_distributions_hp()
+
+ fig, axes = plot_marginal_distributions_hp(marginals_hp)
+ import matplotlib.pyplot as plt
+ plt.savefig(f"gr_marginal_distributions_hp.png")
+
+ prec_obs = marginals_hp['hyperparameters']['prec_o']
+ quantile_pairs = prec_obs['quantiles']['external']['pairs']
+
+ print("Quantile pairs of prec_o:")
+ for p, q in quantile_pairs:
+ print(f" {p:.3f} quantile: {q:.4f}")
+
print_msg("\n--- Finished ---")
diff --git a/examples/gs_coreg2_small/run.py b/examples/gs_coreg2_small/run.py
index 1bb6ba29..7ac07622 100644
--- a/examples/gs_coreg2_small/run.py
+++ b/examples/gs_coreg2_small/run.py
@@ -168,7 +168,7 @@
print_msg("results['theta']: ", results["theta"])
- print_msg("cov_theta: \n", results["cov_theta"])
+ print_msg("Internal Covariance of theta:\n", results["cov_theta_internal"])
print_msg("mean of the fixed effects: ", results["x"][-nb:])
print_msg(
"marginal variances of the fixed effects: ",
diff --git a/examples/gs_coreg3_small/run.py b/examples/gs_coreg3_small/run.py
index 102e2176..9f056ba3 100644
--- a/examples/gs_coreg3_small/run.py
+++ b/examples/gs_coreg3_small/run.py
@@ -1,24 +1,24 @@
-import sys
import os
-
-parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
-sys.path.append(parent_dir)
+import sys
import numpy as np
from dalia import xp
from dalia.configs import (
+ dalia_config,
likelihood_config,
models_config,
- dalia_config,
submodels_config,
)
-from dalia.core.model import Model
from dalia.core.dalia import DALIA
+from dalia.core.model import Model
from dalia.models import CoregionalModel
-from dalia.utils import print_msg, get_host
from dalia.submodels import RegressionSubModel, SpatialSubModel
-from examples_utils.parser_utils import parse_args
+from dalia.utils import get_host, print_msg
+
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.append(parent_dir)
+from examples_utils.parser_utils import parse_args # noqa: E402
SEED = 63
np.random.seed(SEED)
@@ -26,7 +26,9 @@
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
if __name__ == "__main__":
- print_msg("--- Example: Gaussian Coregional (3 variates) spatial model with regression ---")
+ print_msg(
+ "--- Example: Gaussian Coregional (3 variates) spatial model with regression ---"
+ )
# Check for parsed parameters
args = parse_args()
@@ -215,7 +217,7 @@
print_msg("results['theta']: ", results["theta"])
- print_msg("cov_theta: \n", results["cov_theta"])
+ print_msg("Internal Covariance of theta:\n", results["cov_theta_internal"])
print_msg("mean of the fixed effects: ", results["x"][-nb:])
print_msg(
"marginal variances of the fixed effects: ",
@@ -224,15 +226,19 @@
print_msg("\n--- Comparisons ---")
# Compare hyperparameters
- theta_ref = np.load(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_ref.npy")
+ theta_ref = np.load(
+ f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_ref.npy"
+ )
print_msg(
"Norm (theta - theta_ref): ",
f"{np.linalg.norm(results['theta'] - get_host(theta_ref)):.4e}",
)
- x_ref = np.load(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/x_ref.npy")
+ x_ref = np.load(
+ f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/x_ref.npy"
+ )
# Compare latent parameters
- #
+ #
print_msg(
"Norm (x - x_ref): ",
f"{np.linalg.norm(results['x'] - get_host(x_ref)):.4e}",
diff --git a/examples/gs_small/run.py b/examples/gs_small/run.py
index e94f20d1..addd1921 100644
--- a/examples/gs_small/run.py
+++ b/examples/gs_small/run.py
@@ -47,25 +47,28 @@
likelihood_dict = {
"type": "gaussian",
"prec_o": 4,
- "prior_hyperparameters": {
- "type": "penalized_complexity",
- "alpha": 0.01,
- "u": 5,
- },
+ # "prior_hyperparameters": {
+ # "type": "penalized_complexity",
+ # "alpha": 0.01,
+ # "u": 5,
+ # },
+ "prior_hyperparameters": {"type": "gamma", "alpha": 2.0, "beta": 2.0},
}
model = Model(
submodels=[spatial, regression],
likelihood_config=likelihood_config.parse_config(likelihood_dict),
)
+ print_msg(model)
+
# Configurations of DALIA
dalia_dict = {
- "solver": {"type": "dense"},
+ "solver": {"type": "scipy"},
"minimize": {
"max_iter": args.max_iter,
"gtol": 1e-3,
"disp": True,
- "maxcor": len(model.theta),
+ "maxcor": len(model.theta_external),
},
"f_reduction_tol": 1e-3,
"theta_reduction_tol": 1e-4,
@@ -84,10 +87,10 @@
print_msg("\n--- Results ---")
print_msg("Theta values:\n", results["theta"])
- print_msg("Covariance of theta:\n", results["cov_theta"])
+ print_msg("Internal Covariance of theta:\n", results["cov_theta_internal"])
print_msg(
"Mean of the fixed effects:\n",
- results["x"][-model.submodels[-1].n_fixed_effects :],
+ get_host(results["x"][-model.submodels[-1].n_fixed_effects :]),
)
print_msg("\n--- Comparisons ---")
@@ -95,14 +98,14 @@
theta_ref = np.load(f"{BASE_DIR}/reference_outputs/theta_ref.npy")
print_msg(
"Norm (theta - theta_ref): ",
- f"{np.linalg.norm(results['theta'] - get_host(theta_ref)):.4e}",
+ f"{np.linalg.norm(get_host(results['theta_internal']) - get_host(theta_ref)):.4e}",
)
# Compare latent parameters
x_ref = np.load(f"{BASE_DIR}/reference_outputs/x_ref.npy")
print_msg(
"Norm (x - x_ref): ",
- f"{np.linalg.norm(results['x'] - get_host(x_ref)):.4e}",
+ f"{np.linalg.norm(get_host(results['x']) - get_host(x_ref)):.4e}",
)
# Compare marginal variances of latent parameters
diff --git a/examples/gst_coreg2_small/run.py b/examples/gst_coreg2_small/run.py
index 55328b44..79fec600 100644
--- a/examples/gst_coreg2_small/run.py
+++ b/examples/gst_coreg2_small/run.py
@@ -206,7 +206,7 @@
"max_iter": args.max_iter,
"gtol": 1e-3,
"disp": True,
- "maxcor": len(coreg_model.theta),
+ "maxcor": len(coreg_model.theta_external),
},
"f_reduction_tol": 1e-3,
"theta_reduction_tol": 1e-4,
@@ -228,7 +228,7 @@
print_msg("results['theta']: ", results["theta"])
- print_msg("cov_theta: \n", results["cov_theta"])
+ print_msg("Internal Covariance of theta:\n", results["cov_theta_internal"])
print_msg("mean of the fixed effects: ", results["x"][-nb:])
print_msg(
"marginal variances of the fixed effects: ",
@@ -243,7 +243,7 @@
print_msg(
"Norm (theta - theta_ref): ",
- f"{np.linalg.norm(results['theta'] - get_host(theta_ref)):.4e}",
+ f"{np.linalg.norm(get_host(results['theta']) - get_host(theta_ref)):.4e}",
)
x_ref = np.load(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/x_ref.npy")
@@ -254,7 +254,7 @@
# Compare latent parameters
print_msg(
"Norm (x - x_ref): ",
- f"{np.linalg.norm(results['x'] - get_host(x_ref)):.4e}",
+ f"{np.linalg.norm(get_host(results['x']) - get_host(x_ref)):.4e}",
)
# Compare marginal variances of latent parameters
@@ -265,7 +265,7 @@
Qinv_ref = xp.linalg.inv(Qconditional.toarray())
print_msg(
"Norm (marg var latent - ref): ",
- f"{np.linalg.norm(var_latent_params - xp.diag(Qinv_ref)):.4e}",
+ f"{xp.linalg.norm(var_latent_params - xp.diag(Qinv_ref)):.4e}",
)
# Compare marginal variances of observations
diff --git a/examples/gst_coreg3_small/run.py b/examples/gst_coreg3_small/run.py
index 34e305cf..b17ac251 100644
--- a/examples/gst_coreg3_small/run.py
+++ b/examples/gst_coreg3_small/run.py
@@ -1,34 +1,34 @@
-import sys
import os
-
-parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
-sys.path.append(parent_dir)
+import sys
import numpy as np
-import time
from dalia import xp
from dalia.configs import (
+ dalia_config,
likelihood_config,
models_config,
- dalia_config,
submodels_config,
)
-from dalia.core.model import Model
from dalia.core.dalia import DALIA
+from dalia.core.model import Model
from dalia.models import CoregionalModel
from dalia.submodels import RegressionSubModel, SpatioTemporalSubModel
from dalia.utils import get_host, print_msg
-from examples_utils.parser_utils import parse_args
-from examples_utils.infos_utils import summarize_sparse_matrix
-SEED = 63
-np.random.seed(SEED)
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.append(parent_dir)
+from examples_utils.parser_utils import parse_args # noqa: E402
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+SEED = 63
+xp.random.seed(SEED)
+
if __name__ == "__main__":
- print_msg("--- Example: Gaussian Coregional (3 variates) spatio-temporal model with regression ---")
+ print_msg(
+ "--- Example: Gaussian Coregional (3 variates) spatio-temporal model with regression ---"
+ )
args = parse_args()
nv = 3
@@ -40,10 +40,9 @@
n = nv * ns * nt + nb
theta_ref_file = (
- f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_interpretS_original_DALIA_perm_{dim_theta}_1.dat"
- # f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_interpretS_original_DALIA_perm_{dim_theta}_1.npy"
+ f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_ref.npy"
)
- theta_ref = np.loadtxt(theta_ref_file)
+ theta_ref = np.load(theta_ref_file)
perturbation = [
0.18197867,
@@ -80,18 +79,18 @@
"sigma_st": 0.0,
"manifold": "plane",
"ph_s": {
- "type": "gaussian",
- "mean": theta_ref[0],
+ "type": "gaussian",
+ "mean": theta_ref[0],
"precision": 0.5,
},
"ph_t": {
- "type": "gaussian",
- "mean": theta_ref[1],
+ "type": "gaussian",
+ "mean": theta_ref[1],
"precision": 0.5,
},
"ph_st": {
- "type": "gaussian",
- "mean": 0.0,
+ "type": "gaussian",
+ "mean": 0.0,
"precision": 0.5,
},
}
@@ -136,18 +135,18 @@
"sigma_st": 0.0,
"manifold": "plane",
"ph_s": {
- "type": "gaussian",
- "mean": theta_ref[3],
+ "type": "gaussian",
+ "mean": theta_ref[3],
"precision": 0.5,
},
"ph_t": {
- "type": "gaussian",
- "mean": theta_ref[4],
+ "type": "gaussian",
+ "mean": theta_ref[4],
"precision": 0.5,
},
"ph_st": {
- "type": "gaussian",
- "mean": 0.0,
+ "type": "gaussian",
+ "mean": 0.0,
"precision": 0.5,
},
}
@@ -191,18 +190,18 @@
"sigma_st": 0.0,
"manifold": "plane",
"ph_s": {
- "type": "gaussian",
- "mean": theta_ref[6],
+ "type": "gaussian",
+ "mean": theta_ref[6],
"precision": 0.5,
},
"ph_t": {
- "type": "gaussian",
- "mean": theta_ref[7],
+ "type": "gaussian",
+ "mean": theta_ref[7],
"precision": 0.5,
},
"ph_st": {
- "type": "gaussian",
- "mean": 0.0,
+ "type": "gaussian",
+ "mean": 0.0,
"precision": 0.5,
},
}
@@ -258,7 +257,6 @@
)
print_msg(coreg_model)
-
dalia_dict = {
"solver": {
"type": "serinv",
@@ -268,7 +266,7 @@
"max_iter": args.max_iter,
"gtol": 1e-3,
"disp": True,
- "maxcor": len(coreg_model.theta),
+ "maxcor": len(coreg_model.theta_external),
},
"f_reduction_tol": 1e-3,
"theta_reduction_tol": 1e-4,
@@ -290,7 +288,7 @@
print_msg("results['theta']: ", results["theta"])
- print_msg("cov_theta: \n", results["cov_theta"])
+ print_msg("Internal Covariance of theta:\n", results["cov_theta_internal"])
print_msg("mean of the fixed effects: ", results["x"][-nb:])
print_msg(
"marginal variances of the fixed effects: ",
@@ -298,19 +296,15 @@
)
print_msg("\n--- Comparisons ---")
- # Compare hyperparameters
- #theta_ref = np.load(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_ref.npy")
- theta_ref = np.loadtxt(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_interpretS_original_DALIA_perm_15_1.dat")
- np.save(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_ref.npy", theta_ref)
print_msg(
"Norm (theta - theta_ref): ",
f"{np.linalg.norm(results['theta'] - get_host(theta_ref)):.4e}",
)
- #x_ref = np.load(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/x_ref.npy")
- x_ref = np.loadtxt(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/x_ref_8499_1.dat")
- np.save(f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/x_ref.npy", x_ref)
+ x_ref = np.load(
+ f"{BASE_DIR}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/x_ref.npy"
+ )
x_ref = x_ref[dalia.model.permutation_latent_variables]
# Compare latent parameters
@@ -338,4 +332,4 @@
# f"{xp.linalg.norm(var_obs - var_obs_ref):.4e}",
# )
- print_msg("\n--- Finished ---")
\ No newline at end of file
+ print_msg("\n--- Finished ---")
diff --git a/examples/gst_large/reference_outputs/theta_ref.npy b/examples/gst_large/reference_outputs/theta_ref.npy
index 983e4c62..113200b7 100644
--- a/examples/gst_large/reference_outputs/theta_ref.npy
+++ b/examples/gst_large/reference_outputs/theta_ref.npy
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:303281b0ded419b21216bd0cefa983b10ed2b7fbea6df2110df2706807f52ff9
+oid sha256:5ba7d2f753dbe60aca3f47b2220a826eb436a0e96cf83eb592eb925e9a96b444
size 160
diff --git a/examples/gst_large/run.py b/examples/gst_large/run.py
index bb4b8ade..b377fec6 100644
--- a/examples/gst_large/run.py
+++ b/examples/gst_large/run.py
@@ -12,7 +12,7 @@
from dalia.core.dalia import DALIA
from dalia.submodels import RegressionSubModel, SpatioTemporalSubModel
from examples_utils.parser_utils import parse_args
-from dalia.utils import print_msg, get_host
+from dalia.utils import print_msg, get_host, plot_marginal_distributions_hp
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -26,9 +26,9 @@
"type": "spatio_temporal",
"input_dir": f"{BASE_DIR}/inputs_spatio_temporal",
"spatial_domain_dimension": 2,
- "r_s": -0.960279229160082,
- "r_t": -0.3068528194400548,
- "sigma_st": -2.112085713764618,
+ "r_s": -0.96,
+ "r_t": -0.31,
+ "sigma_st": -2.11,
"manifold": "sphere",
"ph_s": {"type": "penalized_complexity", "alpha": 0.01, "u": 0.5},
"ph_t": {"type": "penalized_complexity", "alpha": 0.01, "u": 5},
@@ -79,15 +79,11 @@
config=dalia_config.parse_config(dalia_dict),
)
- # print_msg("\n--- References ---")
- theta_ref = xp.array(np.load(f"{BASE_DIR}/reference_outputs/theta_ref.npy"))
- x_ref = xp.array(np.load(f"{BASE_DIR}/reference_outputs/x_ref.npy"))
-
results = dalia.run()
print_msg("\n--- Results ---")
print_msg("Theta values:\n", results["theta"])
- print_msg("Covariance of theta:\n", results["cov_theta"])
+ print_msg("Internal Covariance of theta:\n", results["cov_theta_internal"])
print_msg(
"Mean of the fixed effects:\n",
results["x"][-model.submodels[-1].n_fixed_effects:],
@@ -96,12 +92,14 @@
print_msg("\n--- Comparisons ---")
# Compare hyperparameters
+ theta_ref = np.array(np.load(f"{BASE_DIR}/reference_outputs/theta_ref.npy"))
print_msg(
"Norm (theta - theta_ref): ",
f"{np.linalg.norm(results['theta'] - get_host(theta_ref)):.4e}",
)
# Compare latent parameters
+ x_ref = np.array(np.load(f"{BASE_DIR}/reference_outputs/x_ref.npy"))
print_msg(
"Norm (x - x_ref): ",
f"{np.linalg.norm(results['x'] - get_host(x_ref)):.4e}",
diff --git a/examples/gst_medium/reference_outputs/theta_ref.npy b/examples/gst_medium/reference_outputs/theta_ref.npy
index 983e4c62..113200b7 100644
--- a/examples/gst_medium/reference_outputs/theta_ref.npy
+++ b/examples/gst_medium/reference_outputs/theta_ref.npy
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:303281b0ded419b21216bd0cefa983b10ed2b7fbea6df2110df2706807f52ff9
+oid sha256:5ba7d2f753dbe60aca3f47b2220a826eb436a0e96cf83eb592eb925e9a96b444
size 160
diff --git a/examples/gst_medium/run.py b/examples/gst_medium/run.py
index 1e1cd1fd..9bf144c5 100644
--- a/examples/gst_medium/run.py
+++ b/examples/gst_medium/run.py
@@ -8,7 +8,7 @@
from dalia.configs import likelihood_config, dalia_config, submodels_config
from dalia.core.model import Model
from dalia.core.dalia import DALIA
-from dalia.utils import print_msg, get_host
+from dalia.utils import print_msg, get_host, plot_marginal_distributions_hp
from dalia.submodels import RegressionSubModel, SpatioTemporalSubModel
from examples_utils.parser_utils import parse_args
from dalia import xp
@@ -27,12 +27,12 @@
"input_dir": f"{BASE_DIR}/inputs_spatio_temporal",
"spatial_domain_dimension": 2,
"r_s": 0.0,
- "r_t": 0.0,
- "sigma_st": 0.0,
- "manifold": "plane",
- "ph_s": {"type": "gaussian", "mean": 0.03972077083991806, "precision": 0.5},
- "ph_t": {"type": "gaussian", "mean": 2.3931471805599456, "precision": 0.5},
- "ph_st": {"type": "gaussian", "mean": 1.4379142862353824, "precision": 0.5},
+ "r_t": 2.2,
+ "sigma_st": 1.3,
+ "manifold": "sphere",
+ "ph_s": {"type": "penalized_complexity", "alpha": 0.01, "u": 0.5},
+ "ph_t": {"type": "penalized_complexity", "alpha": 0.01, "u": 5},
+ "ph_st": {"type": "penalized_complexity", "alpha": 0.01, "u": 3},
}
spatio_temporal = SpatioTemporalSubModel(
config=submodels_config.parse_config(spatio_temporal_dict),
@@ -85,15 +85,11 @@
config=dalia_config.parse_config(dalia_dict),
)
- # print_msg("\n--- References ---")
- theta_ref = xp.array(np.load(f"{BASE_DIR}/reference_outputs/theta_ref.npy"))
- x_ref = xp.array(np.load(f"{BASE_DIR}/reference_outputs/x_ref.npy"))
-
results = dalia.run()
print_msg("\n--- Results ---")
print_msg("Theta values:\n", results["theta"])
- print_msg("Covariance of theta:\n", results["cov_theta"])
+ print_msg("Internal Covariance of theta:\n", results["cov_theta_internal"])
print_msg(
"Mean of the fixed effects:\n",
results["x"][-model.submodels[-1].n_fixed_effects:],
@@ -102,15 +98,17 @@
print_msg("\n--- Comparisons ---")
# Compare hyperparameters
+ theta_ref = np.array(np.load(f"{BASE_DIR}/reference_outputs/theta_ref.npy"))
print_msg(
"Norm (theta - theta_ref): ",
- f"{np.linalg.norm(results['theta'] - get_host(theta_ref)):.4e}",
+ f"{np.linalg.norm(get_host(results['theta_internal']) - theta_ref):.4e}",
)
# Compare latent parameters
+ x_ref = np.array(np.load(f"{BASE_DIR}/reference_outputs/x_ref.npy"))
print_msg(
"Norm (x - x_ref): ",
- f"{np.linalg.norm(results['x'] - get_host(x_ref)):.4e}",
+ f"{np.linalg.norm(get_host(results['x']) - x_ref):.4e}",
)
print_msg("\n--- Finished ---")
diff --git a/examples/gst_small/README.MD b/examples/gst_small/README.MD
index 131cd72a..39787aae 100644
--- a/examples/gst_small/README.MD
+++ b/examples/gst_small/README.MD
@@ -1,18 +1,35 @@
-# Gaussian likelihood, Regression + Spatio-Temporal submodel example
+# gst_small: small Gaussian Spatio-Temporal Model
-This example contain a small spatio-temporal model.
+This example demonstrates a small-scale spatio-temporal regression model with a Gaussian likelihood. It is designed to illustrate the structure and inference workflow for models with both spatial and temporal dependencies, as well as fixed effects (regression coefficients).
---- Model ---
- n_hyperparameters: 4
- n_latent_parameters: 466
- n_observations: 4600
- likelihood: gaussian
+## Model Overview
+| Component | Value |
+|--------------------------|--------------|
+| Likelihood | Gaussian |
+| Hyperparameters | 4 |
+| Latent parameters | 466 |
+| Observations | 4600 |
+| Spatial locations (ns) | 92 |
+| Time points (nt) | 5 |
+| Manifold | Plane |
+| Fixed effects | 6 |
+| Fixed effects prior prec.| 0.001 |
---- SpatioTemporalSubModel ---
- ns: 92
- nt: 5
- manifold: plane
+## Expected Output
+Running this example will:
+- Fit the spatio-temporal Gaussian regression model to the provided data
+- Estimate posterior distributions for latent variables and hyperparameters
+- Output summary statistics and diagnostics (e.g., posterior means, variances)
+- Optionally, generate predictions and residuals for the observed data
---- RegressionSubModel ---
- n_fixed_effects: 6
- fixed_effects_prior_precision: 0.001
+## Usage
+1. Ensure all dependencies are installed (see project root README for setup instructions).
+2. Run the example script(s) in this directory to perform inference.
+3. Review the output files and logs for results and diagnostics.
+
+## Files
+- `inputs/` — Input data for the model
+- `outputs/` — Model outputs, including posterior summaries and diagnostics
+- `README.MD` — This file
+
+For more details on the model structure and inference algorithms, see the main project documentation.
diff --git a/examples/gst_small/reference_outputs/theta_ref.npy b/examples/gst_small/reference_outputs/theta_ref.npy
index 983e4c62..113200b7 100644
--- a/examples/gst_small/reference_outputs/theta_ref.npy
+++ b/examples/gst_small/reference_outputs/theta_ref.npy
@@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
-oid sha256:303281b0ded419b21216bd0cefa983b10ed2b7fbea6df2110df2706807f52ff9
+oid sha256:5ba7d2f753dbe60aca3f47b2220a826eb436a0e96cf83eb592eb925e9a96b444
size 160
diff --git a/examples/gst_small/run.py b/examples/gst_small/run.py
index 2c36f753..348e828b 100644
--- a/examples/gst_small/run.py
+++ b/examples/gst_small/run.py
@@ -11,7 +11,7 @@
from dalia.core.model import Model
from dalia.core.dalia import DALIA
from dalia.submodels import RegressionSubModel, SpatioTemporalSubModel
-from dalia.utils import get_host, print_msg, extract_diagonal
+from dalia.utils import get_host, print_msg, plot_marginal_distributions_hp
from examples_utils.parser_utils import parse_args
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -26,6 +26,7 @@
"type": "spatio_temporal",
"input_dir": f"{BASE_DIR}/inputs_spatio_temporal",
"spatial_domain_dimension": 2,
+ # These hyperparameters are in the internal scale (dalia.py/BFGS)
"r_s": 0,
"r_t": 0,
"sigma_st": 0,
@@ -52,12 +53,13 @@
likelihood_dict = {
"type": "gaussian",
"prec_o": 4,
- # "prior_hyperparameters": {"type": "gaussian", "mean": 1.4, "precision": 0.5},
- "prior_hyperparameters": {
- "type": "penalized_complexity",
- "alpha": 0.01,
- "u": 4,
- },
+ "prior_hyperparameters": {"type": "gamma", "alpha": 2.0, "beta": 2.0},
+ #"prior_hyperparameters": {"type": "gaussian", "mean": 1.4, "precision": 0.5},
+ # "prior_hyperparameters": {
+ # "type": "penalized_complexity",
+ # "alpha": 0.01,
+ # "u": 4,
+ # },
}
# Creation of the model by combining the submodels and the likelihood
@@ -74,7 +76,7 @@
"max_iter": args.max_iter,
"gtol": 1e-3,
"disp": True,
- "maxcor": len(model.theta),
+ "maxcor": len(model.theta_external),
},
"f_reduction_tol": 1e-3,
"theta_reduction_tol": 1e-4,
@@ -91,43 +93,48 @@
results = dalia.run()
print_msg("\n--- Results ---")
+ theta_ref = np.load(f"{BASE_DIR}/reference_outputs/theta_ref.npy")
+
print_msg("Theta values:\n", results["theta"])
- print_msg("Covariance of theta:\n", results["cov_theta"])
- print_msg(
- "Mean of the fixed effects:\n",
- results["x"][-model.submodels[-1].n_fixed_effects :],
- )
+ print_msg("Theta values internal:\n", results["theta_internal"])
+ print_msg("Covariance of theta:\n", results["cov_theta_internal"])
print_msg("\n--- Comparisons ---")
# Compare hyperparameters
- theta_ref = np.load(f"{BASE_DIR}/reference_outputs/theta_ref.npy")
print_msg(
"Norm (theta - theta_ref): ",
- f"{np.linalg.norm(results['theta'] - get_host(theta_ref)):.4e}",
+ f"{np.linalg.norm(get_host(results["theta"]) - theta_ref):.4e}",
)
# Compare latent parameters
x_ref = np.load(f"{BASE_DIR}/reference_outputs/x_ref.npy")
print_msg(
"Norm (x - x_ref): ",
- f"{np.linalg.norm(results['x'] - get_host(x_ref)):.4e}",
+ f"{np.linalg.norm(get_host(results['x']) - x_ref):.4e}",
)
# Compare marginal variances of latent parameters
- var_latent_params = results["marginal_variances_latent"]
+ var_latent_params = get_host(results["marginal_variances_latent"])
+ dalia.model.theta_internal = results["theta_internal"]
Qconditional = dalia.model.construct_Q_conditional(eta=model.a @ model.x)
Qinv_ref = xp.linalg.inv(Qconditional.toarray())
print_msg(
"Norm (marg var latent - ref): ",
- f"{np.linalg.norm(var_latent_params - xp.diag(Qinv_ref)):.4e}",
+ f"{np.linalg.norm(var_latent_params - get_host(xp.diag(Qinv_ref))):.4e}",
)
- # Compare marginal variances of observations
- # var_obs = dalia.get_marginal_variances_observations(theta=theta_ref, x_star=x_ref)
- # var_obs_ref = extract_diagonal(model.a @ Qinv_ref @ model.a.T)
- # print_msg(
- # "Norm (var_obs - var_obs_ref): ",
- # f"{xp.linalg.norm(var_obs - var_obs_ref):.4e}",
- # )
+ print_msg("\n--- Marginal distributions of the hyperparameters ---")
+ marginals_hp = dalia.marginal_distributions_hp()
+
+ fig, axes = plot_marginal_distributions_hp(marginals_hp)
+ import matplotlib.pyplot as plt
+ plt.savefig("gst_small_marginal_distributions_hp.png")
+
+ prec_obs = marginals_hp['hyperparameters']['prec_o']
+ quantile_pairs = prec_obs['quantiles']['external']['pairs']
+ print("Quantile pairs of prec_o:")
+ for p, q in quantile_pairs:
+ print(f" {p:.3f} quantile: {q:.4f}")
+
print_msg("\n--- Finished ---")
diff --git a/examples/p_ar1/generate_data.py b/examples/p_ar1/generate_data.py
new file mode 100644
index 00000000..bdc62a92
--- /dev/null
+++ b/examples/p_ar1/generate_data.py
@@ -0,0 +1,92 @@
+import os
+import sys
+
+import numpy as np
+import scipy.sparse as sp
+
+from scipy.stats import multivariate_normal, poisson
+
+BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+
+if __name__ == "__main__":
+
+ n = 1000
+
+ ## define priors
+ s2 = 1 # 0.7
+ tau = 1 / s2
+ ### note: phi between -1 and 1 for discrete timesteps
+ # (doesn't make sense for negative in cts case)
+
+ ## rescale beta prior 2 theta - 1
+ phi = 0.9 # 0.9
+ theta_original = [phi, tau]
+
+ denom = s2 * (1 - phi**2)
+
+ diag = [(1 + phi**2) / denom] * n
+ diag[0] = diag[-1] = 1 / denom
+ off_diag = [-phi / denom] * (n - 1)
+
+ Q = sp.diags([diag, off_diag, off_diag], [0, -1, 1])
+ L = np.linalg.cholesky(Q.toarray())
+ Cov = np.linalg.inv(Q.toarray())
+
+ print(Q.toarray()[:6, :6])
+ print(np.linalg.inv(Q.toarray())[:6, :6])
+ print(np.round(Q.toarray() @ np.linalg.inv(Q.toarray()), 6)[:6, :6])
+ # exit()
+
+ mv = multivariate_normal(mean=np.zeros(n), cov=Cov, seed=3)
+
+ intercept = 2
+ u = mv.rvs()
+ print("u: ", u[:10])
+ eta = u + intercept
+ x = np.concatenate((u, [intercept]))
+
+ os.makedirs("reference_outputs", exist_ok=True)
+ np.save("reference_outputs/x_original.npy", x)
+ np.save("reference_outputs/theta_original.npy", theta_original)
+
+ x_initial = u + np.random.normal(0, 0.3, size=len(u))
+ os.makedirs("inputs_ar1", exist_ok=True)
+ np.save("inputs_ar1/x.npy", u)
+
+ a_ar1 = sp.eye(n)
+ sp.save_npz("inputs_ar1/a.npz", a_ar1)
+
+ a_regression = sp.csr_matrix(np.ones((n, 1)))
+ os.makedirs("inputs_regression", exist_ok=True)
+ sp.save_npz("inputs_regression/a.npz", a_regression)
+
+ print("eta: ", eta[:10])
+ np.save("inputs_ar1/x_original.npy", eta)
+
+ # sample with repitition
+ E = np.random.choice([1, 2, 3], size=n, replace=True)
+ # E = [1] * n
+ np.save("e.npy", E)
+
+ y = poisson.rvs(E * np.exp(eta), random_state=3)
+ np.save("y.npy", y)
+
+ print("y[:10]: ", y[:10])
+
+ Qprior = sp.block_diag([Q, sp.csr_matrix([[0.001]])])
+ # print("Qprior : \n", Qprior.toarray())
+
+ # a = sp.hstack([a_ar1, a_regression])
+ # Qcond = Qprior + obs_noise_prec * a.T @ a
+ # print("Qcond: \n", Qcond.toarray())
+
+ # b = obs_noise_prec * a.T @ y
+ # # -xp.exp(theta) * (eta - y)
+ # # beta_initial + np.linalg.solve(
+ # # Qconditional.toarray(), information_vector
+ # # )
+ # x_est = np.linalg.solve(Qcond.toarray(), b)
+ # print("x_est: ", x_est)
+
+ # print("eta est : ", a @ x_est)
+ # print("eta : ", a @ x)
diff --git a/examples/p_ar1/run.py b/examples/p_ar1/run.py
new file mode 100644
index 00000000..730c437a
--- /dev/null
+++ b/examples/p_ar1/run.py
@@ -0,0 +1,109 @@
+import sys
+import os
+
+parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+sys.path.append(parent_dir)
+
+import numpy as np
+
+from dalia import xp
+from dalia.configs import likelihood_config, dalia_config, submodels_config
+from dalia.core.model import Model
+from dalia.core.dalia import DALIA
+from dalia.submodels import AR1SubModel, RegressionSubModel
+from dalia.utils import get_host, print_msg # , extract_diagonal
+
+
+BASE_DIR = os.path.dirname(os.path.abspath(__file__))
+
+if __name__ == "__main__":
+ n = 1000
+
+ # load reference output
+ theta_original = np.load(f"{BASE_DIR}/reference_outputs/theta_original.npy")
+ print("theta original: ", theta_original)
+
+ x_original = np.load(f"{BASE_DIR}/reference_outputs/x_original.npy")
+ print("x original: ", x_original[:10])
+ print("dim(x original): ", x_original.shape)
+
+ ar1_dict = {
+ "type": "ar1",
+ "input_dir": f"{BASE_DIR}/inputs_ar1",
+ "phi": 0.45, # has to be between 0 and 1
+ "tau": 0.5, # precision
+ "ph_phi": {"type": "beta", "alpha": 5.0, "beta": 1.0},
+ "ph_tau": {"type": "gamma", "alpha": 2.0, "beta": 0.5},
+ }
+ ar1 = AR1SubModel(
+ config=submodels_config.parse_config(ar1_dict),
+ )
+
+ # Configurations of the regression submodel
+ regression_dict = {
+ "type": "regression",
+ "input_dir": f"{BASE_DIR}/inputs_regression",
+ "n_fixed_effects": 1,
+ "fixed_effects_prior_precision": 0.001,
+ }
+ regression = RegressionSubModel(
+ config=submodels_config.parse_config(regression_dict),
+ )
+
+ likelihood_dict = {
+ "type": "poisson",
+ "input_dir": f"{BASE_DIR}",
+ }
+
+ model = Model(
+ submodels=[ar1, regression],
+ likelihood_config=likelihood_config.parse_config(likelihood_dict),
+ )
+ print_msg(model)
+
+ Qprior = model.construct_Q_prior()
+ print("Qprior: \n", Qprior.toarray())
+ Qinv = np.linalg.inv(Qprior.toarray())
+ geom_mean = np.exp(np.mean(np.log(Qinv.diagonal())))
+ print("Geometric mean of Qinv diagonal: ", geom_mean)
+
+ eta = model.a @ model.x
+ Qcond = model.construct_Q_conditional(eta=eta)
+
+ # L = np.linalg.cholesky(Qcond.toarray())
+
+ # plt.spy(Qcond, markersize=2)
+ # plt.title("Sparsity pattern of Qcond")
+ # plt.show()
+
+ # Configurations of DALIA
+ dalia_dict = {
+ "solver": {"type": "dense"},
+ "minimize": {
+ "max_iter": 100,
+ "gtol": 1e-3,
+ "disp": True,
+ "maxcor": len(model.theta_external),
+ },
+ "f_reduction_tol": 1e-3,
+ "theta_reduction_tol": 1e-4,
+ "inner_iteration_max_iter": 50,
+ "eps_inner_iteration": 1e-3,
+ "eps_gradient_f": 1e-3,
+ "simulation_dir": ".",
+ }
+
+ dalia = DALIA(
+ model=model,
+ config=dalia_config.parse_config(dalia_dict),
+ )
+
+ results = dalia.minimize()
+
+ theta_user = results["theta"]
+ print("theta_ref: ", theta_original)
+ print("theta user: ", theta_user)
+ print("norm(theta_original - theta_user): ", np.linalg.norm(theta_original - get_host(theta_user)))
+ print("norm(x_original - x): ", np.linalg.norm(x_original - get_host(results["x"])))
+ print("normalized norm: ", np.linalg.norm(x_original - get_host(results["x"])) / np.linalg.norm(x_original))
+
diff --git a/examples/pst_small/run.py b/examples/pst_small/run.py
index 97760f29..0584ad29 100644
--- a/examples/pst_small/run.py
+++ b/examples/pst_small/run.py
@@ -65,7 +65,7 @@
# Configurations of DALIA
dalia_dict = {
- "solver": {"type": "dense"},
+ "solver": {"type": "serinv"},
"minimize": {
"max_iter": args.max_iter,
"gtol": 1e-3,
diff --git a/examples/run_example_alex_fau.sh b/examples/run_example_alex_fau.sh
index 7d97765e..6450207a 100644
--- a/examples/run_example_alex_fau.sh
+++ b/examples/run_example_alex_fau.sh
@@ -1,57 +1,81 @@
-#!/bin/bash
-
-#SBATCH --job-name=dalia
+#!/bin/bash -l
+#SBATCH --job-name=dalia_alex
+#SBATCH --output=%x.%j.out
+#SBATCH --error=%x.%j.err
+#SBATCH --time=00:04:00
#SBATCH --nodes=1
-#SBATCH --time=01:00:00
-#SBATCH --gres=gpu:a100:2
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:a100:1
#SBATCH --partition=a100
-#SBATCH --constraint=a100_80
-# ##SBATCH --qos=a100multi
-# ##SBATCH --exclusive
-#SBATCH --error=%x.err #The .error file name
-#SBATCH --output=%x.out #The .output file name
-
-# --- Set Backend ---
-# The backend can be set to either 'cupy' or 'numpy'.
-export ARRAY_MODULE=cupy
+###SBATCH --constraint=a100_80
+###SBATCH --qos=a100multi
+###SBATCH --exclusive
+#SBATCH --export=NONE
-export MPI_CUDA_AWARE=0
-export USE_NCCL=0
+# Change to examples directory
+if [[ "$(basename "$(pwd)")" != "examples" ]]; then
+ echo ""
+ echo "Error: Not in examples directory"
+ echo " Current directory: $(pwd)"
+ echo " Please run this script from the examples/ directory"
+ echo ""
+ exit 1
+fi
-TIMESTAMP=$(date +"%H-%M-%S")
+# Set DALIA environment variables for examples
+source ../scripts/alex_fau_utils.sh && alex_load_modules && alex_activate_conda_env && alex_set_perfenv
+source ../scripts/dalia_job_utils.sh && dalia_set_perfenv && dalia_print_job_config
# --- How to Run ---
# This run script is designed to run on Alex at NHR@FAU
# It uses SLURM for job scheduling and assumes that the user has a working
-# installation of DALIA and its dependencies. By default, DALIA will exploit
-# job parallelism at the parallel function evaluation level.
+# installation of DALIA and its dependencies. By default, DALIA will exploit
+# job parallelism in a cascade, first at the function evaluation level,
+# then at the precision matrix level, finally at the structured solver level.
# --- Parameters ---
# `--solver_min_p` : The minimum number of Processes(/GPUs) to use for the structured
# solver. The default is 1. The maximum number of processes is
# `--max_iter` : The maximum number of iterations of the minimization.
-base_dir=~
+# --- Brainiac Example ---
+srun python ./brainiac/run.py --max_iter 100
+
+# --- Gaussian AR1 Example ---
+# srun python ./g_ar1/run.py --max_iter 100
+
+# --- Gaussian Regression Example ---
+# srun python ./gr/run.py --max_iter 100
+
+# --- Gaussian Spatial Coregional 2 Models (Small) Example ---
+# srun python ./gs_coreg2_small/run.py --max_iter 100
+
+# --- Gaussian Spatial Coregional 3 Models (Small) Example ---
+# srun python ./gs_coreg3_small/run.py --max_iter 100
+
+# --- Gaussian Spatial Model (Small) Example ---
+# srun python ./gs_small/run.py --max_iter 100
+
+# --- Gaussian Spatio-temporal Coregional 2 Models (Small) Example ---
+# srun python ./gst_coreg2_small/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Coregional 3 Models (Small) Example ---
+# srun python ./gst_coreg3_small/run.py --solver_min_p 1 --max_iter 100
-# --- Run Regression Example ---
-srun python ${base_dir}/DALIA/examples/gr/run.py --max_iter 100
+# --- Gaussian Spatio-temporal Model (Large) Example ---
+# srun python ./gst_large/run.py --solver_min_p 1 --max_iter 100
-# --- Run Spatial Examples ---
-srun python ${base_dir}/DALIA/examples/gs_small/run.py --max_iter 100
+# --- Gaussian Spatio-temporal Model (Medium) Example ---
+# srun python ./gst_medium/run.py --solver_min_p 1 --max_iter 100
-# --- Run Spatio-temporal Examples ---
-srun python ${base_dir}/DALIA/examples/gst_small/run.py --solver_min_p 1 --max_iter 100
-srun python ${base_dir}/DALIA/examples/gst_medium/run.py --solver_min_p 1 --max_iter 100
-srun python ${base_dir}/DALIA/examples/gst_large/run.py --solver_min_p 1 --max_iter 100
+# --- Gaussian Spatio-temporal Model (Small) Example ---
+# srun python ./gst_small/run.py --solver_min_p 1 --max_iter 100
-# --- Run Coregional (Spatial) Examples ---
-srun python ${base_dir}/DALIA/examples/gs_coreg2_small/run.py --max_iter 100
-srun python ${base_dir}/DALIA/examples/gs_coreg3_small/run.py --max_iter 100
+# --- Poisson AR1 Example ---
+# srun python ./p_ar1/run.py --max_iter 100
-# --- Run Coregional (Spatio-temporal) Examples ---
-srun python ${base_dir}/DALIA/examples/gst_coreg2_small/run.py --solver_min_p 1 --max_iter 100
-srun python ${base_dir}/DALIA/examples/gst_coreg3_small/run.py --solver_min_p 1 --max_iter 100
+# --- Poisson Regression Example ---
+# srun python ./pr/run.py --max_iter 100
-# --- Run Poisson Examples ---
-srun python ${base_dir}/DALIA/examples/pr/run.py --max_iter 100
-srun python ${base_dir}/DALIA/examples/pst_small/run.py --max_iter 100
+# --- Poisson Spatio-temporal Model (Small) Example ---
+# srun python ./pst_small/run.py --solver_min_p 1 --max_iter 100
\ No newline at end of file
diff --git a/examples/run_example_daint_alps.sh b/examples/run_example_daint_alps.sh
deleted file mode 100644
index 4d65e2c2..00000000
--- a/examples/run_example_daint_alps.sh
+++ /dev/null
@@ -1,63 +0,0 @@
-#!/bin/bash -l
-#SBATCH --job-name="examples"
-#SBATCH --output=%x.%j.out
-#SBATCH --error=%x.%j.err
-#SBATCH --account=sm96
-#SBATCH --time=00:05:00
-#SBATCH --nodes=1
-#SBATCH --ntasks-per-node=1
-#SBATCH --cpus-per-task=64
-#SBATCH --gpus-per-task=1
-####SBATCH --partition=normal
-#SBATCH --partition=debug
-#SBATCH --constraint=gpu
-#SBATCH --hint=nomultithread
-#SBATCH --uenv=prgenv-gnu/24.11:v1
-#SBATCH --view=modules
-
-set -e -u
-
-export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
-export MPICH_GPU_SUPPORT_ENABLED=1
-
-export NCCL_NET='AWS Libfabric'
-export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
-
-source ~/load_modules.sh
-conda activate allin
-
-export ARRAY_MODULE=cupy
-export MPI_CUDA_AWARE=0
-
-# --- How to Run ---
-# This run script is designed to run on the Daint supercomputer at CSCS.
-# It uses SLURM for job scheduling and assumes that the user has a working
-# installation of DALIA and its dependencies. By default, DALIA will exploit
-# job parallelism at the parallel function evaluation level.
-
-# --- Parameters ---
-# `--solver_min_p` : The minimum number of Processes(/GPUs) to use for the structured
-# solver. The default is 1. The maximum number of processes is
-# `--max_iter` : The maximum number of iterations of the minimization.
-
-base_dir=~
-
-# --- Run Regression Example ---
-#srun python ${base_dir}/DALIA/examples/regression/run.py --max_iter 100
-
-# --- Run Spatial Examples ---
-# srun python ${base_dir}/DALIA/examples/gs_small/run.py --max_iter 100
-
-# --- Run Spatio-temporal Examples ---
-# srun python ${base_dir}/DALIA/examples/gst_small/run.py --solver_min_p 1 --max_iter 100
-# srun python ${base_dir}/DALIA/examples/gst_medium/run.py --solver_min_p 1 --max_iter 100
-# srun python ${base_dir}/DALIA/examples/gst_large/run.py --solver_min_p 1 --max_iter 100
-
-# --- Run Coregional (Spatial) Examples ---
-srun python ${base_dir}/DALIA/examples/gs_coreg2_small/run.py --max_iter 100
-# srun python ${base_dir}/DALIA/examples/gs_coreg3_small/run.py --max_iter 100
-
-# --- Run Coregional (Spatio-temporal) Examples ---
-# srun python ${base_dir}/DALIA/examples/gst_coreg2_small/run.py --solver_min_p 1 --max_iter 100
-# srun python ${base_dir}/DALIA/examples/gst_coreg3_small/run.py --solver_min_p 1 --max_iter 100
-
diff --git a/examples/run_example_daint_cscs.sh b/examples/run_example_daint_cscs.sh
new file mode 100644
index 00000000..bbb47e96
--- /dev/null
+++ b/examples/run_example_daint_cscs.sh
@@ -0,0 +1,83 @@
+#!/bin/bash -l
+#SBATCH --job-name="dalia_daint"
+#SBATCH --output=%x.%j.out
+#SBATCH --error=%x.%j.err
+#SBATCH --account=lp16
+#SBATCH --time=00:05:00
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=64
+#SBATCH --gpus-per-task=1
+#SBATCH --partition=debug
+#SBATCH --constraint=gpu
+#SBATCH --hint=nomultithread
+#SBATCH --uenv=prgenv-gnu/25.6:v2
+#SBATCH --view=modules
+
+# Change to examples directory
+if [[ "$(basename "$(pwd)")" != "examples" ]]; then
+ echo ""
+ echo "Error: Not in examples directory"
+ echo " Current directory: $(pwd)"
+ echo " Please run this script from the examples/ directory"
+ echo ""
+ exit 1
+fi
+
+# Set DALIA environment variables for examples
+source ../scripts/daint_cscs_utils.sh && daint_load_modules && daint_activate_conda_env && daint_set_perfenv
+source ../scripts/dalia_job_utils.sh && dalia_set_perfenv && dalia_print_job_config
+
+# --- How to Run ---
+# This run script is designed to run on the Daint supercomputer at CSCS.
+# It uses SLURM for job scheduling and assumes that the user has a working
+# installation of DALIA and its dependencies. By default, DALIA will exploit
+# job parallelism in a cascade, first at the function evaluation level,
+# then at the precision matrix level, finally at the structured solver level.
+
+# --- Parameters ---
+# `--solver_min_p` : The minimum number of Processes(/GPUs) to use for the structured
+# solver. The default is 1. The maximum number of processes is
+# `--max_iter` : The maximum number of iterations of the minimization.
+
+# --- Brainiac Example ---
+srun python ./brainiac/run.py --max_iter 100
+
+# --- Gaussian AR1 Example ---
+# srun python ./g_ar1/run.py --max_iter 100
+
+# --- Gaussian Regression Example ---
+# srun python ./gr/run.py --max_iter 100
+
+# --- Gaussian Spatial Coregional 2 Models (Small) Example ---
+# srun python ./gs_coreg2_small/run.py --max_iter 100
+
+# --- Gaussian Spatial Coregional 3 Models (Small) Example ---
+# srun python ./gs_coreg3_small/run.py --max_iter 100
+
+# --- Gaussian Spatial Model (Small) Example ---
+# srun python ./gs_small/run.py --max_iter 100
+
+# --- Gaussian Spatio-temporal Coregional 2 Models (Small) Example ---
+# srun python ./gst_coreg2_small/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Coregional 3 Models (Small) Example ---
+# srun python ./gst_coreg3_small/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Model (Large) Example ---
+# srun python ./gst_large/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Model (Medium) Example ---
+# srun python ./gst_medium/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Model (Small) Example ---
+# srun python ./gst_small/run.py --solver_min_p 1 --max_iter 100
+
+# --- Poisson AR1 Example ---
+# srun python ./p_ar1/run.py --max_iter 100
+
+# --- Poisson Regression Example ---
+# srun python ./pr/run.py --max_iter 100
+
+# --- Poisson Spatio-temporal Model (Small) Example ---
+# srun python ./pst_small/run.py --solver_min_p 1 --max_iter 100
\ No newline at end of file
diff --git a/examples/run_example_fritz_fau.sh b/examples/run_example_fritz_fau.sh
new file mode 100644
index 00000000..fa1d414b
--- /dev/null
+++ b/examples/run_example_fritz_fau.sh
@@ -0,0 +1,79 @@
+#!/bin/bash -l
+#SBATCH --job-name="dalia_fritz"
+#SBATCH --output=%x.%j.out
+#SBATCH --error=%x.%j.err
+#SBATCH --time=00:08:00
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --cpus-per-task=32
+#SBATCH --partition=spr2tb
+#SBATCH --hint=nomultithread
+#SBATCH --export=NONE
+
+# Change to examples directory
+if [[ "$(basename "$(pwd)")" != "examples" ]]; then
+ echo ""
+ echo "Error: Not in examples directory"
+ echo " Current directory: $(pwd)"
+ echo " Please run this script from the examples/ directory"
+ echo ""
+ exit 1
+fi
+
+# Set DALIA environment variables for examples
+source ../scripts/fritz_fau_utils.sh && fritz_load_modules && fritz_activate_conda_env && fritz_set_perfenv
+source ../scripts/dalia_job_utils.sh && dalia_set_perfenv && dalia_print_job_config
+
+# --- How to Run ---
+# This run script is designed to run on Fritz at NHR@FAU
+# It uses SLURM for job scheduling and assumes that the user has a working
+# installation of DALIA and its dependencies. By default, DALIA will exploit
+# job parallelism in a cascade, first at the function evaluation level,
+# then at the precision matrix level, finally at the structured solver level.
+
+# --- Parameters ---
+# `--solver_min_p` : The minimum number of Processes(/GPUs) to use for the structured
+# solver. The default is 1. The maximum number of processes is
+# `--max_iter` : The maximum number of iterations of the minimization.
+
+# --- Brainiac Example ---
+srun python ./brainiac/run.py --max_iter 100
+
+# --- Gaussian AR1 Example ---
+# srun python ./g_ar1/run.py --max_iter 100
+
+# --- Gaussian Regression Example ---
+# srun python ./gr/run.py --max_iter 100
+
+# --- Gaussian Spatial Coregional 2 Models (Small) Example ---
+# srun python ./gs_coreg2_small/run.py --max_iter 100
+
+# --- Gaussian Spatial Coregional 3 Models (Small) Example ---
+# srun python ./gs_coreg3_small/run.py --max_iter 100
+
+# --- Gaussian Spatial Model (Small) Example ---
+# srun python ./gs_small/run.py --max_iter 100
+
+# --- Gaussian Spatio-temporal Coregional 2 Models (Small) Example ---
+# srun python ./gst_coreg2_small/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Coregional 3 Models (Small) Example ---
+# srun python ./gst_coreg3_small/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Model (Large) Example ---
+# srun python ./gst_large/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Model (Medium) Example ---
+# srun python ./gst_medium/run.py --solver_min_p 1 --max_iter 100
+
+# --- Gaussian Spatio-temporal Model (Small) Example ---
+# srun python ./gst_small/run.py --solver_min_p 1 --max_iter 100
+
+# --- Poisson AR1 Example ---
+# srun python ./p_ar1/run.py --max_iter 100
+
+# --- Poisson Regression Example ---
+# srun python ./pr/run.py --max_iter 100
+
+# --- Poisson Spatio-temporal Model (Small) Example ---
+# srun python ./pst_small/run.py --solver_min_p 1 --max_iter 100
\ No newline at end of file
diff --git a/examples/todo_utils_scripts/data_preprocessing_coreg.py b/examples/todo_utils_scripts/data_preprocessing_coreg.py
deleted file mode 100644
index 2d4afae8..00000000
--- a/examples/todo_utils_scripts/data_preprocessing_coreg.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# import math
-import os
-
-# import matplotlib.pyplot as plt
-import numpy as np
-from matrix_utilities import ( # read_sym_CSC,
- load_matrices_spatial_model_from_dat,
- load_matrices_spatial_temporal_model_from_dat,
- read_CSC,
-)
-from scipy.sparse import csc_matrix, save_npz # block_diag,
-
-# from dalia import sp, xp
-
-if __name__ == "__main__":
- # get current path
- path = os.path.dirname(__file__)
-
- num_vars = 2
-
- type = "spatio-temporal" # "spatial" #
-
- ns = 354
- nt = 12
-
- # add more shared fixed effects later
- nb = num_vars
- nb_per_var = 1
-
- no1 = 2 * ns * nt
- no2 = 2 * ns * nt
- no3 = 0 # 2 * ns * nt
-
- dim_theta = 9
-
- no_list = [no1, no2, no3]
- total_obs = sum(no_list)
-
- n = num_vars * (ns * nt + nb_per_var)
-
- data_dir = f"../../../repositories/application/coregionalization_models/data/nv{num_vars}_ns{ns}_nt{nt}_nb{nb}"
-
- # load submatrices
- if type == "spatial":
- c0, g1, g2 = load_matrices_spatial_model_from_dat(ns, data_dir)
-
- elif type == "spatio-temporal":
- c0, g1, g2, g3, M0, M1, M2 = load_matrices_spatial_temporal_model_from_dat(
- ns, nt, data_dir
- )
- else:
- raise ValueError("Invalid model type")
-
- # load observation vectors
- y1_file = f"{data_dir}/y1_{no1}_1.dat"
- y1 = np.loadtxt(y1_file)
-
- y2_file = f"{data_dir}/y2_{no2}_1.dat"
- y2 = np.loadtxt(y2_file)
-
- # load projection matrices
- a1_file = f"{data_dir}/A1_{no1}_{ns*nt}.dat"
- a1 = read_CSC(a1_file)
- print("nnz(A1) = ", a1.nnz)
-
- # split a into random and fixed effects
- a1_random = a1[:, : ns * nt]
- a1_fixed = csc_matrix(np.ones((no_list[0], nb_per_var)))
-
- a2_file = f"{data_dir}/A2_{no1}_{ns*nt}.dat"
- a2 = read_CSC(a2_file)
- print("nnz(A2) = ", a1.nnz)
-
- a2_random = a2[:, : ns * nt]
- a2_fixed = csc_matrix(np.ones((no_list[1], nb_per_var)))
-
- if num_vars == 3:
- y3_file = f"{data_dir}/y3_{no3}_1.dat"
- y3 = np.loadtxt(y3_file)
-
- a3_file = f"{data_dir}/A3_{no3}_{ns*nt}.dat"
- a3 = read_CSC(a3_file)
-
- a3_random = a3[:, : ns * nt]
- a3_fixed = csc_matrix(np.ones((no_list[2], nb_per_var)))
-
- # set path new data directory and create necessary folders
- new_data_dir = f"{path}/inputs_nv{num_vars}_ns{ns}_nt{nt}_nb{nb}"
- os.makedirs(new_data_dir, exist_ok=True)
- print(f"Created directory {new_data_dir}")
-
- # save matrices
- if type == "spatial":
- for i in range(num_vars):
- os.makedirs(f"{new_data_dir}/model_{i+1}/inputs_spatial", exist_ok=True)
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatial/c0.npz", c0)
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatial/g1.npz", g1)
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatial/g2.npz", g2)
-
- os.makedirs(f"{new_data_dir}/model_{i+1}/inputs_regression", exist_ok=True)
-
- save_npz(f"{new_data_dir}/model_1/inputs_spatial/a.npz", a1_random)
- save_npz(f"{new_data_dir}/model_2/inputs_spatial/a.npz", a2_random)
-
- save_npz(f"{new_data_dir}/model_1/inputs_regression/a.npz", a1_fixed)
- save_npz(f"{new_data_dir}/model_2/inputs_regression/a.npz", a2_fixed)
-
- if num_vars == 3:
- save_npz(f"{new_data_dir}/model_3/inputs_spatial/a.npz", a3_random)
- save_npz(f"{new_data_dir}/model_3/inputs_regression/a.npz", a3_fixed)
-
- elif type == "spatio-temporal":
- for i in range(num_vars):
- os.makedirs(
- f"{new_data_dir}/model_{i+1}/inputs_spatio_temporal", exist_ok=True
- )
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatio_temporal/c0.npz", c0)
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatio_temporal/g1.npz", g1)
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatio_temporal/g2.npz", g2)
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatio_temporal/g3.npz", g3)
-
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatio_temporal/m0.npz", M0)
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatio_temporal/m1.npz", M1)
- save_npz(f"{new_data_dir}/model_{i+1}/inputs_spatio_temporal/m2.npz", M2)
-
- os.makedirs(f"{new_data_dir}/model_{i+1}/inputs_regression/", exist_ok=True)
-
- save_npz(f"{new_data_dir}/model_1/inputs_spatio_temporal/a.npz", a1_random)
- save_npz(f"{new_data_dir}/model_2/inputs_spatio_temporal/a.npz", a2_random)
-
- save_npz(f"{new_data_dir}/model_1/inputs_regression/a.npz", a1_fixed)
- save_npz(f"{new_data_dir}/model_2/inputs_regression/a.npz", a2_fixed)
-
- if num_vars == 3:
- save_npz(f"{new_data_dir}/model_3/inputs_spatio_temporal/a.npz", a3_random)
- save_npz(f"{new_data_dir}/model_3/inputs_regression/a.npz", a3_fixed)
-
- np.save(f"{new_data_dir}/model_1/y.npy", y1)
- np.save(f"{new_data_dir}/model_2/y.npy", y2)
-
- if num_vars == 3:
- np.save(f"{new_data_dir}/model_3/y.npy", y3)
-
- # reference output
- theta_ref_file = f"{data_dir}/theta_interpretS_original_{dim_theta}_1.dat"
- theta_ref = np.loadtxt(theta_ref_file)
- print(f"Reference theta: {theta_ref}")
-
- theta_ref_py_file = (
- f"{data_dir}/theta_interpretS_original_DALIA_perm_{dim_theta}_1.dat"
- )
- theta_ref_py = np.loadtxt(theta_ref_py_file)
- print(f"Reference theta (python order): {theta_ref_py}")
-
- # load reference x
- x_ref_file = f"{data_dir}/x_original_{n}_1.dat"
- x_ref = np.loadtxt(x_ref_file)
- print("x_ref[:10]: ", x_ref[:10])
-
- os.makedirs(f"{new_data_dir}/reference_outputs", exist_ok=True)
-
- np.save(f"{new_data_dir}/reference_outputs/theta_ref.npy", theta_ref_py)
- # np.savetxt(
- # f"{new_data_dir}/reference_outputs/theta_interpretS_original_DALIA_perm_{dim_theta}_1.dat",
- # theta_ref_py,
- # )
-
- # save reference x
- # np.savetxt(f"{new_data_dir}/reference_outputs/x_original_{n}_1.dat", x_ref)
- np.save(f"{new_data_dir}/reference_outputs/x_ref.npy", x_ref)
-
- print("num vars: ", num_vars)
- print("type: ", type)
-
- # # reorder theta to python format
- # if type == "spatial" and num_vars == 3:
- # # old order ['sigma_1' , 'r_s1', 'sigma_2', 'r_s2', 'sigma_3', 'r_s3', 'lambda_0_1', 'lambda_0_2', 'lambda_1_2' 'prec_o1', 'prec_o2', 'prec_o3']
- # # new order ['r_s1', 'prec_o1', 'r_s2', 'prec_o2', 'r_s3', 'prec_o3', 'sigma_1', 'sigma_2', 'sigma_3', 'lambda_0_1', 'lambda_0_2', 'lambda_1_2']
- # theta_py = np.array([
- # theta_ref[1], # r_s1
- # theta_ref[9], # prec_o1
- # theta_ref[3], # r_s2
- # theta_ref[10], # prec_o2
- # theta_ref[5], # r_s3
- # theta_ref[11], # prec_o3
- # theta_ref[0], # sigma_1 (was sigma_0 in the previous mapping)
- # theta_ref[2], # sigma_2 (was sigma_1 in the previous mapping)
- # theta_ref[4], # sigma_3
- # theta_ref[6], # lambda_0_1
- # theta_ref[7], # lambda_0_2
- # theta_ref[8], # lambda_1_2
- # ])
-
- # elif num_vars == 2 and type == "spatio-temporal":
- # # old order ['sigma_1' , 'r_s1', 'r_t1', 'sigma_2', 'r_s2', 'r_t2', 'prec_o1', 'prec_o2', 'lambda_0_1']
- # # new order ['r_s1', 'r_t1', 'prec_o1', 'r_s2', 'r_t2', 'prec_o2', 'sigma_1', 'sigma_2', 'lambda_0_1']
- # theta_py = np.array([
- # theta_ref[1], # r_s1
- # theta_ref[2], # r_t1
- # theta_ref[6], # prec_o1
- # theta_ref[4], # r_s2
- # theta_ref[5], # r_t2
- # theta_ref[7], # prec_o2
- # theta_ref[0], # sigma_1
- # theta_ref[3], # sigma_2
- # theta_ref[8], # lambda_0_1
- # ])
-
- # else:
- # raise ValueError("Invalid model type")
-
- # print(f"theta R: {theta_ref}")
- # print(f"Reordered theta: {theta_py}")
-
- # save reordered theta
- # theta_py_file = f"{new_data_dir}/reference_outputs/theta_interpretS_original_DALIA_perm_{len(theta_py)}_1.dat"
- # np.savetxt(theta_py_file, theta_py)
- # print(f"Saved reordered theta to {theta_py_file}")
diff --git a/install.md b/install.md
new file mode 100644
index 00000000..89205907
--- /dev/null
+++ b/install.md
@@ -0,0 +1,339 @@
+# How to install DALIA on one of the default clusters or your personal machine
+DALIA has been developed on several (super)computing infrastructures. Since the beginning of the project, our goal has been to provide a seamless experience for users regardless of the underlying hardware and software stack.
+
+Because full generalization is not always possible, we provide a simplified procedure for multiple supercomputing infrastructures (with different hardware) to make installing DALIA on a new cluster as easy as possible.
+
+## Purpose
+Base `conda` environments for DALIA are provided in the `dalia/envs` directory. These environments contain all the Python packages on which DALIA relies. Hardware-specific packages (e.g., `cupy` for GPU compute, `mpi4py` for multiprocessing) are not included; they can be installed as extensions of the base environment. We provide base environments for both `x86` and `aarch64` architectures.
+
+These environments are used in the CI/CD pipelines for the respective clusters and are also available for users to quickly set up DALIA on these clusters or any other machine.
+
+## Supported Clusters
+We currently use three different clusters for the testing and development of DALIA:
+
+| Organization | Cluster | Arch | Description | Memory | Nodes |
+| ------------ | ------- | --------------------------------- | -------------------------------------------------------------- | -------------------------- | ---------------------------- |
+| [FAU](https://doc.nhr.fau.de/) | [Fritz](https://doc.nhr.fau.de/clusters/fritz/) | x86 (Sapphire Rapids) | 2x Intel Xeon Platinum 8470 per node
(2x52-cores @ 2.0 GHz) | Up to 2TB DDR5 | 64x 8470, 992x 8360Y |
+| [FAU](https://doc.nhr.fau.de/) | [Alex](https://doc.nhr.fau.de/clusters/alex/) | x86 (AMD EPYC 7713)
Ampere
| 8x Nvidia A100 per node
(2x 64-cores CPU + 8 GPUs) | Up to 80GB HBM2 | 18x A100 80GB, 20x A100 40GB |
+| [CSCS](https://www.cscs.ch/) | [Daint](https://docs.cscs.ch/clusters/daint/) | ARM (Grace)
Hopper | 4x Nvidia GH200 per node
(72-cores CPU + 1 GPU) | 128GB LPDDR5X
96GB HBM3 | 1022 |
+
+
+## Environments Configuration Matrix
+DALIA is designed to work across a wide variety of hardware and software stacks (e.g., GPU acceleration, distributed memory, vendor-specific libraries). The table below summarizes the supported configurations and the corresponding pre-configured `conda` environments.
+
+| | No Comm | Host MPI | GPU-Aware MPI | xCCL |
+| :----------- | :------------------- | :----------------- | :----------------- | :----------------- |
+| CPU (x86) | *dalia_base_fritz* | *dalia_hmpi_fritz* | NA | NA |
+| GPU (NVIDIA) | *dalia_base_alex* | NA | NA | *dalia_xccl_alex* |
+| GPU (NVIDIA) | *dalia_base_daint* | NA | NA | *dalia_xccl_daint* |
+| CPU (ARM) | x | x | x | x |
+| GPU (AMD) | x | x | x | x |
+
+The `dalia_base` environment contains all the dependencies needed to run DALIA on a single node without any communication library (e.g., MPI, xCCL) and without GPU support. This environment includes only hardware-independent Python dependencies. In `dalia/scripts/` we provide interactive installers that can create these `dalia_base` environments for you and, when applicable, extend them to support multi-node communication with MPI or xCCL and/or GPU acceleration.
+
+**Notes:**
+- On the Daint and Alex clusters, CuPy is installed using wheels that include NCCL. Therefore, NCCL is available whenever CuPy is installed. For this reason, no standalone GPU-aware MPI environments are provided.
+- On Daint, the `dalia_base_daint` environment allows one to run on the ARM CPU. However, we do not provide any conda environment specifically targeted for a general ARM CPU-cluster.
+
+# Detailed Instructions
+## On Fritz@FAU
+### a) Installation
+1. Clone the repositories to your workspace:
+ ```bash
+ mv /my/install/path
+ git clone https://github.com/dalia-project/DALIA
+ git clone https://github.com/vincent-maillou/serinv # Optional, recommended for ST-modeling
+ ```
+2. Source the `fritz_fau_utils.sh` script to access the install utilities:
+ ```bash
+ cd DALIA/
+ source scripts/fritz_fau_utils.sh
+ ```
+3. Load the required environment modules:
+ ```bash
+ fritz_load_modules
+ ```
+4. Create the conda environment:
+ ```bash
+ fritz_create_conda_env --dalia-path=path/to/DALIA --serinv-path=path/to/serinv --install-mpi4py --dev-mode
+ ```
+ Notes:
+ - This example installs all optional dependencies and uses developer mode. You can also run the installer in interactive mode by simply running `fritz_create_conda_env`.
+ - Developer mode (`--dev-mode`) keeps the most performant conda environment available, as well as all environments created along the way. This ensures that during development DALIA can be tested against all supported configurations.
+ - The created environment is activated automatically at the end of the installation.
+5. Activate the conda environment:
+ ```bash
+ fritz_activate_conda_env
+ ```
+ Note: This function will try to activate the most performant environment available on the cluster. You can also activate a specific environment by providing the `--env="desired_environment"` argument to the function.
+
+### b) Usage
+Right after installation, you can use DALIA using the loaded modules and activated conda environment.
+
+However, in general, in future shell sessions you will need to:
+1. Source the `fritz_fau_utils.sh` script:
+ ```bash
+ source /path/to/DALIA/scripts/fritz_fau_utils.sh
+ ```
+2. Load the required environment modules:
+ ```bash
+ fritz_load_modules
+ ```
+3. Activate the conda environment:
+ ```bash
+ fritz_activate_conda_env
+ ```
+
+You will then be ready to use DALIA.
+
+### c) Verify Installation
+There is currently two ways to verify that DALIA has been installed correctly:
+1. Run the provided test suite. In an interactive session: `salloc -N 1 --time=00:30:00` (you might need to wait to get the allocation), you can (after activated the correct conda environment) run the test suite: `./path/to/DALIA/tests/runner.sh`.
+2. Run one of the provided examples, you can check the `/path/to/DALIA/examples/run_example_fritz_fau.sh` script for an example on how to submit a job on Fritz.
+
+## On Alex@FAU
+### a) Installation
+
+1. Clone the repositories to your workspace:
+ ```bash
+ mv /my/install/path
+ git clone https://github.com/dalia-project/DALIA
+ git clone https://github.com/vincent-maillou/serinv # Optional, recommended for ST-modeling
+ ```
+2. Source the `alex_fau_utils.sh` script to access the install utilities:
+ ```bash
+ cd DALIA/
+ source scripts/alex_fau_utils.sh
+ ```
+3. Load the required environment modules:
+ ```bash
+ alex_load_modules
+ ```
+4. Create the conda environment:
+ ```bash
+ alex_create_conda_env --dalia-path=path/to/DALIA --serinv-path=path/to/serinv --install-mpi4py --dev-mode
+ ```
+ Notes:
+ - This example installs all optional dependencies and uses developer mode. You can also run the installer in interactive mode by simply running `alex_create_conda_env`.
+ - Developer mode (`--dev-mode`) keeps the most performant conda environment available, as well as all environments created along the way. This ensures that during development DALIA can be tested against all supported configurations.
+ - The created environment is activated automatically at the end of the installation.
+5. Activate the conda environment:
+ ```bash
+ alex_activate_conda_env
+ ```
+ Note: This function will try to activate the most performant environment available on the cluster. You can also activate a specific environment by providing the `--env` argument to the function.
+
+### b) Usage
+Right after installation, you can use DALIA using the loaded modules and activated conda environment.
+
+However, in general, in future shell sessions you will need to:
+1. Source the `alex_fau_utils.sh` script:
+ ```bash
+ source /path/to/DALIA/scripts/alex_fau_utils.sh
+ ```
+2. Load the required environment modules:
+ ```bash
+ alex_load_modules
+ ```
+3. Activate the conda environment:
+ ```bash
+ alex_activate_conda_env
+ ```
+
+You will then be ready to use DALIA.
+
+### c) Verify Installation
+There is currently two ways to verify that DALIA has been installed correctly:
+1. Run the provided test suite. In an interactive session: `salloc --gres=gpu:a100:1 --time=0:30:00` (you might need to wait to get the allocation), you can (after activated the correct conda environment) run the test suite: `./path/to/DALIA/tests/runner.sh`.
+2. Run one of the provided examples, you can check the `/path/to/DALIA/examples/run_example_alex_fau.sh` script for an example on how to submit a job on Alex.
+
+
+## On Daint@CSCS
+### Preamble
+
+In order to successfully install DALIA or even clone the repository on Daint, you will need to install `git-lfs` (Git Large File Storage) in your user space (see "Notes on `git-lfs`" for details). We provide a gist utility script alongside the following installation instructions:
+
+```bash
+mv /my/install/path
+git clone https://gist.github.com/vincent-maillou/d2d38937f7aafbf0cee98c65cf5cfbca # Get the installation script
+cd d2d38937f7aafbf0cee98c65cf5cfbca/ # Move into the script directory
+
+chmod u+x daint_install_git_lfs.sh # Render the script executable
+./daint_install_git_lfs.sh # Run the installation script
+export PATH="$HOME/.local/bin:$PATH" # Add git-lfs to your PATH
+
+which git-lfs # Verify the installation
+```
+
+Git-lfs will now be installed and available in your user space. Its path have been added to your `.bashrc` file, so it will be available in future shell sessions.
+
+
+### a) Installation
+1. Clone the repositories to your workspace:
+ ```bash
+ mv /my/install/path
+ git clone https://github.com/dalia-project/DALIA
+ git clone https://github.com/vincent-maillou/serinv # Optional, recommended for ST-modeling
+ ```
+
+2. Source the `daint_cscs_utils.sh` script to access the install utilities:
+ ```bash
+ cd DALIA/
+ source scripts/daint_cscs_utils.sh
+ ```
+
+3. Use the `daint_install_conda` utility to install Miniconda in your user space (if not already installed):
+ ```bash
+ daint_install_conda --yes
+ ```
+ **Notes:**
+ - This step is only required if you do not already have `conda` installed.
+ - The `--yes` option automatically confirms the installation prompts.
+ - Additional information about the installer options can be found by running `daint_install_conda --help`.
+
+4. Source the installed conda environment:
+ ```bash
+ source ~/miniconda3/etc/profile.d/conda.sh
+ ```
+ **Notes:**
+ - This is needed only once; in subsequent shell sessions conda will be initialized automatically through your `.bashrc` file.
+
+5. Install the programming environment (`uenv`):
+ ```bash
+ daint_install_uenv
+ ```
+
+6. Start the programming environment:
+ ```bash
+ daint_start_uenv
+ ```
+ **Notes:**
+ - This script assumes that you have successfully installed the programming environment in the previous step. If a programming environment is already active, it will be stopped and replaced with a new one.
+
+7. Source the `daint_cscs_utils.sh` script again to access the install utilities:
+ ```bash
+ source scripts/daint_cscs_utils.sh
+ ```
+ **Notes:**
+ - This is needed because starting the programming environment spawns a new shell session in which previously sourced scripts are no longer available.
+
+8. Load the required environment modules:
+ ```bash
+ daint_load_modules
+ ```
+
+9. Create the conda environment:
+ ```bash
+ daint_create_conda_env --dalia-path=path/to/DALIA --serinv-path=path/to/serinv --install-mpi4py --dev-mode
+ ```
+
+10. Activate the conda environment:
+ ```bash
+ daint_activate_conda_env
+ ```
+ Note: This function tries to activate the most performant environment available on the cluster. You can also activate a specific environment by providing the `--env` argument.
+
+### b) Usage
+Right after installation, you can use DALIA using the loaded modules and activated conda environment.
+
+However, in general, in future shell sessions you will need to:
+1. Source the `daint_cscs_utils.sh` script:
+ ```bash
+ source /path/to/DALIA/scripts/daint_cscs_utils.sh
+ ```
+2. Start the programming environment:
+ ```bash
+ daint_start_uenv
+ ```
+3. Source the `daint_cscs_utils.sh` script again (`uenv` is starting a new shell) to access the install utilities:
+ ```bash
+ source /path/to/DALIA/scripts/daint_cscs_utils.sh
+ ```
+4. Load the required environment modules:
+ ```bash
+ daint_load_modules
+ ```
+5. Activate the conda environment:
+ ```bash
+ daint_activate_conda_env
+ ```
+
+You will then be ready to use DALIA.
+
+### c) Verify Installation
+There is currently two ways to verify that DALIA has been installed correctly:
+1. Run the provided test suite. Directly from the front node (after activating the correct conda environment) you can run the test suite: `./path/to/DALIA/tests/runner.sh`.
+2. Run one of the provided examples, you can check the `/path/to/DALIA/examples/run_example_daint_cscs.sh` script for an example on how to submit a job on Daint.
+
+## On a Personal Machine
+### a) Installation
+Given a working installation of `git` and `conda`, you can install DALIA on your personal machine as follows:
+
+1. Clone the repositories to your workspace:
+ ```bash
+ mv /my/install/path
+ git clone https://github.com/dalia-project/DALIA
+ git clone https://github.com/vincent-maillou/serinv # Optional, recommended for ST-modeling
+ ```
+
+2. Based on your micro-architecture (x86 or aarch64), create the conda environment using the provided installer scripts:
+- For x86 architecture:
+ ```bash
+ conda env create --name "env_name" -f path/to/DALIA/envs/dalia_base_x86.yml
+ ```
+- For aarch64 architecture:
+ ```bash
+ conda env create --name "env_name" -f path/to/DALIA/envs/dalia_base_aarch64.yml
+ ```
+
+3. Activate the conda environment:
+ ```bash
+ conda activate env_name
+ ```
+
+4. Install DALIA in editable mode:
+ ```bash
+ cd path/to/DALIA
+ pip install -e .
+ ```
+
+5. Install Serinv (optional, recommended for ST-modeling) in editable mode:
+ ```bash
+ cd path/to/serinv
+ pip install -e .
+ ```
+
+### b) Verify Installation
+They are currently two ways to verify that DALIA has been installed correctly:
+1. Run the provided test suite: `./path/to/DALIA/tests/runner.sh`.
+2. Run one of the provided examples, for example: `python /path/to/DALIA/examples/gr/run.py`
+
+## Notes on `git-lfs`
+
+### Overview
+DALIA uses `git-lfs` (Git Large File Storage) to manage large files such as example datasets. Files tracked by `git-lfs` are stored as pointers in the repository and downloaded on-demand rather than during the initial clone. This is particularly beneficial for large binary files (images, audio files, datasets) that don't compress well.
+
+### Availability by Cluster
+
+#### Fritz and Alex (FAU)
+- `git-lfs` is available by default
+- Location: `/usr/bin/git-lfs`
+- No additional installation required
+
+#### Daint (CSCS)
+- `git-lfs` is **not** provided by the system
+- Must be installed manually in your user space
+- Installation script available at: https://gist.github.com/vincent-maillou/d2d38937f7aafbf0cee98c65cf5cfbca
+
+### Important Note on Repository Cloning
+Without `git-lfs` installed on Daint, `git clone` will fail because it automatically triggers `git-lfs checkout` during the checkout process. This is why the Daint installation instructions include a `git-lfs` installation step in the preamble (see section "On Daint@CSCS > Preamble" for detailed instructions).
+
+## Other Informations
+
+- After installing packages with conda, it is recommended to run `conda clean --all` to free up disk space.
+- These installation procedures and `conda` environments have been tested on Linux systems based on both `x86` and `aarch64` architectures.
+- If the `sqlite` module does not work properly, forcing the following version might help:
+ ```bash
+ conda install conda-forge::sqlite=3.45.3
+ ```
+- CuPy switched NCCL to lazy import, which means NCCL must be imported explicitly with `from cupy.cuda import nccl` before it is available in `cupy.cuda.nccl`.
diff --git a/justfile b/justfile
deleted file mode 100644
index e9d669c2..00000000
--- a/justfile
+++ /dev/null
@@ -1,21 +0,0 @@
-# Cleans the repo.
-clean:
- @find . | grep -E "(__pycache__|\.pyc|\.pyo|build|generated$)" | xargs rm -rf
- @rm -rf src/*.egg-info/ build/ dist/ .coverage .pytest_cache/
-
-# Applies formatting to all files.
-format:
- isort --profile black .
- black .
- blacken-docs
-
-# Lints all files.
-lint:
- ruff check
-
-# Runs all non-MPI tests and determines coverage.
-test-cov workers="4":
- pytest -n {{workers}} --cov=src/dalia --cov-report=term --cov-report=xml tests/
-
-# Runs all tests.
-test: test-cov
diff --git a/pyproject.toml b/pyproject.toml
index 0abe06e2..b3b774e0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,22 +7,32 @@ name = "DALIA"
authors = [
{ name = "Lisa Gaedke-Merzhaeuser", email = "lisa.gaedke.merzhaeuser@usi.ch" },
{ name = "Vincent Maillou", email = "vmaillou@iis.ee.ethz.ch" },
+ { name = "Alexandros Nikolaos Ziogas", email = "alziogas@iis.ee.ethz.ch" },
]
-description = "Python implementation of the method of integrated nested Laplace approximations (INLA)"
+
+description = "Python implementation of the method of integrated nested Laplace approximations (INLA), putting the accent on portability, modularity and performance."
+
license = {text = "BSD-3-Clause"}
+
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
+ "Topic :: Scientific/Engineering",
]
+
keywords = ["INLA", "Bayesian", "statistics"]
+
dynamic = ["version"]
-requires-python = ">=3.11"
+
+requires-python = ">=3.13"
+
dependencies = [
"numpy>=1.23.2",
"scipy",
- "matplotlib",
+ "numba",
"pydantic",
+ "matplotlib",
"pytest",
"pytest-cov",
"pytest-mpi",
@@ -33,14 +43,22 @@ dependencies = [
"ruff",
"just",
"pre-commit",
+ "tabulate",
+ "psutil",
]
[project.optional-dependencies]
-mpi4py = ["mpi4py"]
-cupy = ["cupy==13.2.0"]
-serinv = ["serinv @ git+https://github.com/vincent-maillou/serinv"]
+gpu = [
+ "cupy==13.2.0",
+]
+multiprocessing = [
+ "mpi4py",
+]
+solvers = [
+ "serinv @ git+https://github.com/vincent-maillou/serinv"
+]
[project.urls]
-Code = "https://github.com/lisa-gm/DALIA"
+Code = "https://github.com/dalia-project/DALIA"
[tool.setuptools.dynamic]
version = { attr = "dalia.__about__.__version__" }
diff --git a/scripts/alex_fau_utils.sh b/scripts/alex_fau_utils.sh
new file mode 100644
index 00000000..62a96b30
--- /dev/null
+++ b/scripts/alex_fau_utils.sh
@@ -0,0 +1,592 @@
+#!/bin/bash
+
+alex_load_modules() {
+ echo "alex_load_modules: loading Alex system modules."
+
+ # Check if module command is available
+ if ! command -v module &> /dev/null; then
+ echo " Error: 'module' command not found. Please ensure you are on a system with environment modules."
+ return 1
+ fi
+
+ # Purge any existing modules
+ echo " Purging existing modules..."
+ module purge 2>/dev/null || {
+ echo " Warning: Failed to purge modules (this may be normal on some systems)."
+ }
+
+ # Load required modules
+ echo " Loading required modules: mkl/2023.2.0 gcc/12.1.0 cuda/12.9.0 openmpi/4.1.6-nvhpc23.7-cuda12 python"
+ module load mkl/2023.2.0 gcc/12.1.0 openmpi/4.1.6-nvhpc23.7-cuda12 cuda/12.9.0 python || {
+ echo " Error: Failed to load required modules."
+ echo " Available modules:"
+ module avail 2>&1 | head -20
+ echo " (output truncated - use 'module avail' for full list)"
+ return 1
+ }
+
+ echo " Successfully loaded all required modules."
+ return 0
+}
+
+alex_check_modules() {
+ echo "alex_check_modules: checking if required modules are loaded."
+
+ local required_modules=("mkl/2023.2.0" "gcc/12.1.0" "cuda/12.9.0" "openmpi/4.1.6-nvhpc23.7-cuda12" "python")
+ local missing_modules=()
+
+ # Get list of currently loaded modules
+ local loaded_modules=$(module list 2>&1 | grep -E "mkl|gcc|openmpi|cuda|nvhpc|python")
+
+ # Check each required module
+ for module in "${required_modules[@]}"; do
+ if ! echo "$loaded_modules" | grep -q "$module"; then
+ missing_modules+=("$module")
+ fi
+ done
+
+ if [ ${#missing_modules[@]} -eq 0 ]; then
+ echo " All required modules are loaded."
+ return 0
+ else
+ echo " Error: Missing required modules: ${missing_modules[*]}"
+ echo " Please run 'alex_load_modules' first."
+ return 1
+ fi
+}
+
+alex_create_conda_env_help() {
+ echo "alex_create_conda_env: Create DALIA conda environment for Alex supercomputer"
+ echo ""
+ echo "Usage:"
+ echo " alex_create_conda_env [OPTIONS]"
+ echo ""
+ echo "Options:"
+ echo " --dalia-path=PATH Path to DALIA repository root directory"
+ echo " --serinv-path=PATH Path to serinv repository (automatically installs serinv)"
+ echo " --install-mpi4py Install mpi4py and create enhanced environment"
+ echo " --dev-mode Keep intermediate environments (use with --install-mpi4py)"
+ echo ""
+ echo "Examples:"
+ echo " # Interactive mode (GPU support included by default)"
+ echo " alex_create_conda_env"
+ echo ""
+ echo " # Non-interactive mode with all options"
+ echo " alex_create_conda_env --dalia-path=/path/to/dalia --serinv-path=/path/to/serinv --install-mpi4py --dev-mode"
+ echo ""
+ echo " # Install base with GPU support only"
+ echo " alex_create_conda_env --dalia-path=/path/to/dalia"
+ echo ""
+ echo "Note: GPU support via cupy is installed by default for Alex cluster."
+ echo " If parameters are not provided, the function will prompt interactively."
+}
+
+
+alex_create_conda_env() {
+ echo "alex_create_conda_env: creating DALIA conda environment for Alex."
+
+ # Parse command line arguments
+ local dalia_path=""
+ local serinv_path=""
+ local install_mpi4py_flag=""
+ local dev_mode_flag=""
+
+ while [[ $# -gt 0 ]]; do
+ case $1 in
+ --dalia-path=*)
+ dalia_path="${1#*=}"
+ shift
+ ;;
+ --serinv-path=*)
+ serinv_path="${1#*=}"
+ shift
+ ;;
+ --install-mpi4py)
+ install_mpi4py_flag="y"
+ shift
+ ;;
+ --dev-mode)
+ dev_mode_flag="y"
+ shift
+ ;;
+ --help|-h)
+ alex_create_conda_env_help
+ return 0
+ ;;
+ *)
+ echo " Warning: Unknown parameter '$1' ignored."
+ echo " Use --help for usage information."
+ shift
+ ;;
+ esac
+ done
+
+ # 0. Deactivate any currently active conda environment
+ echo " Deactivating any currently active conda environment..."
+ conda deactivate 2>/dev/null || true
+
+ # 1. Check that the needed modules are loaded
+ if ! alex_check_modules; then
+ return 1
+ fi
+
+ # 2. Get DALIA repository path
+ if [[ -z "$dalia_path" ]]; then
+ echo " Please enter the path to the DALIA root repository:"
+ read -r dalia_path
+ else
+ echo " Using provided DALIA path: ${dalia_path}"
+ fi
+
+ # Expand tilde and remove trailing slash
+ dalia_path=$(eval echo "${dalia_path}")
+ dalia_path="${dalia_path%/}"
+
+ # 3. Validate DALIA repository path
+ if [[ ! -d "$dalia_path" ]]; then
+ echo " Error: Directory '${dalia_path}' does not exist."
+ return 1
+ fi
+
+ if [[ ! -f "${dalia_path}/pyproject.toml" ]]; then
+ echo " Error: '${dalia_path}' does not appear to be a DALIA repository (missing pyproject.toml)."
+ return 1
+ fi
+
+ if [[ ! -d "${dalia_path}/src/dalia" ]]; then
+ echo " Error: '${dalia_path}' does not contain the DALIA source code (missing src/dalia/)."
+ return 1
+ fi
+
+ # 4. Locate the conda environment file
+ local env_file="${dalia_path}/envs/dalia_base_x86.yml"
+ if [[ ! -f "$env_file" ]]; then
+ echo " Error: Conda environment file not found at '${env_file}'."
+ return 1
+ fi
+
+ echo " Found DALIA repository at: ${dalia_path}"
+ echo " Found conda environment file at: ${env_file}"
+
+ # Check if environment already exists
+ local env_name="dalia_base_alex"
+ if conda env list | grep -q "^${env_name} "; then
+ echo " Warning: Conda environment '${env_name}' already exists."
+ echo " Do you want to remove and recreate it? (y/N): "
+ read -r response
+ if [[ "$response" =~ ^[Yy]$ ]]; then
+ echo " Removing existing environment..."
+ conda env remove -n "$env_name" -y || {
+ echo " Error: Failed to remove existing environment."
+ return 1
+ }
+ else
+ echo " Skipping environment creation. Using existing environment."
+ fi
+ fi
+
+ # Create conda environment from YAML file (only if it doesn't exist or was removed)
+ if ! conda env list | grep -q "^${env_name} "; then
+ echo " Creating conda environment from '${env_file}'..."
+ conda env create --name "$env_name" -f "$env_file" || {
+ echo " Error: Failed to create conda environment from '${env_file}'."
+ return 1
+ }
+ fi
+
+ # 5. Activate the conda environment
+ echo " Activating conda environment..."
+ if ! alex_activate_conda_env --env="$env_name"; then
+ return 1
+ fi
+
+ # 6. Install DALIA from the repository
+ echo " Installing DALIA in development mode..."
+ cd "$dalia_path" || {
+ echo " Error: Failed to change directory to '${dalia_path}'."
+ return 1
+ }
+
+ python -m pip install --no-deps --editable . || {
+ echo " Error: Failed to install DALIA in development mode."
+ return 1
+ }
+
+ # 7. Optional: Install serinv structured sparse solver
+ echo ""
+ local install_serinv=""
+
+ if [[ -n "$serinv_path" ]]; then
+ # Serinv path provided via command line - install automatically
+ echo " Serinv path provided via --serinv-path. Installing serinv automatically..."
+ install_serinv="y"
+ else
+ # Interactive mode - ask user
+ echo " Spatio-temporal problems in DALIA can leverage the 'serinv' structured sparse solver for improved performance."
+ echo " Do you want to install the serinv structured sparse solver? (y/N): "
+ read -r install_serinv
+
+ if [[ "$install_serinv" =~ ^[Yy]$ ]]; then
+ echo " Please enter the path to the serinv root repository:"
+ read -r serinv_path
+ fi
+ fi
+
+ if [[ "$install_serinv" =~ ^[Yy]$ ]]; then
+ # Expand tilde and remove trailing slash
+ serinv_path=$(eval echo "${serinv_path}")
+ serinv_path="${serinv_path%/}"
+
+ # Validate serinv repository path
+ if [[ ! -d "$serinv_path" ]]; then
+ echo " Error: Directory '${serinv_path}' does not exist."
+ echo " Skipping serinv installation."
+ elif [[ ! -f "${serinv_path}/pyproject.toml" ]] && [[ ! -f "${serinv_path}/setup.py" ]]; then
+ echo " Error: '${serinv_path}' does not appear to be a valid Python package (missing pyproject.toml or setup.py)."
+ echo " Skipping serinv installation."
+ else
+ echo " Installing serinv from '${serinv_path}' in development mode..."
+ cd "$serinv_path" || {
+ echo " Error: Failed to change directory to '${serinv_path}'."
+ echo " Skipping serinv installation."
+ }
+
+ if python -m pip install --no-deps --editable .; then
+ echo " Successfully installed serinv in development mode."
+ else
+ echo " Warning: Failed to install serinv. You may need to install it manually later."
+ fi
+
+ # Return to DALIA directory
+ cd "$dalia_path" || true
+ fi
+ else
+ echo " Skipping serinv installation."
+ fi
+
+ # 8. Install cupy with GPU support (default for Alex cluster)
+ echo ""
+ echo " Installing cupy with GPU support using SLURM job (default for Alex cluster)..."
+
+ # Install cupy using SLURM job on GPU partition
+ echo " Submitting SLURM job to install cupy on GPU partition..."
+ local job_script=$(mktemp)
+ cat > "$job_script" <<'SLURM_EOF'
+#!/bin/bash -l
+export http_proxy=http://proxy.nhr.fau.de:80
+export https_proxy=http://proxy.nhr.fau.de:80
+conda activate ENV_NAME_PLACEHOLDER
+python -m pip install cupy-cuda12x --no-cache-dir
+SLURM_EOF
+
+ # Replace the environment name placeholder
+ sed -i "s/ENV_NAME_PLACEHOLDER/${env_name}/g" "$job_script"
+
+ # Submit the job and capture job ID
+ local job_output=$(sbatch --partition=a40 --nodes=1 --gres=gpu:a40:1 --time=00:05:00 --job-name="cupy_install_${env_name}" "$job_script")
+ local job_id=$(echo "$job_output" | grep -oE '[0-9]+')
+
+ # Clean up temporary script
+ rm -f "$job_script"
+
+ if [[ -n "$job_id" ]]; then
+ echo " SLURM job submitted successfully with ID: ${job_id}"
+ echo " Waiting for cupy installation to complete..."
+
+ # Wait for job completion
+ local job_status=""
+ local wait_count=0
+ local max_wait=120 # Maximum wait time in 5-second intervals (10 minutes)
+
+ while [[ "$wait_count" -lt "$max_wait" ]]; do
+ job_status=$(squeue -j "$job_id" -h -o "%T" 2>/dev/null || echo "COMPLETED")
+
+ case "$job_status" in
+ "RUNNING")
+ echo " Job ${job_id} is running... (${wait_count}/${max_wait})"
+ ;;
+ "PENDING")
+ echo " Job ${job_id} is pending... (${wait_count}/${max_wait})"
+ ;;
+ "COMPLETED"|"")
+ echo " Job ${job_id} completed successfully."
+ break
+ ;;
+ "FAILED"|"CANCELLED"|"TIMEOUT")
+ echo " Error: Job ${job_id} failed with status: ${job_status}"
+ break
+ ;;
+ *)
+ echo " Job ${job_id} status: ${job_status} (${wait_count}/${max_wait})"
+ ;;
+ esac
+
+ sleep 5
+ ((wait_count++))
+ done
+
+ # Check final job status
+ if [[ "$wait_count" -ge "$max_wait" ]]; then
+ echo " Warning: Timeout waiting for job completion. Please check job status manually: squeue -j ${job_id}"
+ echo " You can also check the job output with: scontrol show job ${job_id}"
+ # Clean up SLURM job output files after timeout
+ echo " Cleaning up SLURM job files..."
+ rm -f "cupy_install_${env_name}.o${job_id}" "cupy_install_${env_name}.e${job_id}" 2>/dev/null || true
+ elif [[ "$job_status" == "COMPLETED" || -z "$job_status" ]]; then
+ # Removed cupy instalaltion check as on front node there is no Cuda capable device.
+ echo " CuPy installation completed."
+
+ # Clean up SLURM job output files
+ echo " Cleaning up SLURM job files..."
+ rm -f "cupy_install_${env_name}.o${job_id}" "cupy_install_${env_name}.e${job_id}" 2>/dev/null || true
+ else
+ echo " Error: CuPy installation job failed. Please check SLURM logs."
+ echo " Continuing with base environment without GPU support..."
+ echo " You can install cupy manually later by submitting a GPU job."
+ # Clean up SLURM job output files after failure
+ echo " Cleaning up SLURM job files..."
+ rm -f "cupy_install_${env_name}.o${job_id}" "cupy_install_${env_name}.e${job_id}" 2>/dev/null || true
+ fi
+ else
+ echo " Error: Failed to submit SLURM job for cupy installation."
+ echo " Continuing with base environment without GPU support..."
+ echo " You can install cupy manually later by submitting a GPU job."
+ fi
+
+ # 9. Optional: Install mpi4py and create enhanced environment
+ echo ""
+ local install_mpi4py=""
+
+ if [[ -n "$install_mpi4py_flag" ]]; then
+ # MPI4py installation requested via command line
+ echo " MPI4py installation requested via --install-mpi4py. Installing automatically..."
+ install_mpi4py="y"
+ else
+ # Interactive mode - ask user
+ echo " MPI support can be added through mpi4py for improved parallel performance."
+ echo " Do you want to install mpi4py and create an enhanced environment? (y/N): "
+ read -r install_mpi4py
+ fi
+
+ if [[ "$install_mpi4py" =~ ^[Yy]$ ]]; then
+ echo " Creating enhanced environment with mpi4py support..."
+
+ # Determine the enhanced environment name
+ local mpi_enhanced_env_name="dalia_xccl_alex" # Always GPU + MPI for Alex
+ echo " Creating environment with GPU and mpi4py support..."
+
+ # Deactivate current environment
+ echo " Deactivating current environment to create MPI-enhanced version..."
+ conda deactivate 2>/dev/null || true
+
+ # Remove enhanced environment if it already exists
+ if conda env list | grep -q "^${mpi_enhanced_env_name} "; then
+ echo " Removing existing MPI-enhanced environment..."
+ conda env remove -n "$mpi_enhanced_env_name" -y || {
+ echo " Warning: Failed to remove existing MPI-enhanced environment."
+ }
+ fi
+
+ # Clone the current environment
+ if conda create --name "$mpi_enhanced_env_name" --clone "$env_name" -y; then
+ echo " Successfully created MPI-enhanced environment."
+
+ # Activate the enhanced environment
+ echo " Activating MPI-enhanced environment '${mpi_enhanced_env_name}'..."
+ if alex_activate_conda_env --env="$mpi_enhanced_env_name"; then
+ # Install mpi4py in the enhanced environment
+ echo " Installing mpi4py with OpenMPI support in enhanced environment..."
+ cd "$dalia_path" || true
+ if MPICC=$(which mpicc) pip install --no-cache-dir mpi4py; then
+ echo " Successfully installed mpi4py in MPI-enhanced environment."
+
+ # Determine whether to keep base environment
+ local keep_base=""
+ if [[ -n "$install_mpi4py_flag" ]]; then
+ # Command line mode - use dev_mode_flag to decide
+ if [[ -n "$dev_mode_flag" ]]; then
+ keep_base="y"
+ echo " Developer mode enabled via --dev-mode. Keeping base environment."
+ else
+ keep_base="n"
+ echo " Default mode: removing base environment to keep only enhanced version."
+ fi
+ else
+ # Interactive mode - ask user
+ echo ""
+ echo " Do you want to keep the base environment 'dalia_base_alex' for development without MPI? (y/N): "
+ read -r keep_base
+ fi
+
+ if [[ ! "$keep_base" =~ ^[Yy]$ ]]; then
+ echo " Removing base environment 'dalia_base_alex'..."
+ conda env remove -n "dalia_base_alex" -y || {
+ echo " Warning: Failed to remove base environment."
+ }
+ else
+ echo " Keeping base environment for development without MPI."
+ fi
+
+ env_name="$mpi_enhanced_env_name" # Update env_name for final message
+ else
+ echo " Error: Failed to install mpi4py in MPI-enhanced environment."
+ echo " Removing broken MPI-enhanced environment and reverting to base environment..."
+
+ # Deactivate the enhanced environment
+ conda deactivate 2>/dev/null || true
+
+ # Remove the broken enhanced environment
+ conda env remove -n "$mpi_enhanced_env_name" -y || {
+ echo " Warning: Failed to remove broken MPI-enhanced environment."
+ }
+
+ # Reactivate the base environment
+ if alex_activate_conda_env --env="dalia_base_alex"; then
+ echo " Reverted to base environment 'dalia_base_alex'."
+ env_name="dalia_base_alex" # Reset env_name to base environment
+ echo " You can install mpi4py manually later with: MPICC=\$(which mpicc) pip install --no-cache-dir mpi4py"
+ else
+ echo " Warning: Failed to reactivate base environment."
+ fi
+ fi
+ else
+ echo " Warning: Failed to activate MPI-enhanced environment."
+ fi
+ else
+ echo " Error: Failed to create MPI-enhanced environment."
+ echo " Continuing with base environment..."
+ fi
+ else
+ echo " Skipping mpi4py installation."
+ fi
+
+ # 10. Final success message
+ echo ""
+ echo " Success! DALIA conda environment '${env_name}' has been created and configured."
+ echo " Repository path: ${dalia_path}"
+ if [[ "$install_serinv" =~ ^[Yy]$ ]] && [[ -d "$serinv_path" ]]; then
+ echo " Serinv path: ${serinv_path}"
+ fi
+ echo " Base environment includes GPU support via cupy (default for Alex cluster)."
+ if [[ "$install_mpi4py" =~ ^[Yy]$ ]]; then
+ if [[ "$env_name" == *"xccl"* ]]; then
+ echo " Enhanced environment with mpi4py and NCCL support created."
+ fi
+ fi
+ echo " To use a specific environment in the future, run: alex_activate_conda_env --env=\"alex_env_name\""
+ echo " To use the most performant environment you have available, just run: alex_activate_conda_env"
+
+ return 0
+}
+
+
+
+alex_activate_conda_env() {
+ echo "alex_activate_conda_env: activating DALIA conda environment."
+
+ # Parse command line arguments
+ local specified_env=""
+
+ while [[ $# -gt 0 ]]; do
+ case $1 in
+ --env=*)
+ specified_env="${1#*=}"
+ shift
+ ;;
+ --help|-h)
+ echo "alex_activate_conda_env: Activate DALIA conda environment"
+ echo ""
+ echo "Usage:"
+ echo " alex_activate_conda_env [OPTIONS]"
+ echo ""
+ echo "Options:"
+ echo " --env=NAME Specific environment name to activate"
+ echo " --help/-h Show this help message"
+ echo ""
+ echo "Examples:"
+ echo " alex_activate_conda_env # Auto-select best available"
+ echo " alex_activate_conda_env --env=dalia_base_alex # Activate specific environment"
+ return 0
+ ;;
+ *)
+ echo " Warning: Unknown parameter '$1' ignored."
+ echo " Use --help for usage information."
+ shift
+ ;;
+ esac
+ done
+
+ # Check if conda is available
+ if ! command -v conda &> /dev/null; then
+ echo " Error: conda command not found. Please ensure conda is installed and in PATH."
+ return 1
+ fi
+
+ # Safely deactivate current environment
+ conda deactivate 2>/dev/null || true
+
+ # Define available environments in order of preference (most performant first)
+ local env_priorities=("dalia_xccl_alex" "dalia_base_alex")
+ local env_name=""
+
+ # If environment name is provided as argument, use it directly
+ if [[ -n "$specified_env" ]]; then
+ env_name="$specified_env"
+ echo " Using specified conda environment '${env_name}'..."
+
+ # Validate that the specified environment exists
+ if ! conda env list | grep -q "^${env_name} "; then
+ echo " Error: Conda environment '${env_name}' does not exist."
+ echo " Available environments:"
+ conda env list
+ return 1
+ fi
+ else
+ # Check which environments are available and select the most performant one
+ echo " Checking available DALIA conda environments..."
+ local available_envs=$(conda env list 2>/dev/null | grep -E "^(dalia_xccl_alex|dalia_base_alex) " | awk '{print $1}')
+
+ for preferred_env in "${env_priorities[@]}"; do
+ if echo "$available_envs" | grep -q "^${preferred_env}$"; then
+ env_name="$preferred_env"
+ echo " Selected most performant available environment: '${env_name}'"
+ break
+ fi
+ done
+
+ # Fallback if no DALIA environments found
+ if [[ -z "$env_name" ]]; then
+ echo " Warning: No DALIA-specific environments found. Available environments:"
+ conda env list
+ echo " Falling back to 'base' environment..."
+ env_name="base"
+ fi
+ fi
+
+ echo " Activating conda environment '${env_name}'..."
+ conda activate ${env_name} || {
+ echo " Error: Failed to activate conda environment '${env_name}'"
+ echo " Please check that the environment exists and conda is properly configured."
+ return 1
+ }
+
+ # Verify activation was successful
+ local current_env=$(conda info --envs | grep '\*' | awk '{print $1}')
+ if [[ "$current_env" == "$env_name" ]]; then
+ echo " Successfully activated conda environment '${env_name}'"
+ else
+ echo " Warning: Environment activation may not have been successful."
+ echo " Expected: ${env_name}, Current: ${current_env}"
+ fi
+
+ return 0
+}
+
+alex_set_perfenv() {
+ echo "alex_set_perfenv: setting performance environment variables for Alex."
+
+ unset SLURM_EXPORT_ENV
+
+ export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
+ export SRUN_CPUS_PER_TASK=$SLURM_CPUS_PER_TASK
+}
\ No newline at end of file
diff --git a/scripts/daint_cscs_utils.sh b/scripts/daint_cscs_utils.sh
new file mode 100644
index 00000000..aeb10b1b
--- /dev/null
+++ b/scripts/daint_cscs_utils.sh
@@ -0,0 +1,1075 @@
+#!/bin/bash
+
+daint_install_conda_help() {
+ echo "daint_install_conda: Install Miniconda3 for the user"
+ echo ""
+ echo "Usage:"
+ echo " daint_install_conda [OPTIONS]"
+ echo ""
+ echo "Options:"
+ echo " --install-path=PATH Custom installation path (default: \$HOME/miniconda3)"
+ echo " --no-init Skip shell initialization"
+ echo " --init-all Initialize for all detected shells (bash, zsh)"
+ echo " --yes Skip confirmation prompts"
+ echo " --help, -h Show this help message"
+ echo ""
+ echo "Examples:"
+ echo " daint_install_conda # Install to \$HOME/miniconda3"
+ echo " daint_install_conda --install-path=/custom/path # Install to custom location"
+ echo " daint_install_conda --no-init # Install without shell init"
+ echo " daint_install_conda --yes # Skip all prompts"
+ echo ""
+ echo "Return codes:"
+ echo " 0 - Success"
+ echo " 1 - Existing installation detected"
+ echo " 2 - Download failed"
+ echo " 3 - Installation failed"
+ echo " 4 - Insufficient disk space"
+ echo " 5 - Network connectivity issues"
+}
+
+daint_install_conda() {
+ echo "daint_install_conda: Installing Miniconda3 for the user."
+
+ # Parse command line arguments
+ local install_path="$HOME/miniconda3"
+ local no_init=""
+ local init_all=""
+ local auto_confirm=""
+
+ while [[ $# -gt 0 ]]; do
+ case $1 in
+ --install-path=*)
+ install_path="${1#*=}"
+ shift
+ ;;
+ --no-init)
+ no_init="y"
+ shift
+ ;;
+ --init-all)
+ init_all="y"
+ shift
+ ;;
+ --yes)
+ auto_confirm="y"
+ shift
+ ;;
+ --help|-h)
+ daint_install_conda_help
+ return 0
+ ;;
+ *)
+ echo " Warning: Unknown parameter '$1' ignored."
+ echo " Use --help for usage information."
+ shift
+ ;;
+ esac
+ done
+
+ # Expand tilde and remove trailing slash
+ install_path="${install_path/#\~/$HOME}"
+ install_path="${install_path%/}"
+
+ echo " Target installation path: ${install_path}"
+ echo ""
+
+ # =============================================================================
+ # 1. PRE-INSTALLATION CHECKS
+ # =============================================================================
+
+ echo " Step 1/7: Checking for existing Miniconda/Conda installations..."
+ local existing_installation=""
+
+ # Check if installation directory exists
+ if [[ -d "$install_path" ]]; then
+ echo " Error: Directory '${install_path}' already exists."
+ existing_installation="directory"
+ fi
+
+ # Check if conda is in PATH
+ if command -v conda &> /dev/null; then
+ local conda_location=$(which conda 2>/dev/null)
+ echo " Error: conda command found in PATH at: ${conda_location}"
+ existing_installation="path"
+ fi
+
+ # Check .bashrc for conda initialization
+ if [[ -f "$HOME/.bashrc" ]] && grep -q "# >>> conda initialize >>>" "$HOME/.bashrc"; then
+ echo " Error: Conda initialization block found in ~/.bashrc"
+ existing_installation="bashrc"
+ fi
+
+ # Check .bash_profile for conda references
+ if [[ -f "$HOME/.bash_profile" ]] && grep -qi "conda\|miniconda\|anaconda" "$HOME/.bash_profile"; then
+ echo " Error: Conda references found in ~/.bash_profile"
+ existing_installation="bash_profile"
+ fi
+
+ # Check .zshrc if it exists
+ if [[ -f "$HOME/.zshrc" ]] && grep -q "# >>> conda initialize >>>" "$HOME/.zshrc"; then
+ echo " Error: Conda initialization block found in ~/.zshrc"
+ existing_installation="zshrc"
+ fi
+
+ # If any existing installation found, exit with instructions
+ if [[ -n "$existing_installation" ]]; then
+ echo ""
+ echo " =========================================="
+ echo " EXISTING CONDA INSTALLATION DETECTED"
+ echo " =========================================="
+ echo " An existing conda/miniconda installation was detected on your system."
+ echo " a) If this installation is working, you may continue..."
+ echo " b) If this installation is broken, please manually uninstall it and re-running this script:"
+ echo " 1. Remove the installation directory:"
+ if [[ -d "$install_path" ]]; then
+ echo " rm -rf ${install_path}"
+ fi
+ echo ""
+ echo " 2. Remove conda initialization from shell configuration files:"
+ if [[ -f "$HOME/.bashrc" ]] && grep -q "# >>> conda initialize >>>" "$HOME/.bashrc"; then
+ echo " Edit ~/.bashrc and remove the section between:"
+ echo " '# >>> conda initialize >>>' and '# <<< conda initialize <<<'"
+ fi
+ if [[ -f "$HOME/.bash_profile" ]] && grep -qi "conda\|miniconda\|anaconda" "$HOME/.bash_profile"; then
+ echo " Edit ~/.bash_profile and remove conda-related lines"
+ fi
+ if [[ -f "$HOME/.zshrc" ]] && grep -q "# >>> conda initialize >>>" "$HOME/.zshrc"; then
+ echo " Edit ~/.zshrc and remove the section between:"
+ echo " '# >>> conda initialize >>>' and '# <<< conda initialize <<<'"
+ fi
+ echo ""
+ echo " 3. Start a new terminal session or run: source ~/.bashrc"
+ echo ""
+ echo " 4. Re-run this installer: daint_install_conda"
+ echo ""
+ return 1
+ fi
+
+ echo " No existing conda installation detected."
+ echo ""
+
+ # =============================================================================
+ # 2. VALIDATE INSTALLATION PATH
+ # =============================================================================
+
+ echo " Step 2/7: Validating installation path..."
+
+ # Get parent directory
+ local parent_dir=$(dirname "$install_path")
+
+ # Check if parent directory exists
+ if [[ ! -d "$parent_dir" ]]; then
+ echo " Error: Parent directory '${parent_dir}' does not exist."
+ return 3
+ fi
+
+ # Check write permissions
+ if [[ ! -w "$parent_dir" ]]; then
+ echo " Error: No write permission for parent directory '${parent_dir}'."
+ return 3
+ fi
+
+ echo " Installation path is valid."
+ echo ""
+
+ # =============================================================================
+ # 3. CHECK DISK SPACE
+ # =============================================================================
+
+ echo " Step 3/7: Checking available disk space..."
+
+ # Get available space in KB
+ local available_space=$(df -k "$parent_dir" | tail -1 | awk '{print $4}')
+ local required_space=5242880 # 5 GB in KB
+
+ if [[ $available_space -lt $required_space ]]; then
+ local available_gb=$((available_space / 1024 / 1024))
+ echo " Error: Insufficient disk space. Available: ${available_gb}GB, Required: 5GB"
+ return 4
+ fi
+
+ local available_gb=$((available_space / 1024 / 1024))
+ echo " Sufficient disk space available: ${available_gb}GB"
+ echo ""
+
+ # =============================================================================
+ # 4. CHECK NETWORK CONNECTIVITY AND DOWNLOAD MINICONDA
+ # =============================================================================
+
+ echo " Step 4/7: Downloading Miniconda installer..."
+
+ # Detect system architecture
+ local arch=$(uname -m)
+ local installer_name=""
+
+ case "$arch" in
+ x86_64)
+ installer_name="Miniconda3-latest-Linux-x86_64.sh"
+ ;;
+ aarch64|arm64)
+ installer_name="Miniconda3-latest-Linux-aarch64.sh"
+ ;;
+ *)
+ echo " Error: Unsupported architecture: ${arch}"
+ echo " Supported architectures: x86_64, aarch64"
+ return 3
+ ;;
+ esac
+
+ echo " Detected architecture: ${arch}"
+ echo " Installer: ${installer_name}"
+
+ # Set download URL
+ local download_url="https://repo.anaconda.com/miniconda/${installer_name}"
+ local installer_path="/tmp/${installer_name}"
+
+ # Test network connectivity
+ echo " Testing connectivity to repo.anaconda.com..."
+ if ! curl -s --connect-timeout 10 --max-time 15 "https://repo.anaconda.com" > /dev/null; then
+ if ! wget -q --timeout=10 --tries=1 --spider "https://repo.anaconda.com" 2>/dev/null; then
+ echo " Error: Cannot reach repo.anaconda.com. Please check your network connection."
+ echo " If you're behind a proxy, set http_proxy and https_proxy environment variables."
+ return 5
+ fi
+ fi
+ echo " Network connectivity OK"
+
+ # Remove existing partial download if present
+ if [[ -f "$installer_path" ]]; then
+ echo " Removing existing installer file..."
+ rm -f "$installer_path"
+ fi
+
+ # Download installer
+ echo " Downloading from: ${download_url}"
+ echo " This may take a few minutes..."
+
+ # Try curl first, then wget
+ local download_success=0
+ if command -v curl &> /dev/null; then
+ if curl -L -o "$installer_path" "$download_url" 2>&1 | grep -v "^#"; then
+ download_success=1
+ fi
+ elif command -v wget &> /dev/null; then
+ if wget -O "$installer_path" "$download_url"; then
+ download_success=1
+ fi
+ else
+ echo " Error: Neither curl nor wget is available for downloading."
+ return 2
+ fi
+
+ if [[ $download_success -eq 0 ]]; then
+ echo " Error: Failed to download Miniconda installer."
+ rm -f "$installer_path"
+ return 2
+ fi
+
+ # Verify download
+ if [[ ! -f "$installer_path" ]]; then
+ echo " Error: Installer file not found after download."
+ return 2
+ fi
+
+ local file_size=$(stat -c%s "$installer_path" 2>/dev/null || stat -f%z "$installer_path" 2>/dev/null)
+ if [[ $file_size -lt 50000000 ]]; then # Less than 50MB indicates a problem
+ echo " Error: Downloaded file is too small (${file_size} bytes). Download may be incomplete."
+ rm -f "$installer_path"
+ return 2
+ fi
+
+ echo " Successfully downloaded installer ($(($file_size / 1024 / 1024))MB)"
+ echo ""
+
+ # =============================================================================
+ # 5. CONFIRM INSTALLATION
+ # =============================================================================
+
+ if [[ -z "$auto_confirm" ]]; then
+ echo " =========================================="
+ echo " Ready to install Miniconda3"
+ echo " =========================================="
+ echo " Installation path: ${install_path}"
+ echo " Architecture: ${arch}"
+ echo " Installer size: $(($file_size / 1024 / 1024))MB"
+ echo ""
+ echo -n " Proceed with installation? [y/N]: "
+ read -r response
+ if [[ ! "$response" =~ ^[Yy]$ ]]; then
+ echo " Installation cancelled by user."
+ rm -f "$installer_path"
+ return 0
+ fi
+ fi
+
+ # =============================================================================
+ # 6. RUN INSTALLER
+ # =============================================================================
+
+ echo ""
+ echo " Step 5/7: Running Miniconda installer..."
+ echo " This may take several minutes..."
+
+ # Run installer in batch mode
+ if bash "$installer_path" -b -p "$install_path"; then
+ echo " Successfully installed Miniconda3 to: ${install_path}"
+ else
+ echo " Error: Miniconda installation failed."
+ rm -f "$installer_path"
+ return 3
+ fi
+
+ echo ""
+
+ # =============================================================================
+ # 7. POST-INSTALLATION CONFIGURATION
+ # =============================================================================
+
+ echo " Step 6/7: Configuring conda..."
+
+ # Verify installation
+ if [[ ! -f "${install_path}/bin/conda" ]]; then
+ echo " Error: conda executable not found at ${install_path}/bin/conda"
+ rm -f "$installer_path"
+ return 3
+ fi
+
+ # Get conda version
+ local conda_version=$("${install_path}/bin/conda" --version 2>&1)
+ echo " Installed: ${conda_version}"
+
+ # Initialize conda for shell(s)
+ if [[ -z "$no_init" ]]; then
+ echo " Initializing conda for shell(s)..."
+
+ # Always initialize bash
+ if "${install_path}/bin/conda" init bash > /dev/null 2>&1; then
+ echo " Initialized conda for bash"
+ else
+ echo " Warning: Failed to initialize conda for bash"
+ fi
+
+ # Initialize other shells if requested
+ if [[ -n "$init_all" ]]; then
+ # Initialize zsh if .zshrc exists
+ if [[ -f "$HOME/.zshrc" ]]; then
+ if "${install_path}/bin/conda" init zsh > /dev/null 2>&1; then
+ echo " Initialized conda for zsh"
+ else
+ echo " Warning: Failed to initialize conda for zsh"
+ fi
+ fi
+ fi
+ else
+ echo " Skipping shell initialization (--no-init flag specified)"
+ fi
+
+ echo ""
+
+ # =============================================================================
+ # 8. CLEANUP
+ # =============================================================================
+
+ echo " Step 7/7: Cleaning up..."
+ rm -f "$installer_path"
+ echo " Removed temporary installer file"
+ echo ""
+
+ # =============================================================================
+ # 9. FINAL INSTRUCTIONS
+ # =============================================================================
+
+ echo " =========================================="
+ echo " SUCCESS: Miniconda3 Installed!"
+ echo " =========================================="
+ echo " Installation path: ${install_path}"
+ echo " Conda version: ${conda_version}"
+ echo ""
+
+ if [[ -z "$no_init" ]]; then
+ echo " To activate conda, run ONE of the following:"
+ echo " 1. Start a new terminal session, OR"
+ echo " 2. Run: source ~/.bashrc"
+ echo ""
+ echo " After activation, verify the installation with:"
+ echo " conda --version"
+ echo " conda info"
+ else
+ echo " Shell initialization was skipped."
+ echo " To use conda, you can:"
+ echo " 1. Manually initialize: ${install_path}/bin/conda init bash"
+ echo " 2. Or activate manually: eval \"\$(${install_path}/bin/conda shell.bash hook)\""
+ fi
+ echo ""
+
+ return 0
+}
+
+daint_install_git_lfs() {
+ # This script will install and configure git-lfs on Daint@ALPS as it is not
+ # available by default in the system.
+
+ # 1. Move to your home directory
+ cd $HOME || { echo "Could not change to home directory"; exit 1; }
+
+ # 2. Download git-lfs for ARM64
+ wget https://github.com/git-lfs/git-lfs/releases/download/v3.7.1/git-lfs-linux-arm64-v3.7.1.tar.gz || { echo "Could not download git-lfs"; exit 1; }
+
+ # 3. Extract the downloaded tarball
+ tar -xvf git-lfs-linux-arm64-v3.7.1.tar.gz || { echo "Could not extract git-lfs tarball"; exit 1; }
+
+ # 4. Move in the extracted directory and make the installer executable
+ cd git-lfs-3.7.1 || { echo "Could not change to git-lfs directory"; exit 1; }
+ chmod +x install.sh || { echo "Could not make installer executable"; exit 1; }
+
+ # 5. Change the installer prefix to your home directory
+ sed -i 's|^prefix="/usr/local"$|prefix="$HOME/.local"|' install.sh || { echo "Could not modify installer prefix"; exit 1; }
+
+ # 6. Make the .local/bin directory if it does not exist
+ mkdir -p "$HOME/.local/bin" || { echo "Could not create .local/bin directory"; exit 1; }
+
+ # 7. Run the installer
+ ./install.sh || { echo "Could not install git-lfs"; exit 1; }
+
+ # 8. Add .local/bin to your PATH if not already present
+ if [[ ":$PATH:" != *":$HOME/.local/bin:"* ]]; then
+ export PATH="$HOME/.local/bin:$PATH"
+ fi
+
+ # 9 Add .local/bin to your PATH in .bashrc for future sessions
+ if ! grep -q 'export PATH="$HOME/.local/bin:$PATH"' "$HOME/.bashrc"; then
+ echo 'export PATH="$HOME/.local/bin:$PATH"' >> "$HOME/.bashrc"
+ fi
+
+ # 10. Verify the installation
+ if command -v git-lfs &> /dev/null; then
+ echo "git-lfs installed successfully!"
+ else
+ echo "git-lfs installation failed"
+ exit 1
+ fi
+}
+
+daint_install_uenv() {
+ # Documentation: https://docs.cscs.ch/software/uenv/
+ echo "daint_install_uenv: setting up Daint.Alps virtual environment."
+
+ if uenv image ls | grep -q "prgenv-gnu/25.6:v2"; then
+ echo " prgenv-gnu/25.6:v2 image already exists, skipping pull"
+ else
+ echo " Pulling prgenv-gnu/25.6:v2 image..."
+ uenv image pull prgenv-gnu/25.6:v2 || {
+ echo " Error: Failed to pull prgenv-gnu/25.6:v2 image"
+ return 1
+ }
+ fi
+
+ return 0
+}
+
+daint_start_uenv() {
+ # Documentation: https://docs.cscs.ch/software/uenv/
+ echo "daint_start_uenv: starting Daint.Alps uenv environment."
+
+ # Stop any existing uenv session
+ uenv stop 2>/dev/null || echo " (No existing uenv to stop)"
+
+ # Start new uenv session
+ echo " Starting uenv with prgenv-gnu/25.6:v2..."
+ echo " WARNING: This is gonna start a new shell session, if you want to"
+ echo " use other functions from this script, you need to source it again."
+ uenv start --view=modules prgenv-gnu/25.6:v2
+}
+
+daint_load_modules() {
+ echo "daint_load_modules: loading Daint system modules."
+
+ # Check if we're already in a uenv session
+ if ! uenv status &>/dev/null; then
+ echo "Error: Not in a uenv session. Modules can only be loaded within a uenv session."
+ return 1
+ fi
+
+ # Check if module command is available
+ if ! command -v module &> /dev/null; then
+ echo " Error: 'module' command not found. Please ensure you are on a system with environment modules."
+ return 1
+ fi
+
+ # Purge any existing modules
+ echo " Purging existing modules..."
+ module purge 2>/dev/null || {
+ echo " Warning: Failed to purge modules (this may be normal on some systems)."
+ }
+
+ # Load required modules
+ echo " Loading required modules: cuda gcc meson ninja nccl cray-mpich cmake openblas aws-ofi-nccl netlib-scalapack"
+ module load cuda gcc meson ninja nccl cray-mpich cmake openblas aws-ofi-nccl netlib-scalapack || {
+ echo " Error: Failed to load required modules."
+ echo " Available modules:"
+ module avail 2>&1 | head -20
+ echo " (output truncated - use 'module avail' for full list)"
+ return 1
+ }
+
+ # Check for CUDA_HOME environment variable
+ if [[ -z "$CUDA_HOME" ]]; then
+ echo " Error: CUDA_HOME not set, please ensure the CUDA module properly sets CUDA_HOME or manually set CUDA_HOME to your CUDA installation directory"
+ return 1
+ fi
+
+ # Set CUDA environment variables
+ export CUDA_DIR=$CUDA_HOME
+ export CUDA_PATH=$CUDA_HOME
+ export CPATH=$CUDA_HOME/include:$CPATH
+ export LIBRARY_PATH=$CUDA_HOME/lib64:$LIBRARY_PATH
+ export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
+
+ # Set NCCL environment variables (dynamically find the nccl installation)
+ export NCCL_ROOT=$(ls -d /user-environment/linux-neoverse_v2/nccl-* 2>/dev/null | head -1)
+ if [[ -z "$NCCL_ROOT" ]]; then
+ echo " Error: NCCL not found in /user-environment/linux-neoverse_v2/"
+ return 1
+ fi
+ export NCCL_LIB_DIR=$NCCL_ROOT/lib
+ export NCCL_INCLUDE_DIR=$NCCL_ROOT/include
+
+ export CPATH=$NCCL_ROOT/include:$CPATH
+ export CFLAGS="-I$NCCL_INCLUDE_DIR":$CFLAGS
+ export LDFLAGS="-L$NCCL_LIB_DIR":$LDFLAGS
+ export LIBRARY_PATH=$NCCL_LIB_DIR:$LIBRARY_PATH
+ export LD_LIBRARY_PATH=$NCCL_LIB_DIR:$LD_LIBRARY_PATH
+
+ echo " Successfully loaded all required modules."
+ return 0
+}
+
+daint_check_modules() {
+ echo "daint_check_modules: checking if required modules are loaded."
+
+ local required_modules=("cuda" "gcc" "meson" "ninja" "nccl" "cray-mpich" "cmake" "openblas" "aws-ofi-nccl" "netlib-scalapack")
+ local missing_modules=()
+
+ # Get list of currently loaded modules
+ local loaded_modules=$(module list 2>&1 | grep -E "cuda|gcc|meson|ninja|nccl|cray-mpich|cmake|openblas|aws-ofi-nccl|netlib-scalapack")
+
+ # Check each required module
+ for module in "${required_modules[@]}"; do
+ if ! echo "$loaded_modules" | grep -q "$module"; then
+ missing_modules+=("$module")
+ fi
+ done
+
+ if [ ${#missing_modules[@]} -eq 0 ]; then
+ echo " All required modules are loaded."
+ return 0
+ else
+ echo " Error: Missing required modules: ${missing_modules[*]}"
+ echo " Please run 'daint_load_modules' first."
+ return 1
+ fi
+}
+
+daint_create_conda_env_help() {
+ echo "daint_create_conda_env: Create DALIA conda environment for Daint supercomputer"
+ echo ""
+ echo "Usage:"
+ echo " daint_create_conda_env [OPTIONS]"
+ echo ""
+ echo "Options:"
+ echo " --dalia-path=PATH Path to DALIA repository root directory"
+ echo " --serinv-path=PATH Path to serinv repository (automatically installs serinv)"
+ echo " --install-mpi4py Install mpi4py and create enhanced environment"
+ echo " --dev-mode Keep intermediate environments (use with --install-mpi4py)"
+ echo ""
+ echo "Examples:"
+ echo " # Interactive mode (GPU support included by default)"
+ echo " daint_create_conda_env"
+ echo ""
+ echo " # Non-interactive mode with all options"
+ echo " daint_create_conda_env --dalia-path=/path/to/dalia --serinv-path=/path/to/serinv --install-mpi4py --dev-mode"
+ echo ""
+ echo " # Install base with GPU support only"
+ echo " daint_create_conda_env --dalia-path=/path/to/dalia"
+ echo ""
+ echo "Note: GPU support via cupy is installed by default for Daint cluster."
+ echo " If parameters are not provided, the function will prompt interactively."
+}
+
+daint_create_conda_env() {
+ echo "daint_create_conda_env: creating DALIA conda environment for Daint."
+
+ # Parse command line arguments
+ local dalia_path=""
+ local serinv_path=""
+ local install_mpi4py_flag=""
+ local dev_mode_flag=""
+
+ while [[ $# -gt 0 ]]; do
+ case $1 in
+ --dalia-path=*)
+ dalia_path="${1#*=}"
+ shift
+ ;;
+ --serinv-path=*)
+ serinv_path="${1#*=}"
+ shift
+ ;;
+ --install-mpi4py)
+ install_mpi4py_flag="y"
+ shift
+ ;;
+ --dev-mode)
+ dev_mode_flag="y"
+ shift
+ ;;
+ --help|-h)
+ daint_create_conda_env_help
+ return 0
+ ;;
+ *)
+ echo " Warning: Unknown parameter '$1' ignored."
+ echo " Use --help for usage information."
+ shift
+ ;;
+ esac
+ done
+
+ # 0. Deactivate any currently active conda environment
+ echo " Deactivating any currently active conda environment..."
+ conda deactivate 2>/dev/null || true
+
+ # 1. Check that the needed modules are loaded
+ if ! daint_check_modules; then
+ return 1
+ fi
+
+ # 2. Get DALIA repository path
+ if [[ -z "$dalia_path" ]]; then
+ echo " Please enter the path to the DALIA root repository:"
+ read -r dalia_path
+ else
+ echo " Using provided DALIA path: ${dalia_path}"
+ fi
+
+ # Expand tilde and remove trailing slash
+ dalia_path="${dalia_path/#~/$HOME}"
+ dalia_path="${dalia_path%/}"
+
+ # 3. Validate DALIA repository path
+ if [[ ! -d "$dalia_path" ]]; then
+ echo " Error: Directory '${dalia_path}' does not exist."
+ return 1
+ fi
+
+ if [[ ! -f "${dalia_path}/pyproject.toml" ]]; then
+ echo " Error: '${dalia_path}' does not appear to be a DALIA repository (missing pyproject.toml)."
+ return 1
+ fi
+
+ if [[ ! -d "${dalia_path}/src/dalia" ]]; then
+ echo " Error: '${dalia_path}' does not contain the DALIA source code (missing src/dalia/)."
+ return 1
+ fi
+
+ # 4. Locate the conda environment file
+ local env_file="${dalia_path}/envs/dalia_base_aarch64.yml"
+ if [[ ! -f "$env_file" ]]; then
+ echo " Error: Conda environment file not found at '${env_file}'."
+ return 1
+ fi
+
+ echo " Found DALIA repository at: ${dalia_path}"
+ echo " Found conda environment file at: ${env_file}"
+
+ # Check if environment already exists
+ local env_name="dalia_base_daint"
+ if conda env list | grep -q "^${env_name} "; then
+ echo " Warning: Conda environment '${env_name}' already exists."
+ echo " Do you want to remove and recreate it? (y/N): "
+ read -r response
+ if [[ "$response" =~ ^[Yy]$ ]]; then
+ echo " Removing existing environment..."
+ conda env remove -n "$env_name" -y || {
+ echo " Error: Failed to remove existing environment."
+ return 1
+ }
+ else
+ echo " Skipping environment creation. Using existing environment."
+ fi
+ fi
+
+ # Create conda environment from YAML file (only if it doesn't exist or was removed)
+ if ! conda env list | grep -q "^${env_name} "; then
+ echo " Creating conda environment from '${env_file}'..."
+ conda env create --name "$env_name" -f "$env_file" || {
+ echo " Error: Failed to create conda environment from '${env_file}'."
+ return 1
+ }
+ fi
+
+ # 5. Activate the conda environment
+ echo " Activating conda environment..."
+ if ! daint_activate_conda_env --env="$env_name"; then
+ return 1
+ fi
+
+ # 6. Install DALIA from the repository
+ echo " Installing DALIA in development mode..."
+ cd "$dalia_path" || {
+ echo " Error: Failed to change directory to '${dalia_path}'."
+ return 1
+ }
+
+ python -m pip install --no-deps --editable . || {
+ echo " Error: Failed to install DALIA in development mode."
+ return 1
+ }
+
+ # 7. Optional: Install serinv structured sparse solver
+ echo ""
+ local install_serinv=""
+
+ if [[ -n "$serinv_path" ]]; then
+ # Serinv path provided via command line - install automatically
+ echo " Serinv path provided via --serinv-path. Installing serinv automatically..."
+ install_serinv="y"
+ else
+ # Interactive mode - ask user
+ echo " Spatio-temporal problems in DALIA can leverage the 'serinv' structured sparse solver for improved performance."
+ echo " Do you want to install the serinv structured sparse solver? (y/N): "
+ read -r install_serinv
+
+ if [[ "$install_serinv" =~ ^[Yy]$ ]]; then
+ echo " Please enter the path to the serinv root repository:"
+ read -r serinv_path
+ fi
+ fi
+
+ if [[ "$install_serinv" =~ ^[Yy]$ ]]; then
+ # Expand tilde and remove trailing slash
+ serinv_path="${serinv_path/#~/$HOME}"
+ serinv_path="${serinv_path%/}"
+
+ # Validate serinv repository path
+ if [[ ! -d "$serinv_path" ]]; then
+ echo " Error: Directory '${serinv_path}' does not exist."
+ echo " Skipping serinv installation."
+ elif [[ ! -f "${serinv_path}/pyproject.toml" ]] && [[ ! -f "${serinv_path}/setup.py" ]]; then
+ echo " Error: '${serinv_path}' does not appear to be a valid Python package (missing pyproject.toml or setup.py)."
+ echo " Skipping serinv installation."
+ else
+ echo " Installing serinv from '${serinv_path}' in development mode..."
+ cd "$serinv_path" || {
+ echo " Error: Failed to change directory to '${serinv_path}'."
+ echo " Skipping serinv installation."
+ }
+
+ if python -m pip install --no-deps --editable .; then
+ echo " Successfully installed serinv in development mode."
+ else
+ echo " Warning: Failed to install serinv. You may need to install it manually later."
+ fi
+
+ # Return to DALIA directory
+ cd "$dalia_path" || true
+ fi
+ else
+ echo " Skipping serinv installation."
+ fi
+
+ # 8. Install cupy with GPU support (default for Daint cluster)
+ echo " Installing cupy for GPU support..."
+
+ if python -m pip install cupy-cuda12x --no-cache-dir; then
+ # Check if CuPy is working
+ if python -c "import cupy; A = cupy.random.rand(10,10)" &> /dev/null; then
+ echo " Successfully installed cupy for GPU support."
+ echo " CuPy configuration:"
+ python -c "import cupy; cupy.show_config()"
+ else
+ echo " Warning: Could not verify CuPy installation. Please test it manually."
+ fi
+ else
+ echo " Warning: Failed to install cupy. You may need to install it manually later."
+ fi
+
+ # 9. Optional: Install mpi4py and create enhanced environment
+ echo ""
+ local install_mpi4py=""
+
+ if [[ -n "$install_mpi4py_flag" ]]; then
+ # MPI4py installation requested via command line
+ echo " MPI4py installation requested via --install-mpi4py. Installing automatically..."
+ install_mpi4py="y"
+ else
+ # Interactive mode - ask user
+ echo " MPI support can be added through mpi4py for improved parallel performance."
+ echo " Do you want to install mpi4py and create an enhanced environment? (y/N): "
+ read -r install_mpi4py
+ fi
+
+ if [[ "$install_mpi4py" =~ ^[Yy]$ ]]; then
+ echo " Creating enhanced environment with mpi4py support..."
+
+ # Determine the enhanced environment name
+ local mpi_enhanced_env_name="dalia_xccl_daint" # Always GPU + MPI for Daint, NCCL is installed with CuPy.
+ echo " Creating environment with GPU and mpi4py support..."
+
+ # Deactivate current environment
+ echo " Deactivating current environment to create MPI-enhanced version..."
+ conda deactivate 2>/dev/null || true
+
+ # Remove enhanced environment if it already exists
+ if conda env list | grep -q "^${mpi_enhanced_env_name} "; then
+ echo " Removing existing MPI-enhanced environment..."
+ conda env remove -n "$mpi_enhanced_env_name" -y || {
+ echo " Warning: Failed to remove existing MPI-enhanced environment."
+ }
+ fi
+
+ # Clone the current environment
+ if conda create --name "$mpi_enhanced_env_name" --clone "$env_name" -y; then
+ echo " Successfully created MPI-enhanced environment."
+
+ # Activate the enhanced environment
+ echo " Activating MPI-enhanced environment '${mpi_enhanced_env_name}'..."
+ if daint_activate_conda_env --env="$mpi_enhanced_env_name"; then
+ # Install mpi4py in the enhanced environment
+ echo " Installing mpi4py with OpenMPI support in enhanced environment..."
+ cd "$dalia_path" || true
+ if MPICC=$(which mpicc) python -m pip install --no-cache-dir --no-binary=mpi4py mpi4py; then
+ echo " Successfully installed mpi4py in MPI-enhanced environment."
+
+ # Update the libcxx given Daint NCCL modules requirements
+ conda update libstdcxx-ng
+ conda install -c conda-forge libstdcxx-ng
+
+ # Try to import nccl to ensure it's available
+ echo " Verifying NCCL installation in MPI-enhanced environment..."
+ if python -c "from cupy.cuda import nccl; nccl.get_unique_id()"; then
+ echo " NCCL is available in MPI-enhanced environment."
+ else
+ echo " Warning: NCCL does not seem to be available in MPI-enhanced environment."
+ fi
+
+ # Determine whether to keep base environment
+ local keep_base=""
+ if [[ -n "$install_mpi4py_flag" ]]; then
+ # Command line mode - use dev_mode_flag to decide
+ if [[ -n "$dev_mode_flag" ]]; then
+ keep_base="y"
+ echo " Developer mode enabled via --dev-mode. Keeping base environment."
+ else
+ keep_base="n"
+ echo " Default mode: removing base environment to keep only enhanced version."
+ fi
+ else
+ # Interactive mode - ask user
+ echo ""
+ echo " Do you want to keep the base environment 'dalia_base_daint' for development without MPI? (y/N): "
+ read -r keep_base
+ fi
+
+ if [[ ! "$keep_base" =~ ^[Yy]$ ]]; then
+ echo " Removing base environment 'dalia_base_daint'..."
+ conda env remove -n "dalia_base_daint" -y || {
+ echo " Warning: Failed to remove base environment."
+ }
+ else
+ echo " Keeping base environment for development without MPI."
+ fi
+
+ env_name="$mpi_enhanced_env_name" # Update env_name for final message
+ else
+ echo " Error: Failed to install mpi4py in MPI-enhanced environment."
+ echo " Removing broken MPI-enhanced environment and reverting to base environment..."
+
+ # Deactivate the enhanced environment
+ conda deactivate 2>/dev/null || true
+
+ # Remove the broken enhanced environment
+ conda env remove -n "$mpi_enhanced_env_name" -y || {
+ echo " Warning: Failed to remove broken MPI-enhanced environment."
+ }
+
+ # Reactivate the base environment
+ if daint_activate_conda_env --env="dalia_base_daint"; then
+ echo " Reverted to base environment 'dalia_base_daint'."
+ env_name="dalia_base_daint" # Reset env_name to base environment
+ echo " You can install mpi4py manually later with: MPICC=\$(which mpicc) pip install --no-cache-dir mpi4py"
+ else
+ echo " Warning: Failed to reactivate base environment."
+ fi
+ fi
+ else
+ echo " Warning: Failed to activate MPI-enhanced environment."
+ fi
+ else
+ echo " Error: Failed to create MPI-enhanced environment."
+ echo " Continuing with base environment..."
+ fi
+ else
+ echo " Skipping mpi4py installation."
+ fi
+
+ # 10. Final success message
+ echo ""
+ echo " Success! DALIA conda environment '${env_name}' has been created and configured."
+ echo " Repository path: ${dalia_path}"
+ if [[ "$install_serinv" =~ ^[Yy]$ ]] && [[ -d "$serinv_path" ]]; then
+ echo " Serinv path: ${serinv_path}"
+ fi
+ echo " Base environment includes GPU support via cupy (default for Daint cluster)."
+ if [[ "$install_mpi4py" =~ ^[Yy]$ ]]; then
+ if [[ "$env_name" == *"xccl"* ]]; then
+ echo " Enhanced environment with mpi4py and NCCL support created."
+ fi
+ fi
+ echo " To use a specific environment in the future, run: daint_activate_conda_env --env=\"daint_env_name\""
+ echo " To use the most performant environment you have available, just run: daint_activate_conda_env"
+
+ return 0
+}
+
+daint_activate_conda_env() {
+ echo "daint_activate_conda_env: activating DALIA conda environment."
+
+ # Parse command line arguments
+ local specified_env=""
+
+ while [[ $# -gt 0 ]]; do
+ case $1 in
+ --env=*)
+ specified_env="${1#*=}"
+ shift
+ ;;
+ --help|-h)
+ echo "daint_activate_conda_env: Activate DALIA conda environment"
+ echo ""
+ echo "Usage:"
+ echo " daint_activate_conda_env [OPTIONS]"
+ echo ""
+ echo "Options:"
+ echo " --env=NAME Specific environment name to activate"
+ echo " --help/-h Show this help message"
+ echo ""
+ echo "Examples:"
+ echo " daint_activate_conda_env # Auto-select best available"
+ echo " daint_activate_conda_env --env=dalia_base_daint # Activate specific environment"
+ return 0
+ ;;
+ *)
+ echo " Warning: Unknown parameter '$1' ignored."
+ echo " Use --help for usage information."
+ shift
+ ;;
+ esac
+ done
+
+ # Check if conda is available
+ if ! command -v conda &> /dev/null; then
+ echo " Error: conda command not found. Please ensure conda is installed and in PATH."
+ return 1
+ fi
+
+ # Safely deactivate current environment
+ conda deactivate 2>/dev/null || true
+
+ # Define available environments in order of preference (most performant first)
+ local env_priorities=("dalia_xccl_daint" "dalia_base_daint")
+ local env_name=""
+
+ # If environment name is provided as argument, use it directly
+ if [[ -n "$specified_env" ]]; then
+ env_name="$specified_env"
+ echo " Using specified conda environment '${env_name}'..."
+
+ # Validate that the specified environment exists
+ if ! conda env list | grep -q "^${env_name} "; then
+ echo " Error: Conda environment '${env_name}' does not exist."
+ echo " Available environments:"
+ conda env list
+ return 1
+ fi
+ else
+ # Check which environments are available and select the most performant one
+ echo " Checking available DALIA conda environments..."
+ local available_envs=$(conda env list 2>/dev/null | grep -E "^(dalia_xccl_daint|dalia_base_daint) " | awk '{print $1}')
+
+ for preferred_env in "${env_priorities[@]}"; do
+ if echo "$available_envs" | grep -q "^${preferred_env}$"; then
+ env_name="$preferred_env"
+ echo " Selected most performant available environment: '${env_name}'"
+ break
+ fi
+ done
+
+ # Fallback if no DALIA environments found
+ if [[ -z "$env_name" ]]; then
+ echo " Warning: No DALIA-specific environments found. Available environments:"
+ conda env list
+ echo " Falling back to 'base' environment..."
+ env_name="base"
+ fi
+ fi
+
+ echo " Activating conda environment '${env_name}'..."
+ conda activate ${env_name} || {
+ echo " Error: Failed to activate conda environment '${env_name}'"
+ echo " Please check that the environment exists and conda is properly configured."
+ return 1
+ }
+
+ # Verify activation was successful
+ local current_env=$(conda info --envs | grep '\*' | grep 'dalia*' | awk '{print $1}')
+ if [[ "$current_env" == "$env_name" ]]; then
+ echo " Successfully activated conda environment '${env_name}'"
+ else
+ echo " Warning: Environment activation may not have been successful."
+ echo " Expected: ${env_name}, Current: ${current_env}"
+ fi
+
+ return 0
+}
+
+daint_set_perfenv() {
+ echo "daint_set_perfenv: setting performance environment variables for Daint."
+
+ set -e
+
+ export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
+ export CUDA_VISIBLE_DEVICES=$SLURM_LOCALID
+ export MPICH_GPU_SUPPORT_ENABLED=0
+
+ # NCCL Performance Configuration
+ # More can be found: https://docs.cscs.ch/software/communication/nccl/#using-nccl
+ # This forces NCCL to use the libfabric plugin, enabling full use of the
+ # Slingshot network. If the plugin can not be found, applications will fail to
+ # start. With the default value, applications would instead fall back to e.g.
+ # TCP, which would be significantly slower than with the plugin. More information
+ # about `NCCL_NET` can be found at:
+ # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-net
+ export NCCL_NET="AWS Libfabric"
+ # Use GPU Direct RDMA when GPU and NIC are on the same NUMA node. More
+ # information about `NCCL_NET_GDR_LEVEL` can be found at:
+ # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-net-gdr-level-formerly-nccl-ib-gdr-level
+ export NCCL_NET_GDR_LEVEL=PHB
+ export NCCL_CROSS_NIC=1
+ # Starting with nccl 2.27 a new protocol (LL128) was enabled by default, which
+ # typically performs worse on Slingshot. The following disables that protocol.
+ export NCCL_PROTO=^LL128
+ # These `FI` (libfabric) environment variables have been found to give the best
+ # performance on the Alps network across a wide range of applications. Specific
+ # applications may perform better with other values.
+ export FI_CXI_DEFAULT_CQ_SIZE=131072
+ export FI_CXI_DEFAULT_TX_SIZE=16384
+ export FI_CXI_DISABLE_HOST_REGISTER=1
+ export FI_CXI_RX_MATCH_MODE=software
+ export FI_MR_CACHE_MONITOR=userfaultfd
+
+ export FI_CXI_RDZV_GET_MIN=0
+ export FI_CXI_RDZV_THRESHOLD=0
+ export FI_CXI_RDZV_EAGER_SIZE=0
+}
\ No newline at end of file
diff --git a/scripts/dalia_job_utils.sh b/scripts/dalia_job_utils.sh
new file mode 100644
index 00000000..070e5543
--- /dev/null
+++ b/scripts/dalia_job_utils.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+
+dalia_set_perfenv() {
+ echo "dalia_set_perfenv: setting up DALIA performance environment variables."
+ export ARRAY_MODULE=numpy # numpy or cupy
+ export MPI_CUDA_AWARE=0 # 0 or 1
+ export USE_NCCL=0 # 0 or 1
+ export MPICH_GPU_SUPPORT_ENABLED=0 # 0 or 1
+ echo "DALIA Environment Configuration:"
+ echo " - ARRAY_MODULE: ${ARRAY_MODULE}"
+ echo " - MPI_CUDA_AWARE: ${MPI_CUDA_AWARE}"
+ echo " - USE_NCCL: ${USE_NCCL}"
+ echo " - MPICH_GPU_SUPPORT_ENABLED: ${MPICH_GPU_SUPPORT_ENABLED}"
+ echo ""
+}
+
+dalia_print_job_config() {
+ echo "dalia_print_job_config: printing job configuration."
+ echo "SLURM Job Configuration:"
+ echo " - Job Name: ${SLURM_JOB_NAME}"
+ echo " - Job ID: ${SLURM_JOB_ID}"
+ echo " - Nodes: ${SLURM_NNODES}"
+ echo " - Tasks per node: ${SLURM_NTASKS_PER_NODE}"
+ echo " - Total tasks: ${SLURM_NTASKS}"
+ echo " - CPUs per task: ${SLURM_CPUS_PER_TASK}"
+ if nvidia-smi &> /dev/null; then
+ echo " - GPUs per task: 1"
+ else
+ echo " - GPUs per task: 0 (seemingly no GPU available)"
+ fi
+ echo " - Time limit: ${SLURM_TIMELIMIT}"
+ echo " - Partition: ${SLURM_JOB_PARTITION}"
+ echo " - Account: ${SLURM_JOB_ACCOUNT}"
+}
diff --git a/scripts/fritz_fau_utils.sh b/scripts/fritz_fau_utils.sh
new file mode 100644
index 00000000..a53cc1fd
--- /dev/null
+++ b/scripts/fritz_fau_utils.sh
@@ -0,0 +1,502 @@
+#!/bin/bash
+
+fritz_load_modules() {
+ echo "fritz_load_modules: loading Fritz system modules."
+
+ # Check if module command is available
+ if ! command -v module &> /dev/null; then
+ echo " Error: 'module' command not found. Please ensure you are on a system with environment modules."
+ return 1
+ fi
+
+ # Purge any existing modules
+ echo " Purging existing modules..."
+ module purge 2>/dev/null || {
+ echo " Warning: Failed to purge modules (this may be normal on some systems)."
+ }
+
+ # Load required modules
+ echo " Loading required modules: intelmpi/2021.10.0 mkl/2023.2.0 gcc/12.1.0 python"
+ module load intelmpi/2021.10.0 mkl/2023.2.0 gcc/12.1.0 python || {
+ echo " Error: Failed to load required modules."
+ echo " Available modules:"
+ module avail 2>&1 | head -20
+ echo " (output truncated - use 'module avail' for full list)"
+ return 1
+ }
+
+ echo " Successfully loaded all required modules."
+ return 0
+}
+
+fritz_check_modules() {
+ echo "fritz_check_modules: checking if required modules are loaded."
+
+ local required_modules=("intelmpi/2021.10.0" "mkl/2023.2.0" "gcc/12.1.0" "python")
+ local missing_modules=()
+
+ # Get list of currently loaded modules
+ local loaded_modules=$(module list 2>&1 | grep -E "intelmpi|mkl|gcc|python")
+
+ # Check each required module
+ for module in "${required_modules[@]}"; do
+ if ! echo "$loaded_modules" | grep -q "$module"; then
+ missing_modules+=("$module")
+ fi
+ done
+
+ if [ ${#missing_modules[@]} -eq 0 ]; then
+ echo " All required modules are loaded."
+ return 0
+ else
+ echo " Error: Missing required modules: ${missing_modules[*]}"
+ echo " Please run 'fritz_load_modules' first."
+ return 1
+ fi
+}
+
+fritz_create_conda_env_help() {
+ echo "fritz_create_conda_env: Create DALIA conda environment for Fritz supercomputer"
+ echo ""
+ echo "Usage:"
+ echo " fritz_create_conda_env [OPTIONS]"
+ echo ""
+ echo "Options:"
+ echo " --dalia-path=PATH Path to DALIA repository root directory"
+ echo " --serinv-path=PATH Path to serinv repository (automatically installs serinv)"
+ echo " --install-mpi4py Install mpi4py and create enhanced environment"
+ echo " --dev-mode Keep both base and enhanced environments (use with --install-mpi4py)"
+ echo ""
+ echo "Examples:"
+ echo " # Interactive mode"
+ echo " fritz_create_conda_env"
+ echo ""
+ echo " # Non-interactive mode with all options"
+ echo " fritz_create_conda_env --dalia-path=/path/to/dalia --serinv-path=/path/to/serinv --install-mpi4py --dev-mode"
+ echo ""
+ echo " # Install only DALIA and mpi4py (removes base environment)"
+ echo " fritz_create_conda_env --dalia-path=/path/to/dalia --install-mpi4py"
+ echo ""
+ echo "Note: If parameters are not provided, the function will prompt interactively."
+}
+
+fritz_create_conda_env() {
+ echo "fritz_create_conda_env: creating DALIA conda environment for Fritz."
+
+ # Parse command line arguments
+ local dalia_path=""
+ local serinv_path=""
+ local install_mpi4py_flag=""
+ local dev_mode_flag=""
+
+ while [[ $# -gt 0 ]]; do
+ case $1 in
+ --dalia-path=*)
+ dalia_path="${1#*=}"
+ shift
+ ;;
+ --serinv-path=*)
+ serinv_path="${1#*=}"
+ shift
+ ;;
+ --install-mpi4py)
+ install_mpi4py_flag="y"
+ shift
+ ;;
+ --dev-mode)
+ dev_mode_flag="y"
+ shift
+ ;;
+ --help|-h)
+ fritz_create_conda_env_help
+ return 0
+ ;;
+ *)
+ echo " Warning: Unknown parameter '$1' ignored."
+ echo " Use --help for usage information."
+ shift
+ ;;
+ esac
+ done
+
+ # 0. Deactivate any currently active conda environment
+ echo " Deactivating any currently active conda environment..."
+ conda deactivate 2>/dev/null || true
+
+ # 1. Check that the needed modules are loaded
+ if ! fritz_check_modules; then
+ return 1
+ fi
+
+ # 2. Get DALIA repository path
+ if [[ -z "$dalia_path" ]]; then
+ echo " Please enter the path to the DALIA root repository:"
+ read -r dalia_path
+ else
+ echo " Using provided DALIA path: ${dalia_path}"
+ fi
+
+ # Expand tilde and remove trailing slash
+ dalia_path="${dalia_path/#\~/$HOME}"
+ dalia_path="${dalia_path%/}"
+
+ # 3. Validate DALIA repository path
+ if [[ ! -d "$dalia_path" ]]; then
+ echo " Error: Directory '${dalia_path}' does not exist."
+ return 1
+ fi
+
+ if [[ ! -f "${dalia_path}/pyproject.toml" ]]; then
+ echo " Error: '${dalia_path}' does not appear to be a DALIA repository (missing pyproject.toml)."
+ return 1
+ fi
+
+ if [[ ! -d "${dalia_path}/src/dalia" ]]; then
+ echo " Error: '${dalia_path}' does not contain the DALIA source code (missing src/dalia/)."
+ return 1
+ fi
+
+ # 4. Locate the conda environment file
+ local env_file="${dalia_path}/envs/dalia_base_x86.yml"
+ if [[ ! -f "$env_file" ]]; then
+ echo " Error: Conda environment file not found at '${env_file}'."
+ return 1
+ fi
+
+ echo " Found DALIA repository at: ${dalia_path}"
+ echo " Found conda environment file at: ${env_file}"
+
+ # Check if environment already exists
+ local env_name="dalia_base_fritz"
+ if conda env list | grep -q "^${env_name} "; then
+ echo " Warning: Conda environment '${env_name}' already exists."
+ echo " Do you want to remove and recreate it? (y/N): "
+ read -r response
+ if [[ "$response" =~ ^[Yy]$ ]]; then
+ echo " Removing existing environment..."
+ conda env remove -n "$env_name" -y || {
+ echo " Error: Failed to remove existing environment."
+ return 1
+ }
+ else
+ echo " Skipping environment creation. Using existing environment."
+ fi
+ fi
+
+ # Create conda environment from YAML file (only if it doesn't exist or was removed)
+ if ! conda env list | grep -q "^${env_name} "; then
+ echo " Creating conda environment from '${env_file}'..."
+ conda env create --name "$env_name" -f "$env_file" || {
+ echo " Error: Failed to create conda environment from '${env_file}'."
+ return 1
+ }
+ fi
+
+ # 5. Activate the conda environment
+ echo " Activating conda environment..."
+ if ! fritz_activate_conda_env --env="$env_name"; then
+ return 1
+ fi
+
+ # 6. Install DALIA from the repository
+ echo " Installing DALIA in development mode..."
+ cd "$dalia_path" || {
+ echo " Error: Failed to change directory to '${dalia_path}'."
+ return 1
+ }
+
+ python -m pip install --no-deps --editable . || {
+ echo " Error: Failed to install DALIA in development mode."
+ return 1
+ }
+
+ # 7. Optional: Install serinv structured sparse solver
+ echo ""
+ local install_serinv=""
+
+ if [[ -n "$serinv_path" ]]; then
+ # Serinv path provided via command line - install automatically
+ echo " Serinv path provided via --serinv-path. Installing serinv automatically..."
+ install_serinv="y"
+ else
+ # Interactive mode - ask user
+ echo " Spatio-temporal problems in DALIA can leverage the 'serinv' structured sparse solver for improved performance."
+ echo " Do you want to install the serinv structured sparse solver? (y/N): "
+ read -r install_serinv
+
+ if [[ "$install_serinv" =~ ^[Yy]$ ]]; then
+ echo " Please enter the path to the serinv root repository:"
+ read -r serinv_path
+ fi
+ fi
+
+ if [[ "$install_serinv" =~ ^[Yy]$ ]]; then
+ # Expand tilde (if present) and remove trailing slash
+ serinv_path="${serinv_path/#~/$HOME}"
+ serinv_path="${serinv_path%/}"
+
+ # Validate serinv repository path
+ if [[ ! -d "$serinv_path" ]]; then
+ echo " Error: Directory '${serinv_path}' does not exist."
+ echo " Skipping serinv installation."
+ elif [[ ! -f "${serinv_path}/pyproject.toml" ]] && [[ ! -f "${serinv_path}/setup.py" ]]; then
+ echo " Error: '${serinv_path}' does not appear to be a valid Python package (missing pyproject.toml or setup.py)."
+ echo " Skipping serinv installation."
+ else
+ echo " Installing serinv from '${serinv_path}' in development mode..."
+ cd "$serinv_path" || {
+ echo " Error: Failed to change directory to '${serinv_path}'."
+ echo " Skipping serinv installation."
+ }
+
+ if python -m pip install --no-deps --editable .; then
+ echo " Successfully installed serinv in development mode."
+ else
+ echo " Warning: Failed to install serinv. You may need to install it manually later."
+ fi
+
+ # Return to DALIA directory
+ cd "$dalia_path" || true
+ fi
+ else
+ echo " Skipping serinv installation."
+ fi
+
+ # 8. Optional: Install mpi4py and create enhanced environment
+ echo ""
+ local install_mpi4py=""
+
+ if [[ -n "$install_mpi4py_flag" ]]; then
+ # MPI4py installation requested via command line
+ echo " MPI4py installation requested via --install-mpi4py. Installing automatically..."
+ install_mpi4py="y"
+ else
+ # Interactive mode - ask user
+ echo " MPI support can be added through mpi4py for improved parallel performance."
+ echo " Do you want to install mpi4py and create an enhanced environment? (y/N): "
+ read -r install_mpi4py
+ fi
+
+ if [[ "$install_mpi4py" =~ ^[Yy]$ ]]; then
+ echo " Creating enhanced environment with mpi4py support..."
+
+ # Deactivate current environment
+ echo " Deactivating current environment to create enhanced version..."
+ conda deactivate 2>/dev/null || true
+
+ # Create new environment with enhanced name
+ local enhanced_env_name="dalia_hmpi_fritz"
+ echo " Creating enhanced environment '${enhanced_env_name}' from '${env_name}'..."
+
+ # Remove enhanced environment if it already exists
+ if conda env list | grep -q "^${enhanced_env_name} "; then
+ echo " Removing existing enhanced environment..."
+ conda env remove -n "$enhanced_env_name" -y || {
+ echo " Warning: Failed to remove existing enhanced environment."
+ }
+ fi
+
+ # Clone the current environment
+ if conda create --name "$enhanced_env_name" --clone "$env_name" -y; then
+ echo " Successfully created enhanced environment."
+
+ # Activate the enhanced environment
+ echo " Activating enhanced environment '${enhanced_env_name}'..."
+ if fritz_activate_conda_env --env="$enhanced_env_name"; then
+ # Install mpi4py in the enhanced environment
+ echo " Installing mpi4py with Intel MPI support in enhanced environment..."
+ cd "$dalia_path" || true
+ if MPICC=$(which mpicc) pip install --no-cache-dir mpi4py; then
+ echo " Successfully installed mpi4py in enhanced environment."
+ env_name="$enhanced_env_name" # Update env_name for final message
+
+ # Determine whether to keep base environment
+ local keep_base=""
+ if [[ -n "$install_mpi4py_flag" ]]; then
+ # Command line mode - use dev_mode_flag to decide
+ if [[ -n "$dev_mode_flag" ]]; then
+ keep_base="y"
+ echo " Developer mode enabled via --dev-mode. Keeping base environment."
+ else
+ keep_base="n"
+ echo " Default mode: removing base environment to keep only enhanced version."
+ fi
+ else
+ # Interactive mode - ask user
+ echo ""
+ echo " Do you want to keep the base environment 'dalia_base_fritz' for development without MPI? (y/N): "
+ read -r keep_base
+ fi
+
+ if [[ ! "$keep_base" =~ ^[Yy]$ ]]; then
+ echo " Removing base environment 'dalia_base_fritz'..."
+ conda env remove -n "dalia_base_fritz" -y || {
+ echo " Warning: Failed to remove base environment."
+ }
+ else
+ echo " Keeping base environment for development without MPI."
+ fi
+ else
+ echo " Error: Failed to install mpi4py in enhanced environment."
+ echo " Removing broken enhanced environment and reverting to base environment..."
+
+ # Deactivate the enhanced environment
+ conda deactivate 2>/dev/null || true
+
+ # Remove the broken enhanced environment
+ conda env remove -n "$enhanced_env_name" -y || {
+ echo " Warning: Failed to remove broken enhanced environment."
+ }
+
+ # Reactivate the base environment
+ if fritz_activate_conda_env --env="dalia_base_fritz"; then
+ echo " Reverted to base environment 'dalia_base_fritz'."
+ env_name="dalia_base_fritz" # Reset env_name to base environment
+ echo " You can install mpi4py manually later with: MPICC=\$(which mpicc) pip install --no-cache-dir mpi4py"
+ else
+ echo " Warning: Failed to reactivate base environment."
+ fi
+ fi
+ else
+ echo " Warning: Failed to activate enhanced environment."
+ fi
+ else
+ echo " Error: Failed to create enhanced environment."
+ echo " Continuing with base environment..."
+ fi
+ else
+ echo " Skipping mpi4py installation."
+ fi
+
+ echo ""
+ echo " Success! DALIA conda environment '${env_name}' has been created and configured."
+ echo " Repository path: ${dalia_path}"
+ if [[ "$install_serinv" =~ ^[Yy]$ ]] && [[ -d "$serinv_path" ]]; then
+ echo " Serinv path: ${serinv_path}"
+ fi
+ if [[ "$install_mpi4py" =~ ^[Yy]$ ]]; then
+ if [[ "$env_name" == *"hmpi"* ]]; then
+ echo " Enhanced environment with mpi4py support created."
+ fi
+ fi
+ echo " To use a specific environment in the future, run: fritz_activate_conda_env --env=\"fritz_env_name\""
+ echo " To use the most performant environment you have available, just run: fritz_activate_conda_env"
+
+ return 0
+}
+
+fritz_activate_conda_env() {
+ echo "fritz_activate_conda_env: activating DALIA conda environment."
+
+ # Parse command line arguments
+ local specified_env=""
+
+ while [[ $# -gt 0 ]]; do
+ case $1 in
+ --env=*)
+ specified_env="${1#*=}"
+ shift
+ ;;
+ --help|-h)
+ echo "fritz_activate_conda_env: Activate DALIA conda environment"
+ echo ""
+ echo "Usage:"
+ echo " fritz_activate_conda_env [OPTIONS]"
+ echo ""
+ echo "Options:"
+ echo " --env=NAME Specific environment name to activate"
+ echo " --help/-h Show this help message"
+ echo ""
+ echo "Examples:"
+ echo " fritz_activate_conda_env # Auto-select best available"
+ echo " fritz_activate_conda_env --env=dalia_base_fritz # Activate specific environment"
+ return 0
+ ;;
+ *)
+ echo " Warning: Unknown parameter '$1' ignored."
+ echo " Use --help for usage information."
+ shift
+ ;;
+ esac
+ done
+
+ # Check if conda is available
+ if ! command -v conda &> /dev/null; then
+ echo " Error: conda command not found. Please ensure conda is installed and in PATH."
+ return 1
+ fi
+
+ # Safely deactivate current environment
+ conda deactivate 2>/dev/null || true
+
+ # Define available environments in order of preference (most performant first)
+ local env_priorities=("dalia_hmpi_fritz" "dalia_base_fritz")
+ local env_name=""
+
+ # If environment name is provided as argument, use it directly
+ if [[ -n "$specified_env" ]]; then
+ env_name="$specified_env"
+ echo " Using specified conda environment '${env_name}'..."
+
+ # Validate that the specified environment exists
+ if ! conda env list | grep -q "^${env_name} "; then
+ echo " Error: Conda environment '${env_name}' does not exist."
+ echo " Available environments:"
+ conda env list
+ return 1
+ fi
+ else
+ # Check which environments are available and select the most performant one
+ echo " Checking available DALIA conda environments..."
+ local available_envs=$(conda env list 2>/dev/null | grep -E "^(dalia_hmpi_fritz|dalia_base_fritz) " | awk '{print $1}')
+
+ for preferred_env in "${env_priorities[@]}"; do
+ if echo "$available_envs" | grep -q "^${preferred_env}$"; then
+ env_name="$preferred_env"
+ echo " Selected most performant available environment: '${env_name}'"
+ break
+ fi
+ done
+
+ # Fallback if no DALIA environments found
+ if [[ -z "$env_name" ]]; then
+ echo " Warning: No DALIA-specific environments found. Available environments:"
+ conda env list
+ echo " Falling back to 'base' environment..."
+ env_name="base"
+ fi
+ fi
+
+ echo " Activating conda environment '${env_name}'..."
+ conda activate ${env_name} || {
+ echo " Error: Failed to activate conda environment '${env_name}'"
+ echo " Please check that the environment exists and conda is properly configured."
+ return 1
+ }
+
+ # Verify activation was successful
+ local current_env=$(conda info --envs | grep '\*' | awk '{print $1}')
+ if [[ "$current_env" == "$env_name" ]]; then
+ echo " Successfully activated conda environment '${env_name}'"
+ else
+ echo " Warning: Environment activation may not have been successful."
+ echo " Expected: ${env_name}, Current: ${current_env}"
+ fi
+
+ return 0
+}
+
+fritz_set_perfenv() {
+ echo "fritz_set_perfenv: setting performance environment variables for Fritz."
+
+ unset SLURM_EXPORT_ENV
+
+ export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK
+ export MKL_NUM_THREADS=$SLURM_CPUS_PER_TASK
+
+ # Set thread affinity
+ CPU_BIND="mask_cpu:0xffff00000000,0xffff000000000000"
+ CPU_BIND="${CPU_BIND},0xffff,0xffff0000"
+ CPU_BIND="${CPU_BIND},0xffff000000000000000000000000,0xffff0000000000000000000000000000"
+ CPU_BIND="${CPU_BIND},0xffff0000000000000000,0xffff00000000000000000000"
+}
\ No newline at end of file
diff --git a/src/dalia/__about__.py b/src/dalia/__about__.py
index 4b3c32fd..6d2725e1 100644
--- a/src/dalia/__about__.py
+++ b/src/dalia/__about__.py
@@ -1,3 +1,3 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
-__version__ = "0.0.1"
+__version__ = "0.1.0"
diff --git a/src/dalia/__init__.py b/src/dalia/__init__.py
index ab8bc5fa..9509485c 100644
--- a/src/dalia/__init__.py
+++ b/src/dalia/__init__.py
@@ -38,16 +38,25 @@
xp.abs(1)
except (ImportError, ImportWarning, ModuleNotFoundError) as e:
- warn(f"'CuPy' is unavailable, defaulting to 'NumPy'. ({e})")
+ warn(
+ f"'CuPy' backend selected but unavailable: defaulting to 'NumPy'. ({e})"
+ )
import numpy as xp
import scipy as sp
xp_host = xp
else:
- raise ValueError(f"Unrecognized ARRAY_MODULE '{backend_flags['array_module']}'")
+ warn(
+ f"Unrecognized ARRAY_MODULE '{backend_flags['array_module']}', defaulting to 'NumPy'."
+ )
+
+ import numpy as xp
+ import scipy as sp
+
+ xp_host = xp
else:
# If the user does not specify the array module, prioritize numpy.
- warn("No `ARRAY_MODULE` specified, DALIA.core defaulting to 'NumPy'.")
+ # LOG: warn("No `ARRAY_MODULE` specified, DALIA.core defaulting to 'NumPy'.")
import numpy as xp
import scipy as sp
@@ -63,7 +72,8 @@
backend_flags["cupy_avail"] = True
except (ImportError, ImportWarning, ModuleNotFoundError) as e:
- warn(f"No 'CuPy' backend detected. ({e})")
+ ...
+ # LOG: warn(f"No 'CuPy' backend detected. ({e})")
try:
@@ -85,15 +95,43 @@
comm.Recv([array, MPI.FLOAT], source=0)
backend_flags["mpi_avail"] = True
+
if backend_flags["cupy_avail"] and os.environ.get("MPI_CUDA_AWARE", "0") == "1":
- backend_flags["mpi_cuda_aware"] = True
+ # If CuPy is available and CUDA-aware MPI is requested, check if it works.
+ try:
+ cupy_array = cupy.array([comm_rank], dtype=cupy.float32)
+ if comm_size > 1:
+ if comm_rank == 0:
+ comm.Send([cupy_array.data.ptr, MPI.FLOAT], dest=1)
+ elif comm_rank == 1:
+ comm.Recv([cupy_array.data.ptr, MPI.FLOAT], source=0)
+ # Only set to True if MPI communication with GPU data succeeded
+ backend_flags["mpi_cuda_aware"] = True
+ else:
+ # For single process, we can't test MPI communication but we can test GPU memory access
+ # Just accessing the GPU pointer suggests CUDA-aware MPI support is available
+ _ = cupy_array.data.ptr # This will fail if CUDA context is broken
+ backend_flags["mpi_cuda_aware"] = True
+ except Exception as e:
+ warn(f"CUDA-aware MPI test failed: {e}, CUDA-aware MPI will be disabled.")
+
if backend_flags["cupy_avail"] and os.environ.get("USE_NCCL", "0") == "1":
- backend_flags["nccl_avail"] = True
- else:
- backend_flags["nccl_avail"] = False
+ # If CuPy is available and NCCL is requested, check if NCCL is available.
+ try:
+ # Check if NCCL is available and functional
+ from cupy.cuda import nccl
+
+ nccl_id = nccl.get_unique_id()
+ backend_flags["nccl_avail"] = True
+ except (ImportError, ImportWarning, ModuleNotFoundError) as e:
+ warn(
+ f"'NCCL' backend requested but unavailable, NCCL will not be used. ({e})"
+ )
+ except (RuntimeError, OSError) as e:
+ warn(f"NCCL test failed: {e}, NCCL will not be used.")
except (ImportError, ImportWarning, ModuleNotFoundError) as e:
- warn(f"No 'MPI' backend detected. ({e})")
+ # LOG: warn(f"No 'MPI' backend detected. ({e})")
comm_rank = 0
comm_size = 1
diff --git a/src/dalia/configs/dalia_config.py b/src/dalia/configs/dalia_config.py
index 0dbbbec0..27f97aaf 100644
--- a/src/dalia/configs/dalia_config.py
+++ b/src/dalia/configs/dalia_config.py
@@ -20,15 +20,16 @@ class BFGSConfig(BaseModel):
max_iter: PositiveInt = 100
jac: bool = True
-
- maxcor: PositiveInt = 10 # maximum number of past gradient vectors to store -> good default: dim(theta)
- maxls: PositiveInt = 20 # maximum number of line search iterations
+
+ maxcor: PositiveInt = (
+ 10 # maximum number of past gradient vectors to store -> good default: dim(theta)
+ )
+ maxls: PositiveInt = 20 # maximum number of line search iterations
gtol: float = 1e-1
# c1: float = 1e-4 # only relevant for BFGS not for L-BFGS-B
# c2: float = 0.9 # only relevant for BFGS not for L-BFGS-B
disp: bool = False
-
class DaliaConfig(BaseModel):
@@ -41,7 +42,7 @@ class DaliaConfig(BaseModel):
# exit BFGS early if the reduction in the objective function is less than f_reduction_tol after f_reduction_lag iterations
f_reduction_lag: int = 3
f_reduction_tol: float = 1e-4
-
+
# exit BFGS early if the change in theta is less than theta_reduction_tol after theta_reduction_lag iterations
theta_reduction_lag: int = 3
theta_reduction_tol: float = 1e-4
@@ -55,6 +56,10 @@ class DaliaConfig(BaseModel):
simulation_dir: Path = Path("./dalia/")
output_dir: Path = Path.joinpath(simulation_dir, "output/")
+ # --- Verbosity level ------------------------------------------------------
+ verbosity: int = 0 # 0: minimal, 1: more info
+
+
def parse_config(config: dict | str) -> DaliaConfig:
if isinstance(config, str):
diff --git a/src/dalia/configs/likelihood_config.py b/src/dalia/configs/likelihood_config.py
index 74f48ecd..33abf64a 100644
--- a/src/dalia/configs/likelihood_config.py
+++ b/src/dalia/configs/likelihood_config.py
@@ -7,7 +7,9 @@
from pydantic import BaseModel, ConfigDict
from dalia.__init__ import ArrayLike, xp
-from dalia.configs.priorhyperparameters_config import PriorHyperparametersConfig
+from dalia.configs.priorhyperparameters_config import (
+ PriorHyperparametersConfig,
+)
from dalia.configs.priorhyperparameters_config import (
parse_config as parse_prior_hyperparameters_config,
)
diff --git a/src/dalia/configs/models_config.py b/src/dalia/configs/models_config.py
index e81fe010..77821e0f 100644
--- a/src/dalia/configs/models_config.py
+++ b/src/dalia/configs/models_config.py
@@ -4,18 +4,18 @@
from abc import ABC, abstractmethod
from typing import Literal
-from pydantic import BaseModel, ConfigDict, Field, PositiveInt
-from pydantic import model_validator
+from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator
from typing_extensions import Annotated
from dalia.__init__ import ArrayLike, xp
-from dalia.configs.priorhyperparameters_config import PriorHyperparametersConfig
+from dalia.configs.priorhyperparameters_config import (
+ PriorHyperparametersConfig,
+)
from dalia.configs.priorhyperparameters_config import (
parse_config as parse_priorhyperparameters_config,
)
-
class ModelConfig(BaseModel, ABC):
model_config = ConfigDict(extra="forbid")
@@ -35,28 +35,36 @@ class CoregionalModelConfig(ModelConfig):
ph_sigmas: list[PriorHyperparametersConfig] = None
ph_lambdas: list[PriorHyperparametersConfig] = None
- @model_validator(mode='after')
+ @model_validator(mode="after")
def check_n_models(self):
assert self.n_models == 2 or self.n_models == 3, "n_models must be 2 or 3"
return self
- @model_validator(mode='after')
+ @model_validator(mode="after")
def check_hyperparameters_length(self):
if self.n_models is not None:
if self.sigmas is not None and len(self.sigmas) != self.n_models:
- raise ValueError(f"Length of sigmas ({len(self.sigmas)}) does not match n_models ({self.n_models})")
- n_lambdas = self.n_models*(self.n_models-1)//2
+ raise ValueError(
+ f"Length of sigmas ({len(self.sigmas)}) does not match n_models ({self.n_models})"
+ )
+ n_lambdas = self.n_models * (self.n_models - 1) // 2
if self.lambdas is not None and len(self.lambdas) != n_lambdas:
- raise ValueError(f"Length of lambdas ({len(self.lambdas)}) does not match the required number of lambdas ({n_lambdas})")
+ raise ValueError(
+ f"Length of lambdas ({len(self.lambdas)}) does not match the required number of lambdas ({n_lambdas})"
+ )
return self
- @model_validator(mode='after')
+ @model_validator(mode="after")
def check_prior_hyperparameters_length(self):
if self.n_models is not None:
if self.sigmas is not None and len(self.ph_sigmas) != len(self.sigmas):
- raise ValueError(f"Length of sigmas prior hyperparameters ({len(self.ph_sigmas)}) does not match number of sigmas ({len(self.sigmas)})")
+ raise ValueError(
+ f"Length of sigmas prior hyperparameters ({len(self.ph_sigmas)}) does not match number of sigmas ({len(self.sigmas)})"
+ )
if self.lambdas is not None and len(self.ph_lambdas) != len(self.lambdas):
- raise ValueError(f"Length of lambdas prior hyperparameters ({len(self.ph_lambdas)}) does not match number of lambdas ({len(self.lambdas)})")
+ raise ValueError(
+ f"Length of lambdas prior hyperparameters ({len(self.ph_lambdas)}) does not match number of lambdas ({len(self.lambdas)})"
+ )
return self
def read_hyperparameters(self):
@@ -65,7 +73,7 @@ def read_hyperparameters(self):
for i in range(self.n_models):
theta_keys.append(f"sigma_{i}")
for i in range(self.n_models):
- for j in range(i+1, self.n_models):
+ for j in range(i + 1, self.n_models):
theta_keys.append(f"lambda_{i}_{j}")
return theta, theta_keys
@@ -79,10 +87,14 @@ def parse_config(config: dict | str) -> ModelConfig:
type = config.get("type")
if type == "coregional":
for i in range(len(config["ph_sigmas"])):
- config["ph_sigmas"][i] = parse_priorhyperparameters_config(config["ph_sigmas"][i])
+ config["ph_sigmas"][i] = parse_priorhyperparameters_config(
+ config["ph_sigmas"][i]
+ )
for i in range(len(config["ph_lambdas"])):
- config["ph_lambdas"][i] = parse_priorhyperparameters_config(config["ph_lambdas"][i])
+ config["ph_lambdas"][i] = parse_priorhyperparameters_config(
+ config["ph_lambdas"][i]
+ )
return CoregionalModelConfig(**config)
# Add more elif branches for other model types
else:
- raise ValueError(f"Invalid submodel type: {type}")
\ No newline at end of file
+ raise ValueError(f"Invalid submodel type: {type}")
diff --git a/src/dalia/configs/priorhyperparameters_config.py b/src/dalia/configs/priorhyperparameters_config.py
index ddbea90c..b1ccd933 100644
--- a/src/dalia/configs/priorhyperparameters_config.py
+++ b/src/dalia/configs/priorhyperparameters_config.py
@@ -3,17 +3,19 @@
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
+from scipy.sparse import spmatrix
from typing_extensions import Annotated
from dalia.__init__ import NDArray
-from scipy.sparse import spmatrix
# --- PRIOR HYPERPARAMETERS ----------------------------------------------------
class PriorHyperparametersConfig(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
- type: Literal["gaussian", "penalized_complexity", "beta", "gaussian_mvn"] = None
+ type: Literal[
+ "gaussian", "penalized_complexity", "beta", "gaussian_mvn", "gamma", "inverse_gamma"
+ ] = None
class GaussianPriorHyperparametersConfig(PriorHyperparametersConfig):
@@ -44,15 +46,27 @@ class BetaPriorHyperparametersConfig(PriorHyperparametersConfig):
beta: float = None
+class GammaPriorHyperparametersConfig(PriorHyperparametersConfig):
+ alpha: float = None
+ beta: float = None
+
+class InverseGammaPriorHyperparametersConfig(PriorHyperparametersConfig):
+ alpha: float = None
+ beta: float = None
+
+
def parse_config(config: dict) -> PriorHyperparametersConfig:
prior_type = config.get("type")
if prior_type == "gaussian":
return GaussianPriorHyperparametersConfig(**config)
- elif prior_type == "gaussian_mvn":
+ if prior_type == "gaussian_mvn":
return GaussianMVNPriorHyperparametersConfig(**config)
- elif prior_type == "penalized_complexity":
+ if prior_type == "penalized_complexity":
return PenalizedComplexityPriorHyperparametersConfig(**config)
- elif prior_type == "beta":
+ if prior_type == "beta":
return BetaPriorHyperparametersConfig(**config)
- else:
- raise ValueError(f"Unknown prior hyperparameters config type: {prior_type}")
+ if prior_type == "gamma":
+ return GammaPriorHyperparametersConfig(**config)
+ if prior_type == "inverse_gamma":
+ return InverseGammaPriorHyperparametersConfig(**config)
+ raise ValueError(f"Unknown prior hyperparameters config type: {prior_type}")
diff --git a/src/dalia/configs/submodels_config.py b/src/dalia/configs/submodels_config.py
index 87c7cf85..44dc4675 100644
--- a/src/dalia/configs/submodels_config.py
+++ b/src/dalia/configs/submodels_config.py
@@ -7,20 +7,22 @@
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
from typing_extensions import Annotated
-from dalia.__init__ import ArrayLike, NDArray, xp
-from dalia.configs.priorhyperparameters_config import PriorHyperparametersConfig, BetaPriorHyperparametersConfig, GaussianMVNPriorHyperparametersConfig
+from dalia.__init__ import ArrayLike, xp
+from dalia.configs.priorhyperparameters_config import (
+ BetaPriorHyperparametersConfig,
+ GaussianMVNPriorHyperparametersConfig,
+ PriorHyperparametersConfig,
+)
from dalia.configs.priorhyperparameters_config import (
parse_config as parse_priorhyperparameters_config,
)
-from dalia.utils import scaled_logit
-
class SubModelConfig(BaseModel, ABC):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
# Input folder for this specific submodel
input_dir: str = None
- type: Literal["spatio_temporal", "spatial", "regression", "brainiac"] = None
+ type: Literal["spatio_temporal", "spatial", "regression", "brainiac", "ar1"] = None
@abstractmethod
def read_hyperparameters(self) -> tuple[ArrayLike, list]: ...
@@ -34,6 +36,34 @@ def read_hyperparameters(self):
return xp.array([]), []
+class AR1SubModelConfig(SubModelConfig):
+
+ ## prior on phi
+ phi: float = None # AR(1) coefficient
+ phi_scaled: float = None
+ ph_phi: PriorHyperparametersConfig = None
+ ## check that phi is between -1 and 1 (use pc prior)
+ # check inla.doc("pc.cor1")
+
+ ## either define tau or sigma2
+ tau: float = None # Precision
+ # sigma2: float = None # Marginal variance
+
+
+ ph_tau: PriorHyperparametersConfig = None
+ # ph_sigma2: PriorHyperparametersConfig = None
+
+ def read_hyperparameters(self):
+
+ # input of phi is in (0,1), rescale to -/+ INF
+ #self.phi_scaled = scaled_logit(self.phi, direction="forward")
+ theta = xp.array([self.phi, self.tau])
+ #theta_internal = xp.array([self.phi, self.tau])
+ theta_keys = ["phi", "tau"]
+
+ return theta, theta_keys
+
+
class SpatioTemporalSubModelConfig(SubModelConfig):
spatial_domain_dimension: PositiveInt = 2
@@ -79,44 +109,42 @@ class BrainiacSubModelConfig(SubModelConfig):
# --- Hyperparameters ---
h2: float = None
h2_scaled: float = None
- alpha: NDArray = None
+ alpha: list[float] = None
# --- Prior hyperparameters ---
ph_h2: BetaPriorHyperparametersConfig = None
ph_alpha: GaussianMVNPriorHyperparametersConfig = None
def read_hyperparameters(self):
-
- # input of h2 is in (0,1), rescale to -/+ INF
- self.h2_scaled = scaled_logit(self.h2, direction="forward")
-
- theta = xp.concatenate(([self.h2_scaled], self.alpha))
+ theta = xp.concatenate([xp.array([self.h2]), xp.array(self.alpha)])
theta_keys = ["h2"] + [f"alpha_{i}" for i in range(len(self.alpha))]
return theta, theta_keys
+
def parse_config(config: dict | str) -> SubModelConfig:
if isinstance(config, str):
with open(config, "rb") as f:
config = tomllib.load(f)
-
- type = config.get("type")
- if type == "spatio_temporal":
+ model_type = config.get("type")
+ if model_type == "spatio_temporal":
config["ph_s"] = parse_priorhyperparameters_config(config["ph_s"])
config["ph_t"] = parse_priorhyperparameters_config(config["ph_t"])
config["ph_st"] = parse_priorhyperparameters_config(config["ph_st"])
return SpatioTemporalSubModelConfig(**config)
- elif type == "spatial":
+ if model_type == "spatial":
config["ph_s"] = parse_priorhyperparameters_config(config["ph_s"])
config["ph_e"] = parse_priorhyperparameters_config(config["ph_e"])
return SpatialSubModelConfig(**config)
- elif type == "regression":
+ if model_type == "regression":
return RegressionSubModelConfig(**config)
- elif type == "brainiac":
+ if model_type == "brainiac":
config["ph_h2"] = parse_priorhyperparameters_config(config["ph_h2"])
config["ph_alpha"] = parse_priorhyperparameters_config(config["ph_alpha"])
return BrainiacSubModelConfig(**config)
- # Add more elif branches for other submodel types
- else:
- raise ValueError(f"Unknown submodel type: {type}")
+ if model_type == "ar1":
+ config["ph_tau"] = parse_priorhyperparameters_config(config["ph_tau"])
+ config["ph_phi"] = parse_priorhyperparameters_config(config["ph_phi"])
+ return AR1SubModelConfig(**config)
+ raise ValueError(f"Unknown submodel type: {model_type}")
diff --git a/src/dalia/core/dalia.py b/src/dalia/core/dalia.py
index 8f52dca4..acddaf97 100644
--- a/src/dalia/core/dalia.py
+++ b/src/dalia/core/dalia.py
@@ -1,38 +1,43 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
import logging
-from tabulate import tabulate
from scipy import optimize
+from tabulate import tabulate
+import copy
-from dalia import ArrayLike, NDArray, backend_flags, comm_rank, comm_size, xp, sp
+from dalia import ArrayLike, NDArray, backend_flags, comm_rank, comm_size, sp, xp
from dalia.configs.dalia_config import DaliaConfig
from dalia.core.model import Model
from dalia.solvers import DenseSolver, DistSerinvSolver, SerinvSolver, SparseSolver
from dalia.utils import (
+ DummyCommunicator,
+ add_str_header,
allreduce,
+ ascii_logo,
+ boxify,
extract_diagonal,
+ format_size,
free_unused_gpu_memory,
get_device,
get_host,
+ memory_report,
print_msg,
set_device,
smartsplit,
synchronize,
synchronize_gpu,
- ascii_logo,
- add_str_header,
- boxify,
- memory_report,
- format_size,
- DummyCommunicator,
+ check_vector_consistency,
+ bcast,
)
if backend_flags["mpi_avail"]:
from mpi4py import MPI
-if backend_flags["cupy_avail"]:
- import cupy as cp
+if backend_flags["nccl_avail"]:
+ from cupy.cuda import nccl
+else:
+ nccl = None
import time
@@ -73,6 +78,8 @@ def __init__(
self.eps_gradient_f = self.config.eps_gradient_f
self.eps_hessian_f = self.config.eps_hessian_f
+ self.verbosity = self.config.verbosity
+
# --- Configure HPC
set_device(comm_rank, comm_size)
@@ -97,7 +104,7 @@ def __init__(
min_group_size=min_solver_size * min_q_parallel,
)
self.world_size = self.comm_world.size
-
+
self.qeval_world, self.comm_qeval, self.color_qeval = smartsplit(
comm=self.comm_feval,
n_parallelizable_evaluations=self.n_qeval,
@@ -160,7 +167,7 @@ def __init__(
raise ValueError(
f"Not enough diagonal blocks ({n_diag_blocks}) to use the distributed solver with {n_processes_solver} processes."
)
-
+
self.nccl_comm = None
if backend_flags["nccl_avail"]:
# --- Initialize NCCL communicator
@@ -169,12 +176,12 @@ def __init__(
f"rank {self.initial_comm_world.rank} initializing NCCL communicator.",
flush=True,
)
- nccl_id = cp.cuda.nccl.get_unique_id()
+ nccl_id = nccl.get_unique_id()
self.comm_qeval.bcast(nccl_id, root=0)
else:
nccl_id = self.comm_qeval.bcast(None, root=0)
- self.nccl_comm = cp.cuda.nccl.NcclCommunicator(
+ self.nccl_comm = nccl.NcclCommunicator(
self.comm_qeval.size,
nccl_id,
self.comm_qeval.rank,
@@ -198,24 +205,28 @@ def __init__(
dtype=xp.float64,
)
self.theta_mat = xp.zeros(
- (self.model.theta.size, self.n_f_evaluations), dtype=xp.float64
+ (self.model.theta_internal.size, self.n_f_evaluations), dtype=xp.float64
)
- self.theta_optimizer = xp.zeros_like(self.model.theta)
- self.theta_optimizer[:] = self.model.theta
+ self.theta_optimizer = xp.zeros_like(self.model.theta_internal)
+ self.theta_optimizer[:] = self.model.theta_internal
+ self.theta_star = None # mode not yet computed
+ self.theta_star_internal = None
+ self.x_star = None # mode not yet computed
+ self.cov_theta_internal = None # covariance not yet computed
# --- Metrics
self.f_values: ArrayLike = []
- self.theta_values: ArrayLike = []
+ self.theta_values_internal: ArrayLike = []
self.objective_function_time: ArrayLike = []
self.solver_time: ArrayLike = []
self.construction_time: ArrayLike = []
-
+ self.accepted_iter = 0
+
# --- Timers
self.t_construction_qprior = 0.0
self.t_construction_qconditional = 0.0
- self.solver.t_cholesky = 0.0
+ self.solver.t_factorize = 0.0
self.solver.t_solve = 0.0
-
self._print_init()
logging.info("DALIA initialized.")
@@ -232,9 +243,18 @@ def _print_init(self) -> None:
# Parallelization strategies header
parallel_strategies_values = [
- ["Participating Processes / Total Processes", f"{self.world_size} / {self.initial_comm_world.size}"],
- ["Parallelization through F()", f"{self.world_size // self.comm_feval.size}"],
- ["Parallelization through Q()", f"{self.comm_feval.size // self.comm_qeval.size}"],
+ [
+ "Participating Processes / Total Processes",
+ f"{self.world_size} / {self.initial_comm_world.size}",
+ ],
+ [
+ "Parallelization through F()",
+ f"{self.world_size // self.comm_feval.size}",
+ ],
+ [
+ "Parallelization through Q()",
+ f"{self.comm_feval.size // self.comm_qeval.size}",
+ ],
["Parallelization through S()", f"{self.comm_qeval.size}"],
]
parallel_strategies_table = tabulate(
@@ -250,10 +270,10 @@ def _print_init(self) -> None:
# HPC modules header
hpc_modules_values = [
- ["Array module", xp.__name__],
- ["MPI available", backend_flags["mpi_avail"]],
- ["Is MPI CUDA aware", backend_flags["mpi_cuda_aware"]],
- ["Is NCCL available", backend_flags["nccl_avail"]],
+ ["Array module", xp.__name__],
+ ["MPI available", backend_flags["mpi_avail"]],
+ ["Is MPI CUDA aware", backend_flags["mpi_cuda_aware"]],
+ ["Is NCCL available", backend_flags["nccl_avail"]],
]
hpc_modules_table = tabulate(
hpc_modules_values,
@@ -269,9 +289,9 @@ def _print_init(self) -> None:
# Memory usage header
used_memory, available_memory = memory_report()
memory_usage_values = [
- ["Solver memory", format_size(self.solver.get_solver_memory())],
- ["Total memory used", format_size(used_memory)],
- ["Total memory available", format_size(available_memory)],
+ ["Solver memory", format_size(self.solver.get_solver_memory())],
+ ["Total memory used", format_size(used_memory)],
+ ["Total memory available", format_size(available_memory)],
]
memory_usage_table = tabulate(
memory_usage_values,
@@ -285,25 +305,35 @@ def _print_init(self) -> None:
str_representation += "\n" + boxify(memory_usage_table)
print_msg(str_representation, flush=True)
-
-
def run(self) -> dict:
"""Run the DALIA"""
+ synchronize(comm=self.comm_world)
+ tic = time.perf_counter()
# compute mode of the hyperparameters theta
minimization_result = self.minimize()
- theta_star = get_device(minimization_result["theta"])
- x_star = get_device(minimization_result["x"])
+ self.theta_star = minimization_result["theta"]
+ self.theta_star_internal = minimization_result["theta_internal"]
+ self.x_star = minimization_result["x"]
+
+ print("Finished the optimization procedure.")
+
+ # need to update theta_star and x_star to be the same across all ranks
+ bcast(data=self.theta_star[:], root=0, comm=self.comm_world)
+ bcast(data=self.theta_star_internal[:], root=0, comm=self.comm_world)
+ bcast(data=self.x_star[:], root=0, comm=self.comm_world)
# compute covariance of the hyperparameters theta at the mode
- cov_theta = self.compute_covariance_hp(theta_star)
+ self.cov_theta_internal = self.compute_covariance_hp(self.theta_star)
+ print("Computed covariance of the hyperparameters at the mode.")
# compute marginal variances of the latent parameters
marginal_variances_latent = self.get_marginal_variances_latent_parameters(
- theta_star, x_star
+ self.theta_star, self.x_star
)
+ print_msg("Computed marginal variances of the latent parameters.")
# compute marginal variances of the observations
# TODO: only run by default when dense multiplcation issue is fixed, see issue #78
@@ -314,18 +344,22 @@ def run(self) -> dict:
# construct new dictionary with the results
results = {
"theta": minimization_result["theta"],
+ "theta_internal": minimization_result["theta_internal"],
"x": minimization_result["x"],
"f": minimization_result["f"],
"grad_f": minimization_result["grad_f"],
"f_values": minimization_result["f_values"],
"theta_values": minimization_result["theta_values"],
- "cov_theta": cov_theta,
+ "cov_theta_internal": self.cov_theta_internal,
"marginal_variances_latent": marginal_variances_latent,
+ "optimization_iterations": self.accepted_iter,
# "marginal_variances_observations": get_host(
# marginal_variances_observations
# ),
}
-
+ synchronize(comm=self.comm_world)
+ toc = time.perf_counter()
+ print_msg(f"DALIA inference took: {toc - tic:0.4f} (s)", flush=True)
return results
def minimize(self) -> optimize.OptimizeResult:
@@ -340,17 +374,29 @@ def minimize(self) -> optimize.OptimizeResult:
minimization_result : scipy.optimize.OptimizeResult
Result of the optimization procedure.
"""
+ # Ensure that all ranks are initialized to the same theta
+ check_vector_consistency(
+ value=self.model.theta_external,
+ comm=self.comm_world,
+ flag="self.model.theta_external",
+ verbose="Full",
+ )
- if len(self.model.theta) == 0:
+ if len(self.model.theta_external) == 0:
# Only run the inner iteration
print_msg("No hyperparameters, just running inner iteration.")
- self.f_value = self._evaluate_f(self.model.theta)
-
+ self.f_value = self._evaluate_f(self.model.theta_external)
self.minimization_result: dict = {
- "theta": self.model.theta,
- "x": self.model.x, # [self.model.inverse_permutation_latent_variables],
- "f": self.f_value,
+ "theta_internal": copy.deepcopy(self.model.theta_internal),
+ "theta": copy.deepcopy(self.model.theta_external),
+ "x": copy.deepcopy(self.model.x), # [self.model.inverse_permutation_latent_variables],
+ "f": copy.deepcopy(self.f_value),
+ "grad_f": [],
+ "f_values": [],
+ ### these values are in internal scale (!!)
+ "theta_values": [],
}
+
else:
print_msg("Starting optimization.")
self.iter = 0
@@ -377,12 +423,12 @@ def callback(intermediate_result: optimize.OptimizeResult):
f"Iteration: {self.accepted_iter:2d} (took: {self.objective_function_time[-1]:.2f}) | "
f"Theta: [{theta_str}] | "
f"Function Value: {fun_i: .6f} | "
- f"Gradient: [{gradient_str}] | ",
+ #f"Gradient: [{gradient_str}] | ",
f"Norm(Grad): [{xp.linalg.norm(self.gradient_f): .6f}]",
flush=True,
)
- self.theta_values.append(theta_i)
+ self.theta_values_internal.append(theta_i)
self.f_values.append(fun_i)
# check if f_values have been decreasing over last iterations
@@ -400,17 +446,15 @@ def callback(intermediate_result: optimize.OptimizeResult):
)
self.minimization_result = {
- "theta": get_host(self.model.theta),
- "x": get_host(
- self.model.x
- # self.model.x[
- # self.model.inverse_permutation_latent_variables
- # ]
- ),
- "f": fun_i,
- "grad_f": self.gradient_f,
+ "theta_internal": copy.deepcopy(self.model.theta_internal),
+ "theta":
+ copy.deepcopy(self.model.theta_external),
+ "x": copy.deepcopy(self.model.x),
+ "f": copy.deepcopy(fun_i),
+ "grad_f": copy.deepcopy(self.gradient_f),
"f_values": self.f_values,
- "theta_values": self.theta_values,
+ ### these values are in internal scale (!!)
+ "theta_values": self.theta_values_internal,
}
raise OptimizationConvergedEarlyExit()
@@ -418,13 +462,13 @@ def callback(intermediate_result: optimize.OptimizeResult):
if self.accepted_iter > self.config.theta_reduction_lag:
if (
xp.linalg.norm(
- self.theta_values[-self.config.theta_reduction_lag]
+ self.theta_values_internal[-self.config.theta_reduction_lag]
- theta_i
)
< self.config.theta_reduction_tol
):
norm_diff = xp.linalg.norm(
- self.theta_values[
+ self.theta_values_internal[
self.accepted_iter - self.config.theta_reduction_lag
]
- theta_i
@@ -438,17 +482,19 @@ def callback(intermediate_result: optimize.OptimizeResult):
)
self.minimization_result = {
- "theta": get_host(self.model.theta),
+ "theta_internal": get_host(self.model.theta_internal),
+ "theta": get_host(
+ self.model.theta_external),
"x": get_host(
self.model.x
# self.model.x[
# self.model.inverse_permutation_latent_variables
# ]
),
- "f": fun_i,
- "grad_f": self.gradient_f,
- "f_values": self.f_values,
- "theta_values": self.theta_values,
+ "f": copy.deepcopy(fun_i),
+ "grad_f": copy.deepcopy(self.gradient_f),
+ "f_values": copy.deepcopy(self.f_values),
+ "theta_values": copy.deepcopy(self.theta_values_internal),
}
raise OptimizationConvergedEarlyExit()
@@ -498,14 +544,13 @@ def callback(intermediate_result: optimize.OptimizeResult):
)
self.minimization_result: dict = {
- "theta": scipy_result.x,
- "x": get_host(
- self.model.x, # [self.model.inverse_permutation_latent_variables]
- ),
- "f": scipy_result.fun,
- "grad_f": self.gradient_f,
- "f_values": self.f_values,
- "theta_values": self.theta_values,
+ "theta_internal": copy.deepcopy(self.model.theta_internal), # scipy_result.x, #
+ "theta": copy.deepcopy(self.model.theta_external),
+ "x": copy.deepcopy(self.model.x), # [self.model.inverse_permutation_latent_variables]
+ "f": copy.deepcopy(scipy_result.fun),
+ "grad_f": copy.deepcopy(self.gradient_f),
+ "f_values": copy.deepcopy(self.f_values),
+ "theta_values": copy.deepcopy(self.theta_values_internal),
}
return self.minimization_result
@@ -529,7 +574,7 @@ def _objective_function(
self.t_construction_qprior = 0.0
self.t_construction_qconditional = 0.0
- self.solver.t_cholesky = 0.0
+ self.solver.t_factorize = 0.0
self.solver.t_solve = 0.0
synchronize(comm=self.comm_world)
@@ -586,12 +631,12 @@ def _objective_function(
synchronize(comm=self.comm_world)
toc = time.perf_counter()
self.objective_function_time.append(toc - tic)
- self.solver_time.append(self.solver.t_cholesky + self.solver.t_solve)
+ self.solver_time.append(self.solver.t_factorize + self.solver.t_solve)
self.construction_time.append(
self.t_construction_qprior + self.t_construction_qconditional
)
- if self.iter > 0:
+ if self.iter > 0 and self.verbosity > 0:
print(
f"rank {comm_rank} | objfunc_time: {self.objective_function_time[1:]} | solver_time: {self.solver_time[1:]} | construction_time: {self.construction_time[1:]}",
flush=True,
@@ -624,14 +669,19 @@ def _evaluate_f(
hyperparameters, log likelihood, log prior of the latent parameters,
and log conditional of the latent parameters.
"""
- self.model.theta[:] = theta_i
+ import time
+
+ tic = time.time()
+
+ # self.model.theta_internal[:] = theta_i
+ self.model.theta_internal = theta_i
f_theta = xp.zeros(1, dtype=xp.float64)
# --- Optimize x and evaluate the conditional of the latent parameters
if self.model.is_likelihood_gaussian():
# Done by both processes
- tic = time.perf_counter()
synchronize_gpu()
+ tic = time.perf_counter()
self.model.construct_Q_prior()
synchronize_gpu()
toc = time.perf_counter()
@@ -644,15 +694,16 @@ def _evaluate_f(
task_mapping = [i % n_qeval_comm for i in range(2)]
if task_mapping[0] == self.color_qeval:
+
# Done by processes "even"
- tic = time.perf_counter()
synchronize_gpu()
+ tic = time.perf_counter()
Q_conditional = self.model.construct_Q_conditional(eta)
synchronize_gpu()
toc = time.perf_counter()
self.t_construction_qconditional += toc - tic
- self.solver.cholesky(A=Q_conditional, sparsity="bta")
+ self.solver.factorize(A=Q_conditional, sparsity="bta")
rhs: NDArray = self.model.construct_information_vector(
eta,
@@ -678,7 +729,8 @@ def _evaluate_f(
log_prior_hyperparameters: float = (
self.model.evaluate_log_prior_hyperparameters()
)
- likelihood: float = float(self.model.evaluate_likelihood(eta=eta))
+
+ likelihood: float = self.model.evaluate_likelihood(eta=eta)
prior_latent_parameters: float = (
self._evaluate_prior_latent_parameters()
)
@@ -696,8 +748,14 @@ def _evaluate_f(
comm=self.comm_feval,
)
synchronize(comm=self.comm_qeval)
+
else:
+ synchronize_gpu()
+ tic = time.perf_counter()
self.model.construct_Q_prior()
+ synchronize_gpu()
+ toc = time.perf_counter()
+ self.t_construction_qprior += toc - tic
log_prior_hyperparameters: float = (
self.model.evaluate_log_prior_hyperparameters()
@@ -744,7 +802,7 @@ def _evaluate_f(
return f_theta[0]
- def compute_covariance_hp(self, theta_i: NDArray) -> NDArray:
+ def compute_covariance_hp(self, theta_external: NDArray) -> NDArray:
"""compute the covariance matrix of the hyperparameters theta.
Parameters
@@ -757,16 +815,44 @@ def compute_covariance_hp(self, theta_i: NDArray) -> NDArray:
cov_theta : NDArray[dim_theta, dim_theta]
Covariance matrix of the hyperparameters theta.
"""
- self.model.theta[:] = theta_i
- hess_theta = self._evaluate_hessian_f(theta_i)
- cov_theta = xp.linalg.inv(hess_theta)
+ # self.model.rescale_hyperparameters_to_internal(theta_interpret, direction="forward")
+ # ensure that all ranks are initialized to the same theta
+ check_vector_consistency(
+ theta_external,
+ comm=self.comm_world,
+ flag="theta_external",
+ verbose="Full",
+ )
+ print_msg(
+ f"Computing covariance of hyperparameters at theta_external {theta_external}.",
+ flush=True,
+ )
+
+ synchronize(comm=self.comm_world)
+ tic = time.perf_counter()
+ self.model.theta_external = theta_external
+
+ hess_theta_internal = self._evaluate_hessian_f(self.model.theta_internal)
+ print_msg(
+ f"hessian_f: \n {hess_theta_internal}",
+ flush=True,
+ )
+ cov_theta_internal = xp.linalg.inv(hess_theta_internal)
- return cov_theta
+ synchronize(comm=self.comm_world)
+ toc = time.perf_counter()
+ print_msg(
+ "Time to compute covariance of hyperparameters:",
+ toc - tic,
+ flush=True,
+ )
+
+ return cov_theta_internal
def _evaluate_hessian_f(
self,
- theta_i: NDArray,
+ theta_internal: NDArray,
) -> NDArray:
"""Approximate the hessian of the function f(theta) = log(p(theta|y)).
@@ -784,7 +870,8 @@ def _evaluate_hessian_f(
Compute finite difference approximation of the hessian of f at theta_i.
"""
- self.model.theta[:] = theta_i
+ ## TODO: this is the quick fix ...
+ # self.model.theta[:] = theta_i
dim_theta = self.model.n_hyperparameters
# pre-allocate storage for the hessian & f_values
@@ -816,8 +903,10 @@ def _evaluate_hessian_f(
counter = 0
# compute f(theta)
if self.color_feval == task_mapping[0]:
+ theta_i = theta_internal.copy()
f_theta = self._evaluate_f(theta_i)
f_ii_loc[1, :] = f_theta
+
counter += 1
for k in range(loop_dim):
@@ -828,46 +917,62 @@ def _evaluate_hessian_f(
if i == j:
if self.color_feval == task_mapping[counter]:
# theta+eps_i
- f_ii_loc[0, i] = self._evaluate_f(
- theta_i + eps_mat[i, :]
- )
+ # theta_i = theta_internal.copy()
+ # f_ii_loc[0, i] = self._evaluate_f(theta_i + eps_mat[i, :])
+ theta_i = theta_internal + eps_mat[i, :]
+ result = self._evaluate_f(theta_i)
+ f_ii_loc[0, i] = result
counter += 1
if self.color_feval == task_mapping[counter]:
# theta-eps_i
- f_ii_loc[2, i] = self._evaluate_f(
- theta_i - eps_mat[i, :]
- )
+ # theta_i = theta_internal.copy()
+ # f_ii_loc[2, i] = self._evaluate_f(theta_i - eps_mat[i, :])
+ theta_i = theta_internal - eps_mat[i, :]
+ result = self._evaluate_f(theta_i)
+ f_ii_loc[2, i] = result
counter += 1
# as hessian is symmetric we only have to compute the upper triangle
elif i < j:
# theta+eps_i+eps_j
if self.color_feval == task_mapping[counter]:
- f_ij_loc[0, k] = self._evaluate_f(
- theta_i + eps_mat[i, :] + eps_mat[j, :]
- )
+ # theta_i = theta_internal.copy()
+ # f_ij_loc[0, k] = self._evaluate_f(
+ # theta_i + eps_mat[i, :] + eps_mat[j, :]
+ # )
+ theta_i = theta_internal + eps_mat[i, :] + eps_mat[j, :]
+ f_ij_loc[0, k] = self._evaluate_f(theta_i)
counter += 1
# theta+eps_i-eps_j
if self.color_feval == task_mapping[counter]:
- f_ij_loc[1, k] = self._evaluate_f(
- theta_i + eps_mat[i, :] - eps_mat[j, :]
- )
+ # theta_i = theta_internal.copy()
+ # f_ij_loc[1, k] = self._evaluate_f(
+ # theta_i + eps_mat[i, :] - eps_mat[j, :]
+ # )
+ theta_i = theta_internal + eps_mat[i, :] - eps_mat[j, :]
+ f_ij_loc[1, k] = self._evaluate_f(theta_i)
counter += 1
# theta-eps_i+eps_j
if self.color_feval == task_mapping[counter]:
- f_ij_loc[2, k] = self._evaluate_f(
- theta_i - eps_mat[i, :] + eps_mat[j, :]
- )
+ # theta_i = theta_internal.copy()
+ # f_ij_loc[2, k] = self._evaluate_f(
+ # theta_i - eps_mat[i, :] + eps_mat[j, :]
+ # )
+ theta_i = theta_internal - eps_mat[i, :] + eps_mat[j, :]
+ f_ij_loc[2, k] = self._evaluate_f(theta_i)
counter += 1
# theta-eps_i-eps_j
if self.color_feval == task_mapping[counter]:
- f_ij_loc[3, k] = self._evaluate_f(
- theta_i - eps_mat[i, :] - eps_mat[j, :]
- )
+ # theta_i = theta_internal.copy()
+ # f_ij_loc[3, k] = self._evaluate_f(
+ # theta_i - eps_mat[i, :] - eps_mat[j, :]
+ # )
+ theta_i = theta_internal - eps_mat[i, :] - eps_mat[j, :]
+ f_ij_loc[3, k] = self._evaluate_f(theta_i)
counter += 1
allreduce(
@@ -909,8 +1014,178 @@ def _evaluate_hessian_f(
return hess
+ def marginal_distributions_hp(self,
+ #quantiles: NDArray = xp.array([0.0001, 0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975, 0.9999])
+ quantiles: NDArray = xp.array([0.025, 0.25, 0.5, 0.75, 0.975])
+ ) -> dict:
+ """Compute the marginal distributions of the hyperparameters theta.
+
+ Parameters
+ ----------
+ quantiles : NDArray
+ Quantiles to compute. If not provided, default quantiles are used. If None, no quantiles are computed.
+
+ Returns
+ -------
+ dict
+ Dictionary containing the marginal distributions of the hyperparameters theta and possibly quantiles / percentiles.
+
+ """
+
+ # check that theta_star and covariance matrix are computed
+ if self.theta_star is None or self.cov_theta_internal is None or self.x_star is None:
+ raise ValueError("theta_star, x_star and covariance matrix of the hyperparameters must be computed before calling marginal_distributions_hp(). Please run the full DALIA pipeline or set them manually.")
+
+ # set up dictionary to store results
+ results = {
+ 'hyperparameters': {},
+ 'summary': {
+ 'n_params': self.model.n_hyperparameters,
+ 'param_names': self.model.theta_keys,
+ 'quantile_levels': quantiles.tolist() if quantiles is not None else None
+ }
+ }
+
+ # Import necessary functions
+ from dalia.utils.gaussian_quadrature import compute_variance_gauss_hermite
+ from dalia.utils.reparametrizations import compute_bounds, compute_transformed_pdf, compute_transformed_quantiles
+ from dalia.prior_hyperparameters import GaussianMVNPriorHyperparameters
+
+ hp_offset = 0
+ for i, prior in enumerate(self.model.prior_hyperparameters):
+ if isinstance(prior, GaussianMVNPriorHyperparameters):
+ n_hp_for_this_prior = prior.mean.shape[0]
+ else:
+ n_hp_for_this_prior = 1
+
+ for j in range(0, n_hp_for_this_prior, 1):
+ param_name = results['summary']['param_names'][i+hp_offset+j]
+
+ # Extract marginal parameters for this hyperparameter
+ theta_internal_i = self.theta_star_internal[i+hp_offset+j]
+ marg_var_internal_i = self.cov_theta_internal[
+ i + hp_offset + j, i + hp_offset + j
+ ]
+
+ # compute external_mean and external_var using
+ # compute_variance_gauss_hermite(mean_internal, variance_internal, transform, n_points=20): from utils gaussian quadrature
+ gauss_hermite_result = compute_variance_gauss_hermite(
+ theta_internal_i, marg_var_internal_i, prior.rescale_hyperparameters_to_internal, n_points=30
+ )
+
+ # compute bounds for theta intervals using compute_bounds() from utils
+ (theta_internal_lower, theta_internal_upper), (theta_external_lower, theta_external_upper) = compute_bounds(
+ theta_internal_i, marg_var_internal_i, prior.rescale_hyperparameters_to_internal, n_std=4
+ )
+
+ # set theta_internal_interval
+ theta_internal_interval = xp.linspace(theta_internal_lower, theta_internal_upper, num=100)
+
+ # Compute PDF values in external scale
+ theta_external_interval, pdf_external = compute_transformed_pdf(theta_internal_i, marg_var_internal_i, theta_internal_interval, prior.rescale_hyperparameters_to_internal)
+
+ # Initialize parameter dictionary
+ param_dict = {
+ 'mean_internal': float(get_host(theta_internal_i)),
+ 'variance_internal': float(get_host(marg_var_internal_i)),
+ 'mean_external': float(get_host(gauss_hermite_result['mean'])),
+ 'variance_external': float(get_host(gauss_hermite_result['variance'])),
+ 'pdf_data': (get_host(theta_external_interval), get_host(pdf_external)) # tuple of xp arrays
+ }
+
+ # if quantiles is not None, compute quantiles using compute_transformed_quantiles()
+ if quantiles is not None:
+ quantiles_external = compute_transformed_quantiles(
+ theta_internal_i, marg_var_internal_i, quantiles, prior.rescale_hyperparameters_to_internal
+ )
+
+ # Also compute internal quantiles for completeness
+ from scipy.stats import norm
+ quantiles_internal = get_device(norm.ppf(get_host(quantiles), loc=get_host(theta_internal_i), scale=get_host(xp.sqrt(marg_var_internal_i))))
+
+ param_dict['quantiles'] = {
+ 'levels': get_host(quantiles).tolist(),
+ 'internal': {
+ 'values': get_host(quantiles_internal).tolist(),
+ 'pairs': list(zip(get_host(quantiles).tolist(), get_host(quantiles_internal).tolist()))
+ },
+ 'external': {
+ 'values': get_host(quantiles_external).tolist(),
+ 'pairs': list(zip(get_host(quantiles).tolist(), get_host(quantiles_external).tolist()))
+ }
+ }
+
+ # Store in main results dictionary
+ results['hyperparameters'][param_name] = param_dict
+
+ hp_offset += n_hp_for_this_prior-1
+
+ # Old code
+ if False:
+ # iterate over all hyperparameters and store outputs in a dictionary
+ for i in range(self.model.n_hyperparameters):
+ param_name = results['summary']['param_names'][i]
+
+ # Extract marginal parameters for this hyperparameter
+ theta_internal_i = self.theta_star_internal[i]
+ marg_var_internal_i = self.cov_theta_internal[i, i]
+
+ # compute external_mean and external_var using
+ # compute_variance_gauss_hermite(mean_internal, variance_internal, transform, n_points=20): from utils gaussian quadrature
+ gauss_hermite_result = compute_variance_gauss_hermite(
+ theta_internal_i, marg_var_internal_i, self.model.prior_hyperparameters[i].rescale_hyperparameters_to_internal, n_points=30
+ )
+
+ # compute bounds for theta intervals using compute_bounds() from utils
+ (theta_internal_lower, theta_internal_upper), (theta_external_lower, theta_external_upper) = compute_bounds(
+ theta_internal_i, marg_var_internal_i, self.model.prior_hyperparameters[i].rescale_hyperparameters_to_internal, n_std=4
+ )
+
+ # set theta_internal_interval
+ theta_internal_interval = xp.linspace(theta_internal_lower, theta_internal_upper, num=100)
+
+ # Compute PDF values in external scale
+ theta_external_interval, pdf_external = compute_transformed_pdf(theta_internal_i, marg_var_internal_i, theta_internal_interval, self.model.prior_hyperparameters[i].rescale_hyperparameters_to_internal)
+
+ # Initialize parameter dictionary
+ param_dict = {
+ 'mean_internal': float(get_host(theta_internal_i)),
+ 'variance_internal': float(get_host(marg_var_internal_i)),
+ 'mean_external': float(get_host(gauss_hermite_result['mean'])),
+ 'variance_external': float(get_host(gauss_hermite_result['variance'])),
+ 'pdf_data': (get_host(theta_external_interval), get_host(pdf_external)) # tuple of xp arrays
+ }
+
+ # if quantiles is not None, compute quantiles using compute_transformed_quantiles()
+ if quantiles is not None:
+ quantiles_external = compute_transformed_quantiles(
+ theta_internal_i, marg_var_internal_i, quantiles, self.model.prior_hyperparameters[i].rescale_hyperparameters_to_internal
+ )
+
+ # Also compute internal quantiles for completeness
+ from scipy.stats import norm
+ quantiles_internal = get_device(norm.ppf(get_host(quantiles), loc=get_host(theta_internal_i), scale=get_host(xp.sqrt(marg_var_internal_i))))
+
+ param_dict['quantiles'] = {
+ 'levels': get_host(quantiles).tolist(),
+ 'internal': {
+ 'values': get_host(quantiles_internal).tolist(),
+ 'pairs': list(zip(get_host(quantiles).tolist(), get_host(quantiles_internal).tolist()))
+ },
+ 'external': {
+ 'values': get_host(quantiles_external).tolist(),
+ 'pairs': list(zip(get_host(quantiles).tolist(), get_host(quantiles_external).tolist()))
+ }
+ }
+
+ # Store in main results dictionary
+ results['hyperparameters'][param_name] = param_dict
+
+ # return dictionary
+ return results
+
def _compute_covariance_latent_parameters(
- self, theta: NDArray, x_star: NDArray
+ self, theta_internal: NDArray, x_star: NDArray
) -> None:
"""Compute the marginal distribution of the latent parameters x.
@@ -926,30 +1201,46 @@ def _compute_covariance_latent_parameters(
marginal_latent_parameters : NDArray
Marginal distribution of the latent parameters x.
"""
- self.model.theta[:] = theta
+
+ self.model.theta_internal = xp.atleast_1d(theta_internal)
self.model.x[:] = x_star
eta = self.model.a @ self.model.x
-
+
+ synchronize_gpu()
+ tic = time.perf_counter()
self.model.construct_Q_conditional(eta)
- self.solver.cholesky(self.model.Q_conditional, sparsity="bta")
+ synchronize_gpu()
+ toc = time.perf_counter()
+ self.t_construction_qconditional += toc - tic
+
+ self.solver.factorize(self.model.Q_conditional, sparsity="bta")
self.solver.selected_inversion(sparsity="bta")
def get_marginal_variances_latent_parameters(
- self, theta: NDArray = None, x_star: NDArray = None
+ self, theta_external: NDArray = None, x_star: NDArray = None
) -> NDArray:
+
# TODO: this should be only called by rank 0?
- if theta is None and x_star is None:
+ if theta_external is None and x_star is None:
print(
"Computing marginal variances for currently stored latent parameters. "
)
x_star = self.model.x
- theta = self.model.theta
+ theta = self.model.theta_internal
+ elif theta_external is not None and x_star is not None:
+ ## assume theta to be in "external" scale
+ self.model.theta_external = xp.atleast_1d(theta_external)
+ theta = self.model.theta_internal
+
elif theta is None or x_star is None:
raise ValueError(
"BOTH or NEITHER theta and x_star must be provided to compute the marginal variances."
)
-
+
+ check_vector_consistency(theta, comm=self.comm_world, flag="theta", verbose="Full")
+ check_vector_consistency(x_star, comm=self.comm_world, flag="x_star", verbose="Minimal")
+
# check order x_star ... -> potentially need to reorder marginal variances
self._compute_covariance_latent_parameters(theta, x_star)
@@ -958,11 +1249,13 @@ def get_marginal_variances_latent_parameters(
sp.sparse.eye(self.model.n_latent_parameters, dtype=xp.float64),
sparsity="bta",
)
+
marginal_variances = extract_diagonal(marginal_variances_sp)
+
return marginal_variances
def get_marginal_variances_observations(
- self, theta: NDArray, x_star: NDArray
+ self, theta_external: NDArray = None, x_star: NDArray = None
) -> NDArray:
"""Extract the marginal variances of the observations.
@@ -985,21 +1278,26 @@ def get_marginal_variances_observations(
Marginal variances of the observations.
"""
+ # TODO: implement this for non-Gaussian likelihoods
+ check_vector_consistency(theta_external, comm=self.comm_world, flag="theta_external", verbose="Full")
+ check_vector_consistency(x_star, comm=self.comm_world, flag="x_star", verbose="Minimal")
+
if self.model.is_likelihood_gaussian():
# TODO: this should be only called by rank 0?
- if theta is None and x_star is None:
+ if theta_external is None and x_star is None:
print(
"Computing marginal variances for currently stored latent parameters. "
)
x_star = self.model.x
- theta = self.model.theta
- elif theta is None or x_star is None:
+ theta_external = self.model.theta_external
+
+ if theta_external is None or x_star is None:
raise ValueError(
"BOTH or NEITHER theta and x_star must be provided to compute the marginal variances."
)
# check order x_star ... -> potentially need to reorder marginal variances
- self._compute_covariance_latent_parameters(theta, x_star)
+ self._compute_covariance_latent_parameters(theta_external, x_star)
# now only extract diagonal elements corresponding to marginal variances of the latent parameters
variances_latent = self.solver._structured_to_spmatrix(
@@ -1015,10 +1313,9 @@ def get_marginal_variances_observations(
return marginal_variances_observations
- else:
- raise NotImplementedError(
- "in compute marginals observations: Only Gaussian likelihood is currently supported."
- )
+ raise NotImplementedError(
+ "in compute marginals observations: Only Gaussian likelihood is currently supported."
+ )
def _inner_iteration(
self,
@@ -1044,7 +1341,7 @@ def _inner_iteration(
if counter > self.inner_iteration_max_iter:
print_msg(
"Theta value at failing of the inner_iteration: ",
- self.model.theta,
+ self.model.theta_internal,
flush=True,
)
raise ValueError(
@@ -1054,8 +1351,14 @@ def _inner_iteration(
x_star[:] += x_update
eta[:] = self.model.a @ x_star
+ synchronize_gpu()
+ tic = time.perf_counter()
Q_conditional = self.model.construct_Q_conditional(eta)
- self.solver.cholesky(A=Q_conditional)
+ synchronize_gpu()
+ toc = time.perf_counter()
+ self.t_construction_qconditional += toc - tic
+
+ self.solver.factorize(A=Q_conditional, sparsity="bta")
rhs: NDArray = self.model.construct_information_vector(
eta,
@@ -1097,7 +1400,7 @@ def _evaluate_prior_latent_parameters(
Log normal:
.. math:: 0.5*log(1/(2*pi)^n * |Q_prior|)) - 0.5 * x.T Q_prior x
"""
- self.solver.cholesky(self.model.Q_prior, sparsity="bt")
+ self.solver.factorize(self.model.Q_prior, sparsity="bt")
logdet_Q_prior: float = self.solver.logdet(sparsity="bt")
log_prior_latent_parameters: float = +0.5 * logdet_Q_prior
@@ -1146,7 +1449,7 @@ def _evaluate_conditional_latent_parameters(
# the else fails if x_mean is None
else:
# Symmetrizing (averaging the tip of the arrow to tame down numerical innaccuracies)
- tip_accu = x_mean[-self.model.total_number_fixed_effects():].copy()
+ tip_accu = x_mean[-self.model.total_number_fixed_effects() :].copy()
synchronize(comm=self.comm_qeval)
allreduce(
tip_accu,
@@ -1155,7 +1458,7 @@ def _evaluate_conditional_latent_parameters(
comm=self.comm_qeval,
)
synchronize(comm=self.comm_qeval)
- x_mean[-self.model.total_number_fixed_effects():] = tip_accu
+ x_mean[-self.model.total_number_fixed_effects() :] = tip_accu
if x is None and x_mean is not None:
quadratic_form = x_mean.T @ Q_conditional @ x_mean
diff --git a/src/dalia/core/likelihood.py b/src/dalia/core/likelihood.py
index 27e13adb..e9a8f632 100644
--- a/src/dalia/core/likelihood.py
+++ b/src/dalia/core/likelihood.py
@@ -37,11 +37,14 @@ def evaluate_likelihood(
**kwargs : optional
Hyperparameters for likelihood.
-
Returns
-------
likelihood : float
Likelihood.
+
+ Implementation Notes:
+ ---------------------
+ - This function does not guarantee that the likelihood is a scalar. If evaluated from numpy/cupy dot product it will be a ndarray of shape (1,).
"""
pass
diff --git a/src/dalia/core/model.py b/src/dalia/core/model.py
index e0f432a9..01b9e6d6 100644
--- a/src/dalia/core/model.py
+++ b/src/dalia/core/model.py
@@ -3,9 +3,9 @@
import os
from abc import ABC
from pathlib import Path
-from tabulate import tabulate
import numpy as np
+from tabulate import tabulate
from dalia import ArrayLike, NDArray, sp, xp
from dalia.configs.likelihood_config import LikelihoodConfig
@@ -14,6 +14,7 @@
GaussianMVNPriorHyperparametersConfig,
GaussianPriorHyperparametersConfig,
PenalizedComplexityPriorHyperparametersConfig,
+ GammaPriorHyperparametersConfig,
)
from dalia.core.likelihood import Likelihood
from dalia.core.prior_hyperparameters import PriorHyperparameters
@@ -24,14 +25,17 @@
GaussianMVNPriorHyperparameters,
GaussianPriorHyperparameters,
PenalizedComplexityPriorHyperparameters,
+ GammaPriorHyperparameters,
)
from dalia.submodels import (
BrainiacSubModel,
RegressionSubModel,
SpatialSubModel,
SpatioTemporalSubModel,
+ AR1SubModel,
)
-from dalia.utils import scaled_logit, add_str_header, boxify
+from dalia.utils import add_str_header, boxify, scaled_logit
+from dalia.utils.scalar_ndarray import ensure_scalar
class Model(ABC):
@@ -56,8 +60,11 @@ def __init__(
self.n_fixed_effects: int = 0
+ self._theta_external: ArrayLike = []
+ self._theta_internal: ArrayLike = []
+
# For each submodel...
- theta: ArrayLike = []
+ theta_external: ArrayLike = []
theta_keys: ArrayLike = []
self.hyperparameters_idx: ArrayLike = [0]
self.prior_hyperparameters: list[PriorHyperparameters] = []
@@ -154,6 +161,42 @@ def __init__(
elif isinstance(submodel, RegressionSubModel):
self.n_fixed_effects += submodel.n_fixed_effects
+ elif isinstance(submodel, AR1SubModel):
+
+ if isinstance(submodel.config.ph_phi, BetaPriorHyperparametersConfig):
+ self.prior_hyperparameters.append(
+ BetaPriorHyperparameters(
+ config=submodel.config.ph_phi,
+ )
+ )
+ elif isinstance(
+ submodel.config.ph_phi,
+ PenalizedComplexityPriorHyperparametersConfig,
+ ):
+ self.prior_hyperparameters.append(
+ PenalizedComplexityPriorHyperparameters(
+ config=submodel.config.ph_phi,
+ hyperparameter_type="phi",
+ )
+ )
+
+ if isinstance(
+ submodel.config.ph_tau, GaussianPriorHyperparametersConfig
+ ):
+ self.prior_hyperparameters.append(
+ GaussianPriorHyperparameters(
+ config=submodel.config.ph_tau,
+ )
+ )
+ if isinstance(submodel.config.ph_tau, GammaPriorHyperparametersConfig):
+ self.prior_hyperparameters.append(
+ GammaPriorHyperparameters(
+ config=submodel.config.ph_tau,
+ )
+ )
+ else:
+ raise ValueError("Unknown prior hyperparameter type for ph_tau")
+
elif isinstance(submodel, BrainiacSubModel):
# h2 hyperparameters
if isinstance(submodel.config.ph_h2, BetaPriorHyperparametersConfig):
@@ -172,33 +215,29 @@ def __init__(
config=submodel.config.ph_alpha,
)
)
+ if isinstance(
+ submodel.config.ph_alpha,
+ PenalizedComplexityPriorHyperparametersConfig,
+ ):
+ self.prior_hyperparameters.append(
+ PenalizedComplexityPriorHyperparameters(
+ config=submodel.config.ph_alpha,
+ hyperparameter_type="alpha",
+ )
+ )
else:
raise ValueError("Unknown submodel type")
# ...and read their hyperparameters
theta_submodel, theta_keys_submodel = submodel.config.read_hyperparameters()
- theta.append(theta_submodel)
+ theta_external.append(theta_submodel)
theta_keys += theta_keys_submodel
self.hyperparameters_idx.append(
self.hyperparameters_idx[-1] + len(theta_submodel)
)
- # Add the likelihood hyperparameters
- (
- lh_hyperparameters,
- lh_hyperparameters_keys,
- ) = likelihood_config.read_hyperparameters()
-
- theta.append(lh_hyperparameters)
- self.theta: NDArray = xp.concatenate(theta)
-
- theta_keys += lh_hyperparameters_keys
- self.theta_keys: NDArray = theta_keys
-
- self.n_hyperparameters = self.theta.size
-
# --- Initialize the latent parameters and the design matrix
self.n_latent_parameters: int = 0
self.latent_parameters_idx: list[int] = [0]
@@ -209,32 +248,54 @@ def __init__(
self.x: NDArray = xp.zeros(self.n_latent_parameters)
- data = []
- rows = []
- cols = []
- for i, submodel in enumerate(self.submodels):
- # Convert csc_matrix to coo_matrix to allow slicing
- coo_submodel_a = submodel.a.tocoo()
- data.append(coo_submodel_a.data)
- rows.append(coo_submodel_a.row)
- cols.append(
- coo_submodel_a.col
- + self.latent_parameters_idx[i]
- * xp.ones(coo_submodel_a.col.size, dtype=int)
+ # check if all a are sparse -> if not construct dense a
+ if all(sp.sparse.issparse(submodel.a) for submodel in self.submodels):
+ data = []
+ rows = []
+ cols = []
+ for i, submodel in enumerate(self.submodels):
+ # Convert csc_matrix to coo_matrix to allow slicing
+ coo_submodel_a = submodel.a.tocoo()
+ data.append(coo_submodel_a.data)
+ rows.append(coo_submodel_a.row)
+ cols.append(
+ coo_submodel_a.col
+ + self.latent_parameters_idx[i]
+ * xp.ones(coo_submodel_a.col.size, dtype=int)
+ )
+
+ self.x[
+ self.latent_parameters_idx[i] : self.latent_parameters_idx[i + 1]
+ ] = submodel.x_initial
+
+ self.a: sp.sparse.spmatrix = sp.sparse.coo_matrix(
+ (xp.concatenate(data), (xp.concatenate(rows), xp.concatenate(cols))),
+ shape=(submodel.a.shape[0], self.n_latent_parameters),
)
+ else:
+ data = []
+ for i, submodel in enumerate(self.submodels):
+ if sp.sparse.issparse(submodel.a):
+ data.append(submodel.a.toarray())
+ else:
+ data.append(submodel.a)
- self.x[
- self.latent_parameters_idx[i] : self.latent_parameters_idx[i + 1]
- ] = submodel.x_initial
+ self.x[
+ self.latent_parameters_idx[i] : self.latent_parameters_idx[i + 1]
+ ] = submodel.x_initial
- self.a: sp.sparse.spmatrix = sp.sparse.coo_matrix(
- (xp.concatenate(data), (xp.concatenate(rows), xp.concatenate(cols))),
- shape=(submodel.a.shape[0], self.n_latent_parameters),
+ self.a: NDArray = xp.concatenate(data, axis=1)
+
+ self.permutation_latent_variables = xp.arange(0, self.n_latent_parameters, 1)
+ self.inverse_permutation_latent_variables = xp.arange(
+ 0, self.n_latent_parameters, 1
)
- # TODO: not so efficient ...
- self.permutation_latent_variables = xp.arange(self.n_latent_parameters)
- self.inverse_permutation_latent_variables = xp.arange(self.n_latent_parameters)
+ # if data is gaussian compute t(A)*A once
+ if likelihood_config.type == "gaussian":
+ self.aTa = self.a.T @ self.a
+ else:
+ self.aTa = None
# --- Load observation vector
input_dir = Path(
@@ -252,7 +313,6 @@ def __init__(
self.n_observations: int = self.y.shape[0]
# --- Initialize likelihood
- # TODO: clean this -> so that for brainiac model we don't add additional hyperperameter
if likelihood_config.type == "gaussian":
self.likelihood: Likelihood = GaussianLikelihood(
n_observations=self.n_observations,
@@ -284,6 +344,24 @@ def __init__(
hyperparameter_type="prec_o",
)
)
+ elif isinstance(
+ likelihood_config.prior_hyperparameters,
+ BetaPriorHyperparametersConfig,
+ ):
+ self.prior_hyperparameters.append(
+ BetaPriorHyperparameters(
+ config=likelihood_config.prior_hyperparameters,
+ )
+ )
+ elif isinstance(
+ likelihood_config.prior_hyperparameters,
+ GammaPriorHyperparametersConfig,
+ ):
+ self.prior_hyperparameters.append(
+ GammaPriorHyperparameters(
+ config=likelihood_config.prior_hyperparameters,
+ )
+ )
elif likelihood_config.type == "poisson":
self.likelihood: Likelihood = PoissonLikelihood(
n_observations=self.n_observations,
@@ -297,12 +375,64 @@ def __init__(
self.likelihood_config: LikelihoodConfig = likelihood_config
+ # Add the likelihood hyperparameters
+ (
+ lh_hyperparameters,
+ lh_hyperparameters_keys,
+ ) = likelihood_config.read_hyperparameters()
+
+ theta_external.append(lh_hyperparameters)
+ self.theta_external = xp.concatenate(theta_external)
+
+ print("Initial hyperparameters (external scale): ", self.theta_external)
+ print("Initial hyperparameters (internal scale): ", self.theta_internal)
+
+ theta_keys += lh_hyperparameters_keys
+ self.theta_keys: NDArray = theta_keys
+
+ self.n_hyperparameters = self.theta_external.size
+
# --- Recurrent variables
self.Q_prior = None
self.Q_prior_data_mapping = [0]
self.Q_conditional = None
self.Q_conditional_data_mapping = [0]
+ ########################################################################
+ @property
+ def theta_external(self):
+ """External/user/interpretable scale theta."""
+ # the copy is important to make sure that in place operations still trigger updating
+ return self._theta_external.copy()
+
+ @theta_external.setter
+ def theta_external(self, value):
+ """Set external theta and automatically update internal.
+
+ Notes
+ -----
+ The re-scaling is implemented for all prios but PenalizedComplexity (identity but already in the correct "log" scale).
+ """
+ self._theta_external = xp.array(value)
+ self._theta_internal = self.rescale_hyperparameters_to_internal(
+ self._theta_external, direction="forward"
+ )
+
+ @property
+ def theta_internal(self):
+ """Internal/BFGS scale theta."""
+ return self._theta_internal.copy()
+
+ @theta_internal.setter
+ def theta_internal(self, value):
+ """Set internal theta and automatically update external."""
+ self._theta_internal = xp.array(value)
+ self._theta_external = self.rescale_hyperparameters_to_internal(
+ self._theta_internal, direction="backward"
+ )
+
+ ########################################################################
+
def construct_Q_prior(self) -> sp.sparse.spmatrix:
kwargs = {}
@@ -313,22 +443,40 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
cols = []
data = []
+ ## TODO: improve the if / elif statements
for i, submodel in enumerate(self.submodels):
if isinstance(submodel, SpatioTemporalSubModel):
for hp_idx in range(
self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
):
- kwargs[self.theta_keys[hp_idx]] = float(self.theta[hp_idx])
+ kwargs[self.theta_keys[hp_idx]] = float(
+ self.theta_external[hp_idx]
+ )
+ # kwargs[self.theta_keys[hp_idx]] = float(theta_interpret[hp_idx])
elif isinstance(submodel, SpatialSubModel):
for hp_idx in range(
self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
):
- kwargs[self.theta_keys[hp_idx]] = float(self.theta[hp_idx])
+ kwargs[self.theta_keys[hp_idx]] = float(
+ self.theta_external[hp_idx]
+ )
+ # kwargs[self.theta_keys[hp_idx]] = float(theta_interpret[hp_idx])
elif isinstance(submodel, BrainiacSubModel):
for hp_idx in range(
self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
):
- kwargs[self.theta_keys[hp_idx]] = float(self.theta[hp_idx])
+ kwargs[self.theta_keys[hp_idx]] = float(
+ self.theta_external[hp_idx]
+ )
+ # kwargs[self.theta_keys[hp_idx]] = float(theta_interpret[hp_idx])
+ elif isinstance(submodel, AR1SubModel):
+ for hp_idx in range(
+ self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
+ ):
+ kwargs[self.theta_keys[hp_idx]] = float(
+ self.theta_external[hp_idx]
+ )
+ # kwargs[self.theta_keys[hp_idx]] = float(theta_interpret[hp_idx])
elif isinstance(submodel, RegressionSubModel):
...
@@ -361,17 +509,34 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
for hp_idx in range(
self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
):
- kwargs[self.theta_keys[hp_idx]] = float(self.theta[hp_idx])
+ kwargs[self.theta_keys[hp_idx]] = float(
+ self.theta_external[hp_idx]
+ )
+ # kwargs[self.theta_keys[hp_idx]] = float(theta_interpret[hp_idx])
elif isinstance(submodel, SpatialSubModel):
for hp_idx in range(
self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
):
- kwargs[self.theta_keys[hp_idx]] = float(self.theta[hp_idx])
+ kwargs[self.theta_keys[hp_idx]] = float(
+ self.theta_external[hp_idx]
+ )
+ # kwargs[self.theta_keys[hp_idx]] = float(theta_interpret[hp_idx])
elif isinstance(submodel, BrainiacSubModel):
for hp_idx in range(
self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
):
- kwargs[self.theta_keys[hp_idx]] = float(self.theta[hp_idx])
+ kwargs[self.theta_keys[hp_idx]] = float(
+ self.theta_external[hp_idx]
+ )
+ # kwargs[self.theta_keys[hp_idx]] = float(theta_interpret[hp_idx])
+ elif isinstance(submodel, AR1SubModel):
+ for hp_idx in range(
+ self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
+ ):
+ kwargs[self.theta_keys[hp_idx]] = float(
+ self.theta_external[hp_idx]
+ )
+ # kwargs[self.theta_keys[hp_idx]] = float(theta_interpret[hp_idx])
submodel_Q_prior = submodel.construct_Q_prior(**kwargs)
@@ -384,7 +549,7 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
def construct_Q_conditional(
self,
eta: NDArray,
- ) -> float:
+ ):
"""Construct the conditional precision matrix.
Note
@@ -394,16 +559,10 @@ def construct_Q_conditional(
"""
- # TODO: need to vectorize
- # hessian_likelihood_diag = hessian_diag_finite_difference_5pt(
- # self.likelihood.evaluate_likelihood, eta, self.y, theta_likelihood
- # )
- # hessian_likelihood = diags(hessian_likelihood_diag)
-
if self.likelihood_config.type == "gaussian":
kwargs = {
"eta": eta,
- "theta": float(self.theta[-1]),
+ "theta": float(self.theta_external[-1]),
}
else:
kwargs = {
@@ -412,13 +571,29 @@ def construct_Q_conditional(
if isinstance(self.submodels[0], BrainiacSubModel):
# Brainiac specific rule
- kwargs["h2"] = float(self.theta[0])
+ kwargs["h2"] = float(self.theta_external[0])
d_matrix = self.submodels[0].evaluate_d_matrix(**kwargs)
else:
# General rules
d_matrix = self.likelihood.evaluate_hessian_likelihood(**kwargs)
- self.Q_conditional = self.Q_prior - self.a.T @ d_matrix @ self.a
+ # if self.a is sparse -> Q_conditional should be sparse, else dense
+ if sp.sparse.issparse(self.a):
+ if self.aTa is not None:
+ self.Q_conditional = self.Q_prior - d_matrix.diagonal()[0] * self.aTa
+ else:
+ self.Q_conditional = self.Q_prior - self.a.T @ d_matrix @ self.a
+ # self.Q_conditional = self.Q_prior - self.a.T @ d_matrix @ self.a
+ else:
+ if self.aTa is not None:
+ self.Q_conditional = (
+ self.Q_prior.toarray() - d_matrix.diagonal()[0] * self.aTa
+ )
+ else:
+ self.Q_conditional = (
+ self.Q_prior.toarray() - self.a.T @ d_matrix @ self.a
+ )
+ # self.Q_conditional = self.Q_prior.toarray() - self.a.T @ d_matrix @ self.a
return self.Q_conditional
@@ -430,7 +605,7 @@ def construct_information_vector(
"""Construct the information vector."""
if isinstance(self.submodels[0], BrainiacSubModel):
- kwargs = {"h2": float(self.theta[0])}
+ kwargs = {"h2": float(self.theta_external[0])}
gradient_likelihood = self.submodels[0].evaluate_gradient_likelihood(
eta=eta, y=self.y, **kwargs
)
@@ -439,7 +614,7 @@ def construct_information_vector(
gradient_likelihood = self.likelihood.evaluate_gradient_likelihood(
eta=eta,
y=self.y,
- theta=self.theta[self.hyperparameters_idx[-1] :],
+ theta=self.theta_external[self.hyperparameters_idx[-1] :],
)
information_vector: NDArray = (
@@ -456,24 +631,16 @@ def evaluate_log_prior_hyperparameters(self) -> float:
"""Evaluate the log prior hyperparameters."""
log_prior = 0.0
- # if BFGS and model scale differ: rescale -- generalize
- if isinstance(self.submodels[0], BrainiacSubModel):
- #
- theta_interpret = self.theta.copy()
- theta_interpret[0] = scaled_logit(self.theta[0], direction="backward")
- # TODO: multivariate prior for a ... need to generalize for now:
- log_prior += self.prior_hyperparameters[0].evaluate_log_prior(
- theta_interpret[0]
- )
-
- log_prior += self.prior_hyperparameters[1].evaluate_log_prior(
- theta_interpret[1:]
- )
- else:
- theta_interpret = self.theta
-
- for i, prior_hyperparameter in enumerate(self.prior_hyperparameters):
- log_prior += prior_hyperparameter.evaluate_log_prior(theta_interpret[i])
+ for i, prior_hyperparameter in enumerate(self.prior_hyperparameters):
+ if isinstance(prior_hyperparameter, GaussianMVNPriorHyperparameters):
+ # for MVN prior hyperparameters, we need to pass the full vector
+ log_prior += prior_hyperparameter.evaluate_log_prior(
+ self.theta_external[i : i + prior_hyperparameter.mean.shape[0]]
+ )
+ else:
+ log_prior += prior_hyperparameter.evaluate_log_prior(
+ self.theta_external[i]
+ )
return log_prior
@@ -481,34 +648,82 @@ def get_theta_likelihood(self) -> NDArray:
"""Return the likelihood hyperparameters."""
if isinstance(self.submodels[0], BrainiacSubModel):
- theta_likelihood = 1 - scaled_logit(self.theta[0], direction="backward")
+ theta_likelihood = 1 - self.theta_external[0]
else:
- theta_likelihood = self.theta[self.hyperparameters_idx[-1] :]
+ theta_likelihood = self.theta_external[self.hyperparameters_idx[-1] :]
return theta_likelihood
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+
+ # need to iterate over theta and its prior hyperparameters
+ theta_internal = xp.copy(theta)
+
+ for i, prior_hyperparameter in enumerate(self.prior_hyperparameters):
+ ## how to handle priors that have multiple hyperparameters?
+ if isinstance(prior_hyperparameter, GaussianMVNPriorHyperparameters):
+ pass # no rescaling implemented
+ else:
+ theta_internal[i] = (
+ prior_hyperparameter.rescale_hyperparameters_to_internal(
+ theta[i], direction=direction
+ )
+ )
+
+ return theta_internal
+
def evaluate_likelihood(self, eta: NDArray, **kwargs) -> float:
- """Evaluate the likelihood."""
+ """Evaluate the likelihood.
+
+ Parameters
+ ----------
+ eta : NDArray
+ Linear predictor.
+ kwargs : dict
+ Additional arguments for the likelihood evaluation. These parameters are model dependent.
+
+ Returns
+ -------
+ likelihood : float
+ The evaluated likelihood.
+ """
if isinstance(self.submodels[0], BrainiacSubModel):
- kwargs["h2"] = float(self.theta[0])
+ # kwargs["h2"] = float(self.theta[0])
+ kwargs["h2"] = float(self.theta_external[0])
likelihood = self.submodels[0].evaluate_likelihood(eta, self.y, **kwargs)
else:
likelihood = self.likelihood.evaluate_likelihood(
- eta, self.y, theta=self.theta[self.hyperparameters_idx[-1] :]
+ eta, self.y, theta=self.theta_external[self.hyperparameters_idx[-1] :]
)
- return likelihood
+ return ensure_scalar(likelihood)
+
+
def __str__(self) -> str:
"""String representation of the model."""
str_representation = ""
# --- Make the Model() table ---
- headers = ["Number of Hyperparameters", "Number of Latent Parameters", "Number of Observations", "Type of Likelihood"]
- values = [self.n_hyperparameters, self.n_latent_parameters, self.n_observations, self.likelihood_config.type.capitalize()]
-
- model_table = tabulate([headers, values], tablefmt="fancy_grid", colalign=("center", "center", "center", "center"))
+ headers = [
+ "Number of Hyperparameters",
+ "Number of Latent Parameters",
+ "Number of Observations",
+ "Type of Likelihood",
+ ]
+ values = [
+ self.n_hyperparameters,
+ self.n_latent_parameters,
+ self.n_observations,
+ self.likelihood_config.type.capitalize(),
+ ]
+
+ model_table = tabulate(
+ [headers, values],
+ tablefmt="fancy_grid",
+ colalign=("center", "center", "center", "center"),
+ )
# Add the header title
model_table = add_str_header("Default Model", model_table)
@@ -524,14 +739,16 @@ def __str__(self) -> str:
# Pad each list of lines to the same length
for lines in lines_list:
- lines += [''] * (max_len - len(lines))
+ lines += [""] * (max_len - len(lines))
# Concatenate corresponding lines
- result_lines = [' '.join(parts) for parts in zip(*lines_list)]
- submodel_jointed_representation = '\n'.join(result_lines)
+ result_lines = [" ".join(parts) for parts in zip(*lines_list)]
+ submodel_jointed_representation = "\n".join(result_lines)
# Add the submodel header title
- submodel_jointed_representation = add_str_header("Submodels", submodel_jointed_representation)
+ submodel_jointed_representation = add_str_header(
+ "Submodels", submodel_jointed_representation
+ )
# Combine the model and submodel tables
str_representation = model_table + "\n" + submodel_jointed_representation
@@ -558,15 +775,14 @@ def get_solver_parameters(self) -> dict:
}
return param
-
-
+
def construct_a_predict(self) -> sp.sparse.spmatrix:
"""Construct the design matrix for prediction."""
-
+
data = []
rows = []
cols = []
-
+
rows_a_predict = 0
for i, submodel in enumerate(self.submodels):
# Convert csc_matrix to coo_matrix to allow slicing
@@ -578,10 +794,10 @@ def construct_a_predict(self) -> sp.sparse.spmatrix:
+ self.latent_parameters_idx[i]
* xp.ones(coo_submodel_a_predict.col.size, dtype=int)
)
-
+
# the number of rows in all of them is the same
rows_a_predict = coo_submodel_a_predict.shape[0]
-
+
self.a_predict: sp.sparse.spmatrix = sp.sparse.coo_matrix(
(xp.concatenate(data), (xp.concatenate(rows), xp.concatenate(cols))),
shape=(rows_a_predict, self.n_latent_parameters),
@@ -589,4 +805,4 @@ def construct_a_predict(self) -> sp.sparse.spmatrix:
def total_number_fixed_effects(self) -> int:
"""Get the number of fixed effects."""
- return self.n_fixed_effects
\ No newline at end of file
+ return self.n_fixed_effects
diff --git a/src/dalia/core/prior_hyperparameters.py b/src/dalia/core/prior_hyperparameters.py
index 88c1ba4e..c438b7f5 100644
--- a/src/dalia/core/prior_hyperparameters.py
+++ b/src/dalia/core/prior_hyperparameters.py
@@ -1,6 +1,7 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
from abc import ABC, abstractmethod
+from dalia import xp
from dalia.configs.priorhyperparameters_config import PriorHyperparametersConfig
@@ -16,6 +17,23 @@ def __init__(
self.config: PriorHyperparametersConfig = config
+ @abstractmethod
+ def rescale_hyperparameters_to_internal(self, theta: float, direction: str) -> float:
+ """Rescale hyperparameters to and from internal scale.
+
+ Args:
+ theta: Hyperparameter
+ direction: "forward", "backward", "forward_jacobian", "backward_jacobian"
+ Returns:
+ Rescaled hyperparameter.
+
+ """
+
+ if direction == "forward" or direction == "backward":
+ return theta
+ elif direction == "forward_jacobian" or direction == "backward_jacobian":
+ return xp.ones_like(theta)
+
@abstractmethod
def evaluate_log_prior(self, theta: float) -> float:
"""Evaluate the log prior hyperparameters."""
diff --git a/src/dalia/core/solver.py b/src/dalia/core/solver.py
index f5e3e95b..bcf07653 100644
--- a/src/dalia/core/solver.py
+++ b/src/dalia/core/solver.py
@@ -24,28 +24,49 @@ def __init__(
self.config = config
@abstractmethod
- def cholesky(self, A: ArrayLike, **kwargs) -> None:
- """Compute Cholesky factor of input matrix.
+ def factorize(self, A: ArrayLike, **kwargs) -> None:
+ """Compute the decomposition of a matrix.
Parameters
----------
- A : ArrayLike
- Input matrix.
+ A : NDArray | sp.sparse.spmatrix
+ The input matrix to decompose.
Returns
-------
None
+
+ Note:
+ -----
+ Even though precision matrices are known to be positive definite, depending on the underlying solver implementation, this could be Cholesky, LU, or other factorizations.
"""
...
@abstractmethod
def solve(self, rhs: NDArray, **kwargs) -> NDArray:
- """Solve linear system using Cholesky factor."""
+ """Solve linear system using Cholesky factor.
+
+ Parameters
+ ----------
+ rhs : NDArray
+ Right-hand side of the linear system.
+
+ Returns
+ -------
+ NDArray
+ Solution of the linear system.
+ """
...
@abstractmethod
def logdet(self, **kwargs) -> float:
- """Compute logdet of input matrix using Cholesky factor."""
+ """Compute the log determinant of the matrix.
+
+ Returns
+ -------
+ float
+ The log determinant of the matrix.
+ """
...
@abstractmethod
@@ -61,4 +82,4 @@ def _structured_to_spmatrix(self, **kwargs) -> None:
@abstractmethod
def get_solver_memory(self) -> int:
"""Return the memory used by the solver in number of bytes"""
- ...
\ No newline at end of file
+ ...
diff --git a/src/dalia/core/submodel.py b/src/dalia/core/submodel.py
index f45d572e..0c8123da 100644
--- a/src/dalia/core/submodel.py
+++ b/src/dalia/core/submodel.py
@@ -4,11 +4,12 @@
from pathlib import Path
import numpy as np
-from scipy.sparse import csc_matrix, load_npz, spmatrix
+from scipy.sparse import load_npz, spmatrix
from dalia import NDArray, sp, xp
from dalia.configs.submodels_config import SubModelConfig
+
class SubModel(ABC):
"""Abstract core class for statistical models."""
@@ -22,11 +23,23 @@ def __init__(
self.submodel_type = config.type
# --- Load design matrix
- a: spmatrix = csc_matrix(load_npz(self.input_path.joinpath("a.npz")))
- if xp == np:
- self.a: sp.sparse.spmatrix = a
- else:
- self.a: sp.sparse.spmatrix = sp.sparse.csc_matrix(a)
+
+ try:
+ a: spmatrix = load_npz(self.input_path.joinpath("a.npz"))
+ self.a = sp.sparse.csc_matrix(a)
+ except FileNotFoundError:
+ # check if dense a matrix exists
+ try:
+ a: NDArray = np.load(self.input_path.joinpath("a.npy"))
+ if xp == np:
+ self.a: NDArray = a
+ else:
+ self.a: NDArray = xp.array(a)
+ except FileNotFoundError:
+ raise FileNotFoundError(
+ "No design matrix found. Please provide a valid design matrix."
+ )
+
self.n_latent_parameters: int = self.a.shape[1]
# --- Load latent parameters vector
@@ -44,25 +57,18 @@ def __init__(
def construct_Q_prior(self, **kwargs) -> sp.sparse.coo_matrix:
"""Construct the prior precision matrix."""
...
-
+
def load_a_predict(self) -> sp.sparse.csc_matrix:
"""Load the design matrix for prediction."""
- a_predict: sp.sparse.csc_matrix = csc_matrix(
+ self.a_predict: sp.sparse.csc_matrix = sp.sparse.csc_matrix(
load_npz(self.input_path.joinpath("apr.npz"))
)
-
- if xp == np:
- self.a_predict: sp.sparse.spmatrix = a_predict
- else:
- self.a_predict: sp.sparse.spmatrix = sp.sparse.csc_matrix(a_predict)
-
+
# check that number of columns is the same as in a
if self.a_predict.shape[1] != self.a.shape[1]:
raise ValueError(
f"Number of columns in a_predict ({self.a_predict.shape[1]}) "
f"does not match number of columns in a ({self.a.shape[1]})."
)
-
- return self.a_predict
-
+ return self.a_predict
diff --git a/src/dalia/kernels/blockmapping.py b/src/dalia/kernels/blockmapping.py
index 4d873c35..27ea5bae 100644
--- a/src/dalia/kernels/blockmapping.py
+++ b/src/dalia/kernels/blockmapping.py
@@ -100,6 +100,7 @@ def compute_block_sort_index(
return sort_index
+
def compute_block_slice(
rows: xp.ndarray,
cols: xp.ndarray,
diff --git a/src/dalia/likelihoods/binomial.py b/src/dalia/likelihoods/binomial.py
index 53c2754d..862f958e 100644
--- a/src/dalia/likelihoods/binomial.py
+++ b/src/dalia/likelihoods/binomial.py
@@ -53,14 +53,14 @@ def evaluate_likelihood(
y : NDArray
Vector of the observations.
- Notes
- -----
- For now only a sigmoid link-function is implemented.
-
Returns
-------
likelihood : float
Likelihood.
+
+ Notes
+ -----
+ - For now only a sigmoid link-function is implemented.
"""
linkEta: NDArray = self.link_function(eta)
diff --git a/src/dalia/likelihoods/gaussian.py b/src/dalia/likelihoods/gaussian.py
index ee69bd16..6ed8a864 100644
--- a/src/dalia/likelihoods/gaussian.py
+++ b/src/dalia/likelihoods/gaussian.py
@@ -29,7 +29,7 @@ def evaluate_likelihood(
Evaluate Gaussian log-likelihood for a given set of observations, latent parameters, and design matrix, where
the observations are assumed to be identically and independently distributed given eta (=A*x). Leading to:
- log (p(y|eta)) = -0.5 * n * log(2 * pi) - 0.5 * n * theta_observations - 0.5 * exp(theta_observations) * (y - eta)^T * (y - eta)
+ log (p(y|eta)) = -0.5 * n * log(2 * pi) - 0.5 * n * log(theta) - 0.5 * theta * (y - eta)^T * (y - eta)
where the constant in front of the likelihood is omitted.
Parameters
@@ -40,7 +40,7 @@ def evaluate_likelihood(
Vector of the observations.
kwargs :
theta : float
- Specific parameter for the likelihood calculation.
+ precision parameter for the likelihood calculation. theta > 0.
Returns
-------
@@ -52,11 +52,16 @@ def evaluate_likelihood(
if theta is None:
raise ValueError("theta must be provided to evaluate gaussian likelihood.")
+ if theta <= 0:
+ raise ValueError(f"theta must be positive, got {theta}")
+
yEta = eta - y
# print("xp.exp(theta) in lh:", xp.exp(theta))
likelihood: float = (
- 0.5 * theta * self.n_observations - 0.5 * xp.exp(theta) * yEta.T @ yEta
+ -0.5 * self.n_observations * xp.log(2 * xp.pi)
+ + 0.5 * xp.log(theta) * self.n_observations
+ - 0.5 * theta * yEta.T @ yEta
)
return likelihood
@@ -77,7 +82,7 @@ def evaluate_gradient_likelihood(
Vector of the observations.
kwargs :
theta : float
- Specific parameter for the likelihood calculation.
+ precision parameter for the likelihood calculation. theta > 0.
Returns
-------
@@ -91,7 +96,10 @@ def evaluate_gradient_likelihood(
"theta must be provided to evaluate gradient of gaussian likelihood."
)
- gradient_likelihood: NDArray = -xp.exp(theta) * (eta - y)
+ if theta <= 0:
+ raise ValueError(f"theta must be positive, got {theta}")
+
+ gradient_likelihood: NDArray = -theta * (eta - y)
return gradient_likelihood
@@ -121,11 +129,11 @@ def evaluate_hessian_likelihood(
raise ValueError(
"theta must be provided to evaluate gradient of gaussian likelihood."
)
+ if theta <= 0:
+ raise ValueError(f"theta must be positive, got {theta}")
# print("hessian lh: xp.exp(theta)", xp.exp(theta))
- hessian_likelihood: ArrayLike = -xp.exp(theta) * sp.sparse.eye(
- self.n_observations
- )
+ hessian_likelihood: ArrayLike = -theta * sp.sparse.eye(self.n_observations)
return hessian_likelihood
diff --git a/src/dalia/models/coregional_model.py b/src/dalia/models/coregional_model.py
index 166fbd03..e1c14671 100644
--- a/src/dalia/models/coregional_model.py
+++ b/src/dalia/models/coregional_model.py
@@ -1,11 +1,11 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
import re
-from tabulate import tabulate
import numpy as np
+from tabulate import tabulate
-from dalia import ArrayLike, NDArray, sp, xp, backend_flags
+from dalia import ArrayLike, NDArray, sp, xp
from dalia.configs.models_config import CoregionalModelConfig
from dalia.core.model import Model
from dalia.core.prior_hyperparameters import PriorHyperparameters
@@ -14,8 +14,15 @@
PenalizedComplexityPriorHyperparameters,
)
from dalia.submodels import RegressionSubModel, SpatialSubModel, SpatioTemporalSubModel
-from dalia.utils import bdiag_tiling, free_unused_gpu_memory
-from dalia.utils import add_str_header, align_tables_side_by_side, boxify
+from dalia.utils import (
+ add_str_header,
+ align_tables_side_by_side,
+ bdiag_tiling,
+ boxify,
+ free_unused_gpu_memory,
+)
+from dalia.utils.scalar_ndarray import ensure_scalar
+
class CoregionalModel(Model):
"""Core class for statistical models."""
@@ -95,14 +102,18 @@ def __init__(
f"Model {model} has a different number of fixed effects than the first model"
)
+ self._theta_external: ArrayLike = []
+ self._theta_internal: ArrayLike = []
+
+ # For each model
# Get Models() hyperparameters
- theta: ArrayLike = []
+ theta_external: ArrayLike = []
theta_keys: ArrayLike = []
self.hyperparameters_idx: ArrayLike = [0]
self.prior_hyperparameters: list[PriorHyperparameters] = []
for model in self.models:
- theta_model = model.theta
+ theta_model = model.theta_external
theta_keys_model = model.theta_keys
# remove the theta that correspond to the "sigma_xx" where x can be whatever
@@ -118,7 +129,7 @@ def __init__(
key for i, key in enumerate(theta_keys_model) if i not in sigma_indices
]
- theta.append(xp.array(theta_model))
+ theta_external.append(xp.array(theta_model))
theta_keys += theta_keys_model
self.hyperparameters_idx.append(
@@ -137,7 +148,7 @@ def __init__(
theta_coregional_model,
theta_keys_coregional_model,
) = coregional_model_config.read_hyperparameters()
- theta.append(xp.array(theta_coregional_model))
+ theta_external.append(xp.array(theta_coregional_model))
theta_keys += theta_keys_coregional_model
self.hyperparameters_idx.append(
@@ -145,8 +156,8 @@ def __init__(
)
# Finalize the hyperparameters
- self.theta: NDArray = xp.concatenate(theta)
- self.n_hyperparameters = self.theta.size
+ self.theta_external: NDArray = xp.concatenate(theta_external)
+ self.n_hyperparameters = self.theta_external.size
self.theta_keys: NDArray = theta_keys
# Initialize the Coregional Prior Hyperparameters
@@ -201,19 +212,20 @@ def __init__(
self.latent_parameters_idx[i] : self.latent_parameters_idx[i + 1]
] = model.x
- self.y[
- self.n_observations_idx[i] : self.n_observations_idx[i + 1]
- ] = model.y
+ self.y[self.n_observations_idx[i] : self.n_observations_idx[i + 1]] = (
+ model.y
+ )
+
+ self.a: sp.sparse.spmatrix = bdiag_tiling(
+ [model.a for model in self.models]
+ ).tocsc()
- self.a: sp.sparse.spmatrix = bdiag_tiling([model.a for model in self.models]).tocsc()
-
for model in self.models:
model.a = None
model.y = None
model.x = None
-
- free_unused_gpu_memory()
-
+
+ free_unused_gpu_memory()
if self.coregionalization_type == "spatio_temporal":
self.permutation_Qst = self._generate_permutation_indices(
@@ -237,7 +249,7 @@ def __init__(
# don't permute when nt is 1
self.permutation_Qst = xp.arange(0, self.n_models * n_re, 1)
-
+
# permute fixed effects to the end
self.permutation_latent_variables = (
self._generate_permutation_indices_spatial(
@@ -247,7 +259,7 @@ def __init__(
self.a = self.a[:, self.permutation_latent_variables]
self.x = self.x[self.permutation_latent_variables]
-
+
# self.inverse_permutation_latent_variables = xp.argsort(self.permutation_latent_variables)
# self.perm2 = self._generate_permutation_indices_for_a_new(
# self.n_temporal_nodes,
@@ -259,7 +271,7 @@ def __init__(
# --- Recurrent variables
self.Q_prior_data_mapping = [0]
-
+
self.rows_Qprior_re = None
self.columns_Qprior_re = None
self.data_Qprior_re = None
@@ -270,17 +282,18 @@ def __init__(
self.Q_conditional = None
self.Q_conditional_data_mapping = [0]
-
- self.Q_prior: sp.sparse.spmatrix = None # need this otherwise the construct will fail
+
+ self.Q_prior: sp.sparse.spmatrix = (
+ None # need this otherwise the construct will fail
+ )
self.construct_Q_prior()
-
def construct_Q_prior(self) -> sp.sparse.spmatrix:
# number of random effects per model
n_re = self.n_spatial_nodes * self.n_temporal_nodes
-
- Qu_list: list = [None] * self.n_models
+
+ Qu_list: list = [None] * self.n_models
Q_r: list = [None] * self.n_models
for i, model in enumerate(self.models):
@@ -290,7 +303,7 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
for hp_idx in range(
self.hyperparameters_idx[i], self.hyperparameters_idx[i + 1]
):
- kwargs_st[self.theta_keys[hp_idx]] = float(self.theta[hp_idx])
+ kwargs_st[self.theta_keys[hp_idx]] = float(self.theta_external[hp_idx])
Qu_list[i] = submodel_st.construct_Q_prior(**kwargs_st).tocsc()
@@ -301,10 +314,10 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
kwargs_r = {}
Q_r[i] = submodel_r.construct_Q_prior(**kwargs_r).tocsc()
- sigma_0 = xp.exp(self.theta[self.theta_keys.index("sigma_0")])
- sigma_1 = xp.exp(self.theta[self.theta_keys.index("sigma_1")])
+ sigma_0 = xp.exp(self.theta_external[self.theta_keys.index("sigma_0")])
+ sigma_1 = xp.exp(self.theta_external[self.theta_keys.index("sigma_1")])
- lambda_0_1 = self.theta[self.theta_keys.index("lambda_0_1")]
+ lambda_0_1 = self.theta_external[self.theta_keys.index("lambda_0_1")]
if self.n_models == 2:
q11 = sp.sparse.coo_matrix(
@@ -312,22 +325,22 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
+ (lambda_0_1**2 / sigma_1**2) * Qu_list[1]
)
if not q11.has_canonical_format:
- q11.sum_duplicates()
-
+ q11.sum_duplicates()
+
Qu_list[0] = None
q21 = sp.sparse.coo_matrix((-lambda_0_1 / sigma_1**2) * Qu_list[1])
if not q21.has_canonical_format:
- q21.sum_duplicates()
-
+ q21.sum_duplicates()
+
q12 = sp.sparse.coo_matrix((-lambda_0_1 / sigma_1**2) * Qu_list[1])
if not q12.has_canonical_format:
- q12.sum_duplicates()
-
+ q12.sum_duplicates()
+
q22 = sp.sparse.coo_matrix((1 / sigma_1**2) * Qu_list[1])
if not q22.has_canonical_format:
- q22.sum_duplicates()
-
+ q22.sum_duplicates()
+
Qu_list[1] = None
if self.data_Qprior_re is None:
@@ -342,14 +355,14 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
q22_rows = q22.row + n_re
q22_columns = q22.col + n_re
-
+
self.rows_Qprior_re = xp.concatenate(
[q11_rows, q12_rows, q21_rows, q22_rows]
)
self.columns_Qprior_re = xp.concatenate(
[q11_columns, q12_columns, q21_columns, q22_columns]
)
-
+
self.data_Qprior_re = xp.concatenate(
[q11.data, q12.data, q21.data, q22.data]
)
@@ -357,11 +370,11 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
# Qprior_st = sp.sparse.bmat([[q11, q12], [q21, q22]]).tocsc()
elif self.n_models == 3:
- sigma_2 = xp.exp(self.theta[self.theta_keys.index("sigma_2")])
+ sigma_2 = xp.exp(self.theta_external[self.theta_keys.index("sigma_2")])
+
+ lambda_0_2 = self.theta_external[self.theta_keys.index("lambda_0_2")]
+ lambda_1_2 = self.theta_external[self.theta_keys.index("lambda_1_2")]
- lambda_0_2 = self.theta[self.theta_keys.index("lambda_0_2")]
- lambda_1_2 = self.theta[self.theta_keys.index("lambda_1_2")]
-
q11 = sp.sparse.coo_matrix(
(1 / sigma_0**2) * Qu_list[0]
+ (lambda_0_1**2 / sigma_1**2) * Qu_list[1]
@@ -369,83 +382,82 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
)
Qu_list[0] = None
if not q11.has_canonical_format:
- q11.sum_duplicates()
+ q11.sum_duplicates()
q21 = sp.sparse.coo_matrix(
(-lambda_0_1 / sigma_1**2) * Qu_list[1]
+ (lambda_0_2 * lambda_1_2 / sigma_2**2) * Qu_list[2]
)
if not q21.has_canonical_format:
- q21.sum_duplicates()
-
+ q21.sum_duplicates()
+
q31 = sp.sparse.coo_matrix(-lambda_1_2 / sigma_2**2 * Qu_list[2])
if not q31.has_canonical_format:
- q31.sum_duplicates()
-
+ q31.sum_duplicates()
+
q22 = sp.sparse.coo_matrix(
(1 / sigma_1**2) * Qu_list[1]
+ (lambda_0_2**2 / sigma_2**2) * Qu_list[2]
)
if not q22.has_canonical_format:
- q22.sum_duplicates()
+ q22.sum_duplicates()
Qu_list[1] = None
q32 = sp.sparse.coo_matrix(-lambda_0_2 / sigma_2**2 * Qu_list[2])
if not q32.has_canonical_format:
- q32.sum_duplicates()
-
+ q32.sum_duplicates()
+
q33 = sp.sparse.coo_matrix((1 / sigma_2**2) * Qu_list[2])
if not q33.has_canonical_format:
- q33.sum_duplicates()
+ q33.sum_duplicates()
Qu_list[2] = None
-
+
# not the most elegant way but im afraid that without the copy it might break in some cases
q12 = (q21.copy()).T
if not q12.has_canonical_format:
- q12.sum_duplicates()
-
+ q12.sum_duplicates()
+
q13 = (q31.copy()).T
if not q13.has_canonical_format:
- q13.sum_duplicates()
-
+ q13.sum_duplicates()
+
q23 = (q32.copy()).T
if not q23.has_canonical_format:
- q23.sum_duplicates()
-
- free_unused_gpu_memory()
+ q23.sum_duplicates()
+
+ free_unused_gpu_memory()
# we only need these indices once in the beginning
# then they can be none again and we can only collect data array
if self.data_Qprior_re is None:
q11_rows = q11.row
q11_columns = q11.col
-
+
q21_rows = q21.row + n_re
q21_columns = q21.col
-
+
q31_rows = q31.row + 2 * n_re
- q31_columns = q31.col
-
+ q31_columns = q31.col
+
q22_rows = q22.row + n_re
q22_columns = q22.col + n_re
-
+
q32_rows = q32.row + 2 * n_re
q32_columns = q32.col + n_re
-
+
q33_rows = q33.row + 2 * n_re
- q33_columns = q33.col + 2 * n_re
-
- ## CAREFUL IF THIS IS NOT A "TRUE" COPY ...
+ q33_columns = q33.col + 2 * n_re
+
+ ## CAREFUL IF THIS IS NOT A "TRUE" COPY ...
q12_rows = q12.row
q12_columns = q12.col + n_re
-
+
q13_rows = q13.row
q13_columns = q13.col + 2 * n_re
-
+
q23_rows = q23.row + n_re
q23_columns = q23.col + 2 * n_re
-
-
+
self.rows_Qprior_re = xp.concatenate(
[
q11_rows,
@@ -472,7 +484,7 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
q33_columns,
]
)
-
+
# this changes every time -> need to keep
self.data_Qprior_re = xp.concatenate(
[
@@ -486,15 +498,15 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
q32.data,
q33.data,
]
- )
-
- free_unused_gpu_memory()
+ )
+
+ free_unused_gpu_memory()
# Qprior_st = sp.sparse.bmat(
# [[q11, q12, q13], [q21, q22, q23], [q31, q32, q33]]
# ).tocsc()
-
- #Qprior_re = sp.sparse.coo_matrix((self.data_Qprior_re, (self.rows_Qprior_re, self.columns_Qprior_re)), shape=( self.n_models * n_re, self.n_models * n_re))
+
+ # Qprior_re = sp.sparse.coo_matrix((self.data_Qprior_re, (self.rows_Qprior_re, self.columns_Qprior_re)), shape=( self.n_models * n_re, self.n_models * n_re))
# Permute matrix
# Qprior_st_perm = Qprior_st[self.permutation_Qst, :][:, self.permutation_Qst]
@@ -504,7 +516,7 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
(self.n_models * n_re, self.n_models * n_re),
dtype=self.data_Qprior_re.dtype,
)
-
+
# self.permutation_Qst is the identity for spatial models
self.set_data_array_permutation_indices(
self.permutation_Qst,
@@ -512,72 +524,73 @@ def construct_Q_prior(self) -> sp.sparse.spmatrix:
self.columns_Qprior_re,
self.n_models * n_re,
)
-
+
# we only need to set these once
self.Qprior_re_perm.indices = self.permutation_indices_Q_prior
self.Qprior_re_perm.indptr = self.permutation_indptr_Q_prior
-
+
# dont need these anymore
self.rows_Qprior_re = None
self.columns_Qprior_re = None
-
- free_unused_gpu_memory()
-
+
+ free_unused_gpu_memory()
+
self.data_Qprior_re = self.data_Qprior_re[self.permutation_vector_Q_prior]
if Q_r != []:
if self.Q_prior is None:
self.Qprior_re_perm.data = self.data_Qprior_re
-
+
Qprior_reg = bdiag_tiling(Q_r).tocsc()
self.Q_prior = bdiag_tiling([self.Qprior_re_perm, Qprior_reg]).tocsc()
self.nnz_Qprior_re_perm = self.Qprior_re_perm.nnz
-
+
# free all memory not needed anymore
self.Qprior_re_perm = None
self.permutation_indices_Q_prior = None
self.permutation_indptr_Q_prior = None
- else:
- free_unused_gpu_memory()
-
+ else:
+ free_unused_gpu_memory()
+
self.Q_prior.tocsc()
self.Q_prior.sort_indices()
- self.Q_prior.data[:self.nnz_Qprior_re_perm] = self.data_Qprior_re
+ self.Q_prior.data[: self.nnz_Qprior_re_perm] = self.data_Qprior_re
else:
self.Q_prior = self.Qprior_re_perm
-
- free_unused_gpu_memory()
+
+ free_unused_gpu_memory()
return self.Q_prior
-
+
def spgemm(self, A, B, rows: int = 5408):
-
- free_unused_gpu_memory()
-
+ free_unused_gpu_memory()
+
C = None
for i in range(0, A.shape[0], rows):
- A_block = A[i:min(A.shape[0], i+rows)]
+ A_block = A[i : min(A.shape[0], i + rows)]
C_block = A_block @ B
if C is None:
C = C_block
else:
C = sp.sparse.vstack([C, C_block], format="csr")
-
- free_unused_gpu_memory()
+
+ free_unused_gpu_memory()
return C.tocsc()
-
- def custom_Q_ATDA(self, Q: sp.sparse.csc_matrix, A: sp.sparse.csc_matrix, D_diag: xp.ndarray) -> sp.sparse.csr_matrix:
+
+ def custom_Q_ATDA(
+ self, Q: sp.sparse.csc_matrix, A: sp.sparse.csc_matrix, D_diag: xp.ndarray
+ ) -> sp.sparse.csr_matrix:
"""
Computes A^T * D * A with minimal memory and maximum speed.
- Uses sparse diagonal multiplication.
- No temporary dense matrices.
"""
-
- DA = A.multiply(D_diag[:, xp.newaxis]).T.tocsr()
-
- mem_used_bytes = free_unused_gpu_memory()
-
+
+ DA = A.multiply(D_diag[:, xp.newaxis]).T.tocsr()
+
+ mem_used_bytes = free_unused_gpu_memory()
+
# use batched spgemm if mempool is full
if mem_used_bytes > 80 * 1024**3:
batch_size = int(xp.ceil(A.shape[0] / 2))
@@ -585,11 +598,11 @@ def custom_Q_ATDA(self, Q: sp.sparse.csc_matrix, A: sp.sparse.csc_matrix, D_diag
batch_size = A.shape[0]
ATDA = self.spgemm(DA, A, rows=batch_size)
- free_unused_gpu_memory()
-
- self.Qconditional = Q - ATDA
- free_unused_gpu_memory()
-
+ free_unused_gpu_memory()
+
+ self.Qconditional = Q - ATDA
+ free_unused_gpu_memory()
+
return self.Qconditional
def construct_Q_conditional(
@@ -605,14 +618,14 @@ def construct_Q_conditional(
"""
d_vec = xp.zeros(self.n_observations)
-
+
for i, model in enumerate(self.models):
if model.likelihood_config.type == "gaussian":
kwargs = {
"eta": eta[
self.n_observations_idx[i] : self.n_observations_idx[i + 1]
],
- "theta": float(self.theta[self.hyperparameters_idx[i + 1] - 1]),
+ "theta": float(self.theta_external[self.hyperparameters_idx[i + 1] - 1]),
}
else:
kwargs = {
@@ -621,12 +634,10 @@ def construct_Q_conditional(
],
}
- #d_list[i] = model.likelihood.evaluate_hessian_likelihood(**kwargs)
- d_vec[
- self.n_observations_idx[i] : self.n_observations_idx[i + 1]
- ] = model.likelihood.evaluate_hessian_likelihood(
- **kwargs
- ).diagonal()
+ # d_list[i] = model.likelihood.evaluate_hessian_likelihood(**kwargs)
+ d_vec[self.n_observations_idx[i] : self.n_observations_idx[i + 1]] = (
+ model.likelihood.evaluate_hessian_likelihood(**kwargs).diagonal()
+ )
self.Qconditional = self.custom_Q_ATDA(
Q=self.Q_prior,
@@ -634,8 +645,8 @@ def construct_Q_conditional(
D_diag=d_vec,
)
self.Q_conditional = self.Qconditional.tocsc()
- free_unused_gpu_memory()
-
+ free_unused_gpu_memory()
+
return self.Q_conditional
def construct_information_vector(
@@ -650,7 +661,7 @@ def construct_information_vector(
gradient_likelihood = model.likelihood.evaluate_gradient_likelihood(
eta=eta[self.n_observations_idx[i] : self.n_observations_idx[i + 1]],
y=self.y[self.n_observations_idx[i] : self.n_observations_idx[i + 1]],
- theta=float(self.theta[self.hyperparameters_idx[i + 1] - 1]),
+ theta=float(self.theta_external[self.hyperparameters_idx[i + 1] - 1]),
)
gradient_vector_list.append(gradient_likelihood)
@@ -664,7 +675,13 @@ def construct_information_vector(
return information_vector
def is_likelihood_gaussian(self) -> bool:
- """Check if the likelihood is Gaussian."""
+ """Check if the likelihood is Gaussian.
+
+ Returns
+ -------
+ is_gaussian : bool
+ True if the likelihood is Gaussian, False otherwise.
+ """
for model in self.models:
if not model.is_likelihood_gaussian():
return False
@@ -674,21 +691,40 @@ def evaluate_likelihood(
self,
eta: NDArray,
) -> float:
+ """Evaluate the likelihood.
+
+ Parameters
+ ----------
+ eta : NDArray
+ Linear predictor.
+ kwargs : dict
+ Additional arguments for the likelihood evaluation. These parameters are model dependent.
+
+ Returns
+ -------
+ likelihood : float
+ The evaluated likelihood.
+
+ Implementation Notes:
+ ---------------------
+ - The likelihood is evaluated for each model and then summed up to get the total likelihood of the CoregionalModel.
+ - Returned as a scalar for consistency, even if the likelihood is computed as a sum of multiple likelihoods from different models.
+ """
likelihood: float = 0.0
for i, model in enumerate(self.models):
likelihood += model.likelihood.evaluate_likelihood(
eta=eta[self.n_observations_idx[i] : self.n_observations_idx[i + 1]],
y=self.y[self.n_observations_idx[i] : self.n_observations_idx[i + 1]],
- theta=float(self.theta[self.hyperparameters_idx[i + 1] - 1]),
+ theta=float(self.theta_external[self.hyperparameters_idx[i + 1] - 1]),
)
-
- return likelihood
+
+ return ensure_scalar(likelihood)
def evaluate_log_prior_hyperparameters(self) -> float:
"""Evaluate the log prior hyperparameters."""
log_prior = 0.0
- theta_interpret = self.theta
+ theta_interpret = self.theta_external
for i, prior_hyperparameter in enumerate(self.prior_hyperparameters):
log_prior += prior_hyperparameter.evaluate_log_prior(theta_interpret[i])
@@ -700,13 +736,23 @@ def __str__(self) -> str:
str_representation = ""
# --- Make the Coregional Model() table ---
- headers = ["Number of Hyperparameters", "Number of Latent Parameters", "Number of Observations"]
+ headers = [
+ "Number of Hyperparameters",
+ "Number of Latent Parameters",
+ "Number of Observations",
+ ]
values = [self.n_hyperparameters, self.n_latent_parameters, self.n_observations]
- model_table = tabulate([headers, values], tablefmt="fancy_grid", colalign=("center", "center", "center"))
+ model_table = tabulate(
+ [headers, values],
+ tablefmt="fancy_grid",
+ colalign=("center", "center", "center"),
+ )
# Add the header title
- model_table = add_str_header(f"Coregional Model ({self.n_models} variates)", model_table)
+ model_table = add_str_header(
+ f"Coregional Model ({self.n_models} variates)", model_table
+ )
# --- Add the model information ---
# Create headers and values for the model table
@@ -715,10 +761,14 @@ def __str__(self) -> str:
models_str_representation.append(str(model))
# Create the model table
- model_jointed_representation = align_tables_side_by_side(models_str_representation)
+ model_jointed_representation = align_tables_side_by_side(
+ models_str_representation
+ )
# Add the model header title
- model_jointed_representation = add_str_header("Models", model_jointed_representation)
+ model_jointed_representation = add_str_header(
+ "Models", model_jointed_representation
+ )
# Combine the model and model tables
str_representation = model_table + "\n" + boxify(model_jointed_representation)
@@ -917,42 +967,44 @@ def _generate_permutation_indices_for_a(
def set_data_array_permutation_indices(
self, permutation, a_rows: NDArray, a_cols: NDArray, n: int
) -> None:
-
a_data_placeholder = xp.arange(0, len(a_rows), 1)
a = sp.sparse.csc_matrix(
- sp.sparse.coo_matrix((a_data_placeholder, (a_rows, a_cols)), shape=(n, n), dtype=xp.float64)
+ sp.sparse.coo_matrix(
+ (a_data_placeholder, (a_rows, a_cols)), shape=(n, n), dtype=xp.float64
+ )
)
a_perm = a[permutation, :][:, permutation]
- a_perm.sort_indices() ## new
-
+ a_perm.sort_indices() ## new
+
self.permutation_vector_Q_prior = a_perm.data.astype(xp.int32)
self.permutation_indices_Q_prior = a_perm.indices
self.permutation_indptr_Q_prior = a_perm.indptr
-
-
+
def construct_a_predict(self) -> sp.sparse.spmatrix:
-
# iterate through the models to load their respective a_predict
for i, model in enumerate(self.models):
model.construct_a_predict()
-
- self.a_predict: sp.sparse.spmatrix = bdiag_tiling([model.a_predict for model in self.models]).tocsc()
-
+
+ self.a_predict: sp.sparse.spmatrix = bdiag_tiling(
+ [model.a_predict for model in self.models]
+ ).tocsc()
+
# Reorder a_predict in the same way as a
self.a_predict = self.a_predict[:, self.permutation_latent_variables]
-
+
return self.a_predict
-
- def compare_matrices(self, a1_data_vec, a1_indices, a1_indptr, a2_data_vec, a2_indices, a2_indptr):
+ def compare_matrices(
+ self, a1_data_vec, a1_indices, a1_indptr, a2_data_vec, a2_indices, a2_indptr
+ ):
"""
Compare two sparse matrices represented by their data vectors, indices, and indptr arrays.
"""
# Check if the shapes of the matrices are the same
# if len(a1_data_vec) != len(a2_data_vec):
# return False
-
+
# Check if the indices arrays are equal
if not xp.array_equal(a1_indices, a2_indices):
print("indices arrays are not equal")
@@ -968,7 +1020,7 @@ def compare_matrices(self, a1_data_vec, a1_indices, a1_indptr, a2_data_vec, a2_i
if ptr1 != ptr2:
print(f"Indptr arrays differ at index {i}: {ptr1} != {ptr2}")
return False
-
+
if not xp.array_equal(a1_data_vec, a2_data_vec):
print("data vectors are not equal")
print("theta: ", self.theta)
@@ -978,7 +1030,6 @@ def compare_matrices(self, a1_data_vec, a1_indices, a1_indptr, a2_data_vec, a2_i
return False
return True
-
def get_solver_parameters(self) -> dict:
"""Get the solver parameters."""
@@ -996,4 +1047,4 @@ def get_solver_parameters(self) -> dict:
def total_number_fixed_effects(self) -> int:
"""Get the number of fixed effects."""
- return self.n_fixed_effects_per_model * self.n_models
\ No newline at end of file
+ return self.n_fixed_effects_per_model * self.n_models
diff --git a/src/dalia/prior_hyperparameters/__init__.py b/src/dalia/prior_hyperparameters/__init__.py
index b561764e..274d58cc 100644
--- a/src/dalia/prior_hyperparameters/__init__.py
+++ b/src/dalia/prior_hyperparameters/__init__.py
@@ -6,10 +6,14 @@
PenalizedComplexityPriorHyperparameters,
)
from dalia.prior_hyperparameters.beta import BetaPriorHyperparameters
+from dalia.prior_hyperparameters.gamma import GammaPriorHyperparameters
+from dalia.prior_hyperparameters.inverse_gamma import InverseGammaPriorHyperparameters
__all__ = [
"GaussianPriorHyperparameters",
"GaussianMVNPriorHyperparameters",
"PenalizedComplexityPriorHyperparameters",
"BetaPriorHyperparameters",
+ "GammaPriorHyperparameters",
+ "InverseGammaPriorHyperparameters",
]
diff --git a/src/dalia/prior_hyperparameters/beta.py b/src/dalia/prior_hyperparameters/beta.py
index c21ca175..ad70a253 100644
--- a/src/dalia/prior_hyperparameters/beta.py
+++ b/src/dalia/prior_hyperparameters/beta.py
@@ -1,11 +1,12 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
import scipy.stats as stats
-
+from dalia import sp, xp
from dalia.configs.priorhyperparameters_config import (
GaussianPriorHyperparametersConfig,
)
from dalia.core.prior_hyperparameters import PriorHyperparameters
+from dalia.utils.link_functions import scaled_logit
class BetaPriorHyperparameters(PriorHyperparameters):
@@ -21,6 +22,24 @@ def __init__(
self.alpha: float = config.alpha
self.beta: float = config.beta
+
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+
+ ### TODO: on longer term make scaled_logit default but let it be configurable in config
+ ## beta prior is defined on [0,1], while BFGS works on (-inf, inf)
+ if direction == "forward":
+ theta_scaled = scaled_logit(theta, direction="forward")
+ elif direction == "backward":
+ theta_scaled = scaled_logit(theta, direction="backward")
+ elif direction == "forward_jacobian":
+ theta_scaled = scaled_logit(theta, direction="forward_jacobian")
+ elif direction == "backward_jacobian":
+ theta_scaled = scaled_logit(theta, direction="backward_jacobian")
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ return theta_scaled
+
def evaluate_log_prior(self, theta: float, **kwargs) -> float:
"""Evaluate the log prior hyperparameters."""
@@ -29,5 +48,15 @@ def evaluate_log_prior(self, theta: float, **kwargs) -> float:
"Beta distribution is defined on the interval [0, 1]. theta: {theta}"
)
- log_prior = stats.beta.logpdf(theta, self.alpha, self.beta)
+ log_beta = (
+ sp.special.gammaln(self.alpha)
+ + sp.special.gammaln(self.beta)
+ - sp.special.gammaln(self.alpha + self.beta)
+ )
+ log_prior = (
+ (self.alpha - 1) * xp.log(theta)
+ + (self.beta - 1) * xp.log(1 - theta)
+ - log_beta
+ )
+
return log_prior
diff --git a/src/dalia/prior_hyperparameters/gamma.py b/src/dalia/prior_hyperparameters/gamma.py
new file mode 100644
index 00000000..61dc966c
--- /dev/null
+++ b/src/dalia/prior_hyperparameters/gamma.py
@@ -0,0 +1,347 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+from dalia import NDArray
+from scipy.sparse import spmatrix
+from dalia import sp, xp
+
+import numpy as np
+
+from dalia.configs.priorhyperparameters_config import (
+ GammaPriorHyperparametersConfig,
+)
+from dalia.core.prior_hyperparameters import PriorHyperparameters
+
+
+class GammaPriorHyperparameters(PriorHyperparameters):
+ """Gamma prior hyperparameters.
+
+ p(theta) = (beta^alpha / Gamma(alpha)) * theta^(alpha - 1) * exp(-beta * theta)
+
+ and in log scale:
+ log p(theta) = alpha * log(beta) - log(Gamma(alpha)) + (alpha - 1) * log(theta) - beta * theta
+
+ where theta is typically a positive parameter such as a precision or rate.
+
+ Parameters
+ ----------
+ config : GammaPriorHyperparametersConfig
+ Configuration object containing alpha and beta parameters.
+
+ Attributes
+ ----------
+ alpha : float. alpha > 0
+ Shape parameter of the Gamma distribution.
+ beta : float. beta > 0
+ Rate parameter of the Gamma distribution.
+ normalizing_constant : float
+ Precomputed normalizing constant for log probability evaluation.
+ """
+
+ def __init__(
+ self,
+ config: GammaPriorHyperparametersConfig,
+ ) -> None:
+ """
+ Initialize the Gamma prior hyperparameters.
+
+ Parameters
+ ----------
+ config : GammaPriorHyperparametersConfig
+ Configuration containing alpha (shape) and beta (rate) parameters.
+
+ Raises
+ ------
+ ValueError
+ If alpha or beta are not positive.
+ """
+ super().__init__(config)
+
+ self.alpha: float = config.alpha
+ self.beta: float = config.beta
+
+ # Validate alpha and beta are positive
+ if self.alpha <= 0:
+ raise ValueError(f"Alpha must be positive, got {self.alpha}")
+ if self.beta <= 0:
+ raise ValueError(f"Beta must be positive, got {self.beta}")
+
+ self.normalizing_constant: float = self.alpha * xp.log(self.beta) - float(
+ sp.special.gammaln(self.alpha)
+ )
+
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """
+ Transform between external and internal parameter representations.
+
+ The Gamma distribution is defined for positive values, but optimization
+ often works better in unconstrained space. This method transforms
+ between theta (positive) and log(theta) (unconstrained).
+
+ Parameters
+ ----------
+ theta : float or NDArray
+ Parameter value(s) to transform.
+ direction : str
+ Transformation direction:
+ - "forward": theta -> log(theta) (external to internal)
+ - "backward": log(theta) -> theta (internal to external)
+
+ Returns
+ -------
+ float or NDArray
+ Transformed parameter value(s).
+
+ Raises
+ ------
+ ValueError
+ If direction is not "forward" or "backward".
+ """
+ if direction == "forward":
+ theta_scaled = xp.log(theta)
+ elif direction == "backward":
+ theta_scaled = xp.exp(theta)
+ elif direction == "forward_jacobian":
+ theta_scaled = 1 / theta # d(log(theta))/d(theta) = 1/theta
+ elif direction == "backward_jacobian":
+ theta_scaled = theta # d(exp(theta))/d(theta) = exp(theta) = theta
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ return theta_scaled
+
+ def evaluate_log_prior(self, theta: float, **kwargs) -> float:
+ """
+ Evaluate the log prior probability density.
+
+ Computes the log probability density of the Gamma distribution
+ at the given theta value in its external/user representation.
+
+ Parameters
+ ----------
+ theta : float
+ Parameter value at which to evaluate the log prior.
+ Must be positive (external/user representation).
+ **kwargs
+ Additional keyword arguments (unused).
+
+ Returns
+ -------
+ float
+ Log prior probability density at theta.
+
+ Notes
+ -----
+ The computation follows:
+ log p(θ) = C + (α - 1) * log(θ) - β * θ
+ where C is the normalizing constant, α is the shape parameter,
+ and β is the rate parameter.
+
+ Raises
+ ------
+ ValueError
+ If theta is not positive (implicitly through log computation).
+ """
+
+ log_prior = (
+ self.normalizing_constant
+ + (self.alpha - 1) * xp.log(theta)
+ - self.beta * theta
+ )
+
+ return log_prior
+
+if __name__ == "__main__":
+ """
+ Test Gaussian quadrature for functions using rescale_hyperparameters_to_internal()
+ from gamma prior hyperparameters.
+
+ We start with normally distributed random variables in internal space (unconstrained)
+ that get reparametrized to external space (positive) using the gamma prior's
+ rescaling function.
+ """
+
+ from dalia.utils.gaussian_quadrature import compute_variance_gauss_hermite
+
+ print("=" * 80)
+ print("Testing Gaussian Quadrature with Gamma Prior Rescaling")
+ print("=" * 80)
+
+ # Create a inverse gamma prior configuration
+ alpha_values = [1.0, 3.0, 5.0]
+ beta_values = [0.5, 1.0, 2.0]
+
+ for alpha, beta in zip(alpha_values, beta_values):
+ print(f"\nTesting alpha={alpha}, beta={beta}")
+ config = GammaPriorHyperparametersConfig(alpha=alpha, beta=beta)
+ gamma_prior = GammaPriorHyperparameters(config=config)
+
+ ## compare against scipy implementation
+ from scipy.stats import gamma
+
+ test_values = [0.1, 0.5, 1.0, 2.0, 5.0]
+ print("Comparing log prior evaluations with scipy.stats.gamma:")
+ for val in test_values:
+ logp_dalia = gamma_prior.evaluate_log_prior(val)
+ ## note: scipy's gamma takes scale = 1/beta
+ logp_scipy = gamma.logpdf(val, a=alpha, scale=1/beta)
+ print(f" θ = {val:4.1f}: DALIA logp = {logp_dalia:.6f}, "
+ f"scipy logp = {logp_scipy:.6f}, diff = {abs(logp_dalia - logp_scipy):.2e}")
+ if abs(logp_dalia - logp_scipy) > 1e-6:
+ raise ValueError("Log prior evaluation does not match scipy implementation.")
+
+ print()
+ print("All tests passed!")
+
+ # Create a gamma prior configuration
+ config = GammaPriorHyperparametersConfig(alpha=2.0, beta=1.0)
+ gamma_prior = GammaPriorHyperparameters(config=config)
+
+ # Define parameters for the normal distribution in internal space
+ # These represent log(theta) where theta > 0 is the gamma-distributed parameter
+ mean_internal = 0.5 # Mean of log(theta)
+ variance_internal = 0.25 # Variance of log(theta)
+
+ print(f"Internal space (log-scale) parameters:")
+ print(f" Mean: {mean_internal}")
+ print(f" Variance: {variance_internal}")
+ print(f" Standard deviation: {xp.sqrt(variance_internal)}")
+ print()
+
+ # Test 1: Compute statistics using Gaussian quadrature
+ print("1. Computing statistics using Gaussian quadrature:")
+
+ # Use the rescaling function as the transform
+ def transform_func(x, direction):
+ return gamma_prior.rescale_hyperparameters_to_internal(x, direction)
+
+ # Compute statistics using different numbers of quadrature points
+ for n_points in [10, 20, 30, 50]:
+ result = compute_variance_gauss_hermite(
+ mean_internal,
+ variance_internal,
+ transform_func,
+ n_points=n_points
+ )
+
+ print(f" n_points = {n_points:2d}: Mean = {result['mean']:.6f}, "
+ f"Std = {result['std']:.6f}, Var = {result['variance']:.6f}")
+
+ print()
+
+ # Test 2: Compare with analytical solution
+ print("2. Comparison with analytical log-normal distribution:")
+ print(" For log(Y) ~ N(μ, σ²), we have:")
+ print(" E[Y] = exp(μ + σ²/2)")
+ print(" Var[Y] = (exp(σ²) - 1) * exp(2μ + σ²)")
+
+ # Analytical moments for log-normal distribution
+ mu = mean_internal
+ sigma2 = variance_internal
+
+ analytical_mean = xp.exp(mu + sigma2/2)
+ analytical_variance = (xp.exp(sigma2) - 1) * xp.exp(2*mu + sigma2)
+ analytical_std = xp.sqrt(analytical_variance)
+
+ print(f" Analytical mean: {analytical_mean:.6f}")
+ print(f" Analytical std: {analytical_std:.6f}")
+ print(f" Analytical var: {analytical_variance:.6f}")
+ print()
+
+ # Compare with quadrature result (using 50 points)
+ quad_result = compute_variance_gauss_hermite(
+ mean_internal, variance_internal, transform_func, n_points=50
+ )
+
+ print("3. Comparison of quadrature vs analytical:")
+ print(f" Mean difference: {abs(quad_result['mean'] - analytical_mean):.2e}")
+ print(f" Std difference: {abs(quad_result['std'] - analytical_std):.2e}")
+ print(f" Var difference: {abs(quad_result['variance'] - analytical_variance):.2e}")
+
+ # Relative errors
+ mean_rel_error = abs(quad_result['mean'] - analytical_mean) / analytical_mean
+ std_rel_error = abs(quad_result['std'] - analytical_std) / analytical_std
+ var_rel_error = abs(quad_result['variance'] - analytical_variance) / analytical_variance
+
+ print(f" Mean rel. error: {mean_rel_error:.2e}")
+ print(f" Std rel. error: {std_rel_error:.2e}")
+ print(f" Var rel. error: {var_rel_error:.2e}")
+ print()
+
+ # Test 3: Test with different internal parameters
+ print("4. Testing with different internal parameters:")
+
+ test_cases = [
+ {"mean": 0.0, "var": 0.1, "name": "Small variance"},
+ {"mean": 1.0, "var": 0.5, "name": "Medium variance"},
+ {"mean": -0.5, "var": 1.0, "name": "Large variance"},
+ {"mean": 2.0, "var": 0.01, "name": "Large mean, small variance"}
+ ]
+
+ for case in test_cases:
+ mu_test = case["mean"]
+ var_test = case["var"]
+
+ # Quadrature result
+ quad_result = compute_variance_gauss_hermite(
+ mu_test, var_test, transform_func, n_points=30
+ )
+
+ # Analytical result
+ anal_mean = xp.exp(mu_test + var_test/2)
+ anal_var = (xp.exp(var_test) - 1) * xp.exp(2*mu_test + var_test)
+
+ rel_mean_error = abs(quad_result['mean'] - anal_mean) / anal_mean
+ rel_var_error = abs(quad_result['variance'] - anal_var) / anal_var
+
+ print(f" {case['name']:25s}: Mean rel. err = {rel_mean_error:.2e}, "
+ f"Var rel. err = {rel_var_error:.2e}")
+
+ print()
+
+ # Test 4: Test the rescaling function directions
+ print("5. Testing rescaling function directions:")
+
+ # Test some values
+ test_values = [0.1, 0.5, 1.0, 2.0, 5.0]
+
+ print(" Testing forward (external -> internal) and backward (internal -> external):")
+ for theta in test_values:
+ # Forward: theta -> log(theta)
+ log_theta = gamma_prior.rescale_hyperparameters_to_internal(theta, "forward")
+
+ # Backward: log(theta) -> theta
+ theta_recovered = gamma_prior.rescale_hyperparameters_to_internal(log_theta, "backward")
+
+ error = abs(theta - theta_recovered)
+ print(f" θ = {theta:4.1f} -> log(θ) = {log_theta:6.3f} -> θ = {theta_recovered:6.3f}, "
+ f"error = {error:.2e}")
+
+ print()
+
+ # Test 5: Convergence study
+ print("6. Convergence study (increasing number of quadrature points):")
+
+ mu_conv = 0.3
+ var_conv = 0.4
+ analytical_mean_conv = xp.exp(mu_conv + var_conv/2)
+
+ n_points_list = [5, 10, 15, 20, 25, 30, 40, 50, 75, 100]
+
+ print(" n_points | Mean | Rel. Error")
+ print(" -----------|-------------|------------")
+
+ for n in n_points_list:
+ result = compute_variance_gauss_hermite(
+ mu_conv, var_conv, transform_func, n_points=n
+ )
+ rel_error = abs(result['mean'] - analytical_mean_conv) / analytical_mean_conv
+
+ print(f" {n:8d} | {result['mean']:10.6f} | {rel_error:.3e}")
+
+ print()
+ print("=" * 80)
+ print("Test completed successfully!")
+ print("All tests show that Gaussian quadrature accurately approximates")
+ print("the moments of log-normal distributions obtained through gamma")
+ print("prior rescaling transformations.")
+ print("=" * 80)
+
diff --git a/src/dalia/prior_hyperparameters/gaussian.py b/src/dalia/prior_hyperparameters/gaussian.py
index d18215ed..f7e5b227 100644
--- a/src/dalia/prior_hyperparameters/gaussian.py
+++ b/src/dalia/prior_hyperparameters/gaussian.py
@@ -1,9 +1,9 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
-from dalia import NDArray
-from scipy.sparse import spmatrix
-
import numpy as np
+from scipy.sparse import spmatrix
+from dalia import sp, xp
+from dalia import NDArray
from dalia.configs.priorhyperparameters_config import (
GaussianPriorHyperparametersConfig,
)
@@ -11,19 +11,104 @@
class GaussianPriorHyperparameters(PriorHyperparameters):
- """Gaussian prior hyperparameters."""
+ """
+ Univariate Gaussian prior hyperparameters.
+
+ This class implements prior hyperparameters following a univariate normal
+ (Gaussian) distribution with specified mean and precision (inverse variance).
+
+ Parameters
+ ----------
+ config : GaussianPriorHyperparametersConfig
+ Configuration object containing mean and precision parameters.
+
+ Attributes
+ ----------
+ mean : float
+ Mean of the Gaussian distribution.
+ precision : float
+ Precision (inverse variance) of the Gaussian distribution.
+ normalizing_constant : float
+ Precomputed normalizing constant for log probability evaluation.
+ """
def __init__(
self,
config: GaussianPriorHyperparametersConfig,
) -> None:
- """Initializes the Gaussian prior hyperparameters."""
+ """
+ Initialize the Gaussian prior hyperparameters.
+
+ Parameters
+ ----------
+ config : GaussianPriorHyperparametersConfig
+ Configuration containing mean and precision parameters.
+
+ Raises
+ ------
+ ValueError
+ If the precision is not positive.
+ """
super().__init__(config)
self.mean: float = config.mean
self.precision: float = config.precision
+ # Validate precision is positive
+ if self.precision <= 0:
+ raise ValueError(f"Precision must be positive, got {self.precision}")
+
+ self.normalizing_constant = -0.5 * xp.log(2 * xp.pi) + 0.5 * xp.log(
+ self.precision
+ )
+
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """
+ Rescale hyperparameters between internal and external representations.
+
+ Parameters
+ ----------
+ theta : NDArray or float
+ Hyperparameter values to rescale.
+ direction : str
+ Direction of rescaling ('forward' or 'backward', i.e. from interpretable/user/external to internal or vice versa).
+
+ Returns
+ -------
+ NDArray or float
+ Rescaled hyperparameter values. In this case it is the identity, therefore unchanged.
+
+ Notes
+ -----
+ For Gaussian priors, the rescaling is the identity function since the internal and external representations are the same.
+ """
+ return super().rescale_hyperparameters_to_internal(theta, direction)
+
def evaluate_log_prior(self, theta: float, **kwargs) -> float:
- """Evaluate the log prior hyperparameters."""
+ """
+ Evaluate the log prior probability density.
+
+ Computes the log probability density of the univariate normal
+ distribution at the given theta value.
+
+ Parameters
+ ----------
+ theta : float
+ Parameter value at which to evaluate the log prior.
+ **kwargs
+ Additional keyword arguments (unused).
+
+ Returns
+ -------
+ float
+ Log prior probability density at theta.
- return -0.5 * self.precision * (theta - self.mean) ** 2
+ Notes
+ -----
+ The computation follows:
+ log p(θ) = C - 0.5 * τ * (θ - μ)²
+ where C is the normalizing constant, τ is precision, and μ is the mean.
+ """
+ return (
+ self.normalizing_constant - 0.5 * self.precision * (theta - self.mean) ** 2
+ )
diff --git a/src/dalia/prior_hyperparameters/gaussian_mvn.py b/src/dalia/prior_hyperparameters/gaussian_mvn.py
index 8310ef10..bbec0e6e 100644
--- a/src/dalia/prior_hyperparameters/gaussian_mvn.py
+++ b/src/dalia/prior_hyperparameters/gaussian_mvn.py
@@ -1,10 +1,8 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
-from dalia import NDArray
-from scipy.sparse import spmatrix
-
import numpy as np
-from dalia import sp, xp
+from scipy.sparse import spmatrix
+from dalia import NDArray, sp, xp
from dalia.configs.priorhyperparameters_config import (
GaussianMVNPriorHyperparametersConfig,
)
@@ -12,13 +10,44 @@
class GaussianMVNPriorHyperparameters(PriorHyperparameters):
- """Gaussian MVN prior hyperparameters."""
+ """
+ Gaussian multivariate normal (MVN) prior hyperparameters.
+
+ This class implements prior hyperparameters following a multivariate normal
+ distribution with specified mean and precision matrix.
+
+ Parameters
+ ----------
+ config : GaussianMVNPriorHyperparametersConfig
+ Configuration object containing mean and precision matrix.
+
+ Attributes
+ ----------
+ mean : NDArray
+ Mean vector of the multivariate normal distribution.
+ precision : spmatrix
+ Precision matrix (inverse covariance) of the distribution.
+ normalizing_constant : float
+ Precomputed normalizing constant for log probability evaluation.
+ """
def __init__(
self,
config: GaussianMVNPriorHyperparametersConfig,
) -> None:
- """Initializes the Gaussian MVN prior hyperparameters."""
+ """
+ Initialize the Gaussian MVN prior hyperparameters.
+
+ Parameters
+ ----------
+ config : GaussianMVNPriorHyperparametersConfig
+ Configuration containing mean vector and precision matrix.
+
+ Raises
+ ------
+ ValueError
+ If the precision matrix is not positive definite.
+ """
super().__init__(config)
self.mean: NDArray = config.mean
@@ -31,8 +60,58 @@ def __init__(
self.mean: NDArray = xp.asarray(self.mean)
self.precision: sp.sparse.spmatrix = sp.sparse.csc_matrix(self.precision)
- def evaluate_log_prior(self, theta: float, **kwargs) -> float:
- """Evaluate the log prior hyperparameters."""
+ sign, logabsdet = np.linalg.slogdet(self.precision.toarray())
+ if sign != 1:
+ raise ValueError("Precision matrix must be positive definite.")
+
+ self.normalizing_constant = (
+ -0.5 * self.mean.shape[0] * xp.log(2 * xp.pi) + 0.5 * logabsdet
+ )
+ print("Normalizing constant: ", self.normalizing_constant)
+
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """
+ Rescale hyperparameters between internal and external/user representations.
+
+ Parameters
+ ----------
+ theta : NDArray
+ Hyperparameter values to rescale.
+ direction : str
+ Direction of rescaling ('forward' or 'backward').
+
+ Returns
+ -------
+ NDArray
+ Rescaled hyperparameter values, which is the identity in this case, therefore unchanged.
+ """
+ return super().rescale_hyperparameters_to_internal(theta, direction)
+
+ def evaluate_log_prior(self, theta: NDArray, **kwargs) -> float:
+ """
+ Evaluate the log prior probability density.
+
+ Computes the log probability density of the multivariate normal
+ distribution at the given theta values.
+
+ Parameters
+ ----------
+ theta : NDArray
+ Parameter values at which to evaluate the log prior.
+ Must have the same shape as the mean vector.
+ **kwargs
+ Additional keyword arguments (unused).
+
+ Returns
+ -------
+ float
+ Log prior probability density at theta.
+
+ Raises
+ ------
+ ValueError
+ If theta and mean have incompatible shapes.
+ """
# TODO: add check in config or somewhere else that dim(theta) and dim(mean) match
if self.mean.shape != theta.shape:
@@ -41,7 +120,11 @@ def evaluate_log_prior(self, theta: float, **kwargs) -> float:
)
if isinstance(self.mean, float):
- return -0.5 * self.precision * (theta - self.mean) ** 2
+ return (
+ self.normalizing_constant
+ - 0.5 * self.precision * (theta - self.mean) ** 2
+ )
else:
- # neglect constant as the precision is fixed
- return -0.5 * (theta - self.mean).T @ self.precision @ (theta - self.mean)
+ return self.normalizing_constant - 0.5 * (
+ theta - self.mean
+ ).T @ self.precision @ (theta - self.mean)
diff --git a/src/dalia/prior_hyperparameters/inverse_gamma.py b/src/dalia/prior_hyperparameters/inverse_gamma.py
new file mode 100644
index 00000000..aa9bc403
--- /dev/null
+++ b/src/dalia/prior_hyperparameters/inverse_gamma.py
@@ -0,0 +1,192 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+from dalia import sp, xp
+
+import numpy as np
+
+from dalia.configs.priorhyperparameters_config import (
+ InverseGammaPriorHyperparametersConfig,
+)
+from dalia.core.prior_hyperparameters import PriorHyperparameters
+
+
+class InverseGammaPriorHyperparameters(PriorHyperparameters):
+ """Inverse Gamma prior hyperparameters.
+
+ p(theta) = (beta^alpha / Gamma(alpha)) * (1/theta)^(alpha + 1) * exp(-beta / theta)
+
+ and in log scale:
+ log p(theta) = alpha * log(beta) - log(Gamma(alpha)) - (alpha + 1) * log(theta) - beta / theta
+
+ where theta is typically a positive parameter such as a variance.
+
+ Parameters
+ ----------
+ config : InverseGammaPriorHyperparametersConfig
+ Configuration object containing alpha and beta parameters.
+
+ Attributes
+ ----------
+ alpha : float. alpha > 0
+ Shape parameter of the Gamma distribution.
+ beta : float. beta > 0
+ Rate parameter of the Gamma distribution.
+ normalizing_constant : float
+ Precomputed normalizing constant for log probability evaluation.
+ """
+
+ def __init__(
+ self,
+ config: InverseGammaPriorHyperparametersConfig,
+ ) -> None:
+ """
+ Initialize the Inverse Gamma prior hyperparameters.
+
+ Parameters
+ ----------
+ config : InverseGammaPriorHyperparametersConfig
+ Configuration containing alpha (shape) and beta (rate) parameters.
+
+ Raises
+ ------
+ ValueError
+ If alpha or beta are not positive.
+ """
+ super().__init__(config)
+
+ self.alpha: float = config.alpha
+ self.beta: float = config.beta
+
+ # Validate alpha and beta are positive
+ if self.alpha <= 0:
+ raise ValueError(f"Alpha must be positive, got {self.alpha}")
+ if self.beta <= 0:
+ raise ValueError(f"Beta must be positive, got {self.beta}")
+
+ self.normalizing_constant: float = self.alpha * xp.log(self.beta) - float(
+ sp.special.gammaln(self.alpha)
+ )
+
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """
+ Transform between external and internal parameter representations.
+
+ The Inverse Gamma distribution is defined for positive values, but the optimization
+ happens in an unconstrained space. This method transforms
+ between theta (positive) and log(theta) (unconstrained).
+
+ Parameters
+ ----------
+ theta : float or NDArray
+ Parameter value(s) to transform.
+ direction : str
+ Transformation direction:
+ - "forward": theta -> log(theta) (external to internal)
+ - "backward": log(theta) -> theta (internal to external)
+ - "forward jacobian": derivative of forward transformation
+ - "backward jacobian": 1 / derivative of forward transformation
+
+ Returns
+ -------
+ float or NDArray
+ Transformed parameter value(s).
+
+ Raises
+ ------
+ ValueError
+ If direction is not "forward" or "backward".
+ """
+ if direction == "forward":
+ theta_scaled = xp.log(theta)
+ elif direction == "backward":
+ theta_scaled = xp.exp(theta)
+ elif direction == "forward_jacobian":
+ theta_scaled = 1 / theta # d(log(theta))/d(theta) = 1/theta
+ elif direction == "backward_jacobian":
+ theta_scaled = theta
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ return theta_scaled
+
+ def evaluate_log_prior(self, theta: float, **kwargs) -> float:
+ """
+ Evaluate the log prior probability density.
+
+ Computes the log probability density of the Inverse Gamma distribution
+ at the given theta value in its external/user representation.
+
+ Parameters
+ ----------
+ theta : float
+ Parameter value at which to evaluate the log prior.
+ Must be positive (external/user representation).
+ **kwargs
+ Additional keyword arguments (unused).
+
+ Returns
+ -------
+ float
+ Log prior probability density at theta.
+
+ Notes
+ -----
+ The computation follows:
+ log p(θ) = C - (α + 1) * log(θ) - β / θ
+ where C is the normalizing constant, α is the shape parameter,
+ and β is the rate parameter.
+
+ Raises
+ ------
+ ValueError
+ If theta is not positive (implicitly through log computation).
+ """
+
+ log_prior = (
+ self.normalizing_constant
+ - (self.alpha + 1) * xp.log(theta)
+ - self.beta / theta
+ )
+
+ return log_prior
+
+if __name__ == "__main__":
+ """
+ Test Gaussian quadrature for functions using rescale_hyperparameters_to_internal()
+ from gamma prior hyperparameters.
+
+ We start with normally distributed random variables in internal space (unconstrained)
+ that get reparametrized to external space (positive) using the gamma prior's
+ rescaling function.
+ """
+
+ from dalia.utils.gaussian_quadrature import compute_variance_gauss_hermite
+
+ print("=" * 80)
+ print("Testing Gaussian Quadrature with Inverse Gamma Prior Rescaling")
+ print("=" * 80)
+
+ # Create a inverse gamma prior configuration
+ alpha_values = [1.0, 3.0, 5.0]
+ beta_values = [0.5, 1.0, 2.0]
+
+ for alpha, beta in zip(alpha_values, beta_values):
+ print(f"\nTesting alpha={alpha}, beta={beta}")
+ config = InverseGammaPriorHyperparametersConfig(alpha=alpha, beta=beta)
+ inverse_gamma_prior = InverseGammaPriorHyperparameters(config=config)
+
+ ## compare against scipy implementation
+ from scipy.stats import invgamma
+
+ test_values = [0.1, 0.5, 1.0, 2.0, 5.0]
+ print("Comparing log prior evaluations with scipy.stats.invgamma:")
+ for val in test_values:
+ logp_dalia = inverse_gamma_prior.evaluate_log_prior(val)
+ logp_scipy = invgamma.logpdf(val, a=alpha, scale=beta)
+ print(f" θ = {val:4.1f}: DALIA logp = {logp_dalia:.6f}, "
+ f"scipy logp = {logp_scipy:.6f}, diff = {abs(logp_dalia - logp_scipy):.2e}")
+ if abs(logp_dalia - logp_scipy) > 1e-6:
+ raise ValueError("Log prior evaluation does not match scipy implementation.")
+
+ print()
+ print("All tests passed!")
+
diff --git a/src/dalia/prior_hyperparameters/penalized_complexity.py b/src/dalia/prior_hyperparameters/penalized_complexity.py
index 2c9788d2..158562d0 100644
--- a/src/dalia/prior_hyperparameters/penalized_complexity.py
+++ b/src/dalia/prior_hyperparameters/penalized_complexity.py
@@ -44,6 +44,13 @@ def __init__(
# print("lambda_theta: ", self.lambda_theta)
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """Rescale hyperparameters to and from internal scale.
+
+ TODO: Implement the re-scaling.
+ """
+ return super().rescale_hyperparameters_to_internal(theta, direction)
+
def evaluate_log_prior(self, theta: float, **kwargs) -> float:
"""Evaluate the prior hyperparameters."""
log_prior: float = 0.0
diff --git a/src/dalia/solvers/__init__.py b/src/dalia/solvers/__init__.py
index 16cd72ca..69ed201e 100644
--- a/src/dalia/solvers/__init__.py
+++ b/src/dalia/solvers/__init__.py
@@ -1,8 +1,8 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
from dalia.solvers.dense_solver import DenseSolver
+from dalia.solvers.distributed_structured_solver import DistSerinvSolver
from dalia.solvers.sparse_solver import SparseSolver
from dalia.solvers.structured_solver import SerinvSolver
-from dalia.solvers.distributed_structured_solver import DistSerinvSolver
__all__ = ["DenseSolver", "SparseSolver", "SerinvSolver", "DistSerinvSolver"]
diff --git a/src/dalia/solvers/dense_solver.py b/src/dalia/solvers/dense_solver.py
index b93a920a..deaaeb67 100644
--- a/src/dalia/solvers/dense_solver.py
+++ b/src/dalia/solvers/dense_solver.py
@@ -1,8 +1,19 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
+import time
+
from dalia import NDArray, sp, xp
from dalia.configs.dalia_config import SolverConfig
from dalia.core.solver import Solver
+from dalia.utils import synchronize_gpu
+
+
+## check if sparse matrix is diagonal
+def is_diagonal(A: NDArray) -> bool:
+ """Check if a matrix is diagonal."""
+
+ coo = A.tocoo()
+ return xp.all(coo.row == coo.col)
class DenseSolver(Solver):
@@ -17,8 +28,6 @@ def __init__(
----------
config : SolverConfig
Configuration object for the solver.
- n : int
- Size of the matrix.
Returns
-------
@@ -32,35 +41,103 @@ def __init__(
self.L: NDArray = xp.zeros((self.n, self.n), dtype=xp.float64)
self.A_inv = None
- def cholesky(self, A: NDArray, **kwargs) -> None:
- self.L[:] = A.todense()
+ # Solver Metrics
+ self.t_factorize = 0.0
+ self.t_solve = 0.0
+
+ def factorize(self, A: NDArray, **kwargs) -> None:
+ """Compute the Cholesky decomposition of a matrix.
+
+ Parameters
+ ----------
+ A : NDArray
+ The input matrix to decompose.
+
+ Returns
+ -------
+ None
+
+ Note:
+ -----
+ Uses the Cholesky decomposition.
+ """
+ synchronize_gpu()
+ tic = time.perf_counter()
+
+ if sp.sparse.issparse(A):
+ # if A is diagonal, we can use the diagonal directly
+ if is_diagonal(A):
+ self.L[:] = 0
+ self.L[xp.arange(self.n), xp.arange(self.n)] = xp.sqrt(A.diagonal())
+ return
+
+ else:
+ self.L[:] = A.todense()
+ else:
+ self.L[:] = A
self.L = xp.linalg.cholesky(self.L)
+ synchronize_gpu()
+ toc = time.perf_counter()
+ self.t_factorize += toc - tic
+
def solve(
self,
rhs: NDArray,
**kwargs,
) -> NDArray:
- rhs[:] = sp.linalg.solve_triangular(self.L, rhs, lower=True, overwrite_b=True)
+ """Solve linear system using Cholesky factor.
+
+ Parameters
+ ----------
+ rhs : NDArray
+ Right-hand side of the linear system.
+
+ Returns
+ -------
+ NDArray
+ Solution of the linear system.
+ """
+ synchronize_gpu()
+ tic = time.perf_counter()
+
rhs[:] = sp.linalg.solve_triangular(
- self.L.T, rhs, lower=False, overwrite_b=True
+ self.L,
+ rhs,
+ lower=True,
+ )
+ rhs[:] = sp.linalg.solve_triangular(
+ self.L,
+ rhs,
+ trans="T",
+ lower=True,
)
+ synchronize_gpu()
+ toc = time.perf_counter()
+ self.t_solve += toc - tic
+
return rhs
def logdet(
self,
**kwargs,
) -> float:
+ """Compute the log determinant of the matrix.
+
+ Returns
+ -------
+ float
+ The log determinant of the matrix.
+ """
return 2 * xp.sum(xp.log(xp.diag(self.L)))
- # TODO: optimize for memory??
def selected_inversion(self, **kwargs) -> None:
-
- L_inv = xp.eye(self.L.shape[0])
- L_inv[:] = sp.linalg.solve_triangular(
- self.L, L_inv, lower=True, overwrite_b=True
+ L_inv = sp.linalg.solve_triangular(
+ self.L,
+ xp.eye(self.L.shape[0]),
+ lower=True,
)
self.A_inv = L_inv.T @ L_inv
@@ -76,4 +153,4 @@ def get_solver_memory(self) -> int:
"""Return the memory used by the solver in number of bytes."""
solver_mem = 2 * self.n * self.n * xp.dtype(xp.float64).itemsize
- return solver_mem
\ No newline at end of file
+ return solver_mem
diff --git a/src/dalia/solvers/distributed_structured_solver.py b/src/dalia/solvers/distributed_structured_solver.py
index 1b9fa8a5..b9f3f4d9 100644
--- a/src/dalia/solvers/distributed_structured_solver.py
+++ b/src/dalia/solvers/distributed_structured_solver.py
@@ -1,8 +1,7 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
-from warnings import warn
-
import time
+from warnings import warn
from dalia import NDArray, backend_flags, sp, xp, xp_host
from dalia.configs.dalia_config import SolverConfig
@@ -20,6 +19,7 @@
from serinv.utils import allocate_pobtax_permutation_buffers
from serinv.wrappers import (
allocate_pobtars,
+ allocate_pobtrs,
ppobtaf,
ppobtas,
ppobtasi,
@@ -158,17 +158,29 @@ def __init__(
self.buffer = allocate_pobtax_permutation_buffers(
self.A_diagonal_blocks,
)
- self.pobtars: dict = allocate_pobtars(
- A_diagonal_blocks=self.A_diagonal_blocks,
- A_lower_diagonal_blocks=self.A_lower_diagonal_blocks,
- A_lower_arrow_blocks=self.A_arrow_bottom_blocks,
- A_arrow_tip_block=self.A_arrow_tip_block,
- B=self.dist_rhs,
- comm=self.comm,
- array_module=xp.__name__,
- strategy="allgather",
- nccl_comm=self.nccl_comm,
- )
+
+ if self.arrowhead_blocksize > 0:
+ self.reduced_system: dict = allocate_pobtars(
+ A_diagonal_blocks=self.A_diagonal_blocks,
+ A_lower_diagonal_blocks=self.A_lower_diagonal_blocks,
+ A_lower_arrow_blocks=self.A_arrow_bottom_blocks,
+ A_arrow_tip_block=self.A_arrow_tip_block,
+ B=self.dist_rhs,
+ comm=self.comm,
+ array_module=xp.__name__,
+ strategy="allgather",
+ nccl_comm=self.nccl_comm,
+ )
+ else:
+ self.reduced_system: dict = allocate_pobtrs(
+ A_diagonal_blocks=self.A_diagonal_blocks,
+ A_lower_diagonal_blocks=self.A_lower_diagonal_blocks,
+ B=self.dist_rhs,
+ comm=self.comm,
+ array_module=xp.__name__,
+ strategy="allgather",
+ nccl_comm=self.nccl_comm,
+ )
# Initialize the caching strategy
self.bta_cache_block_sort_index = None
@@ -177,21 +189,38 @@ def __init__(
# Solver Metrics
self.total_bytes: int = 0
- self.t_cholesky = 0.0
+ self.t_factorize = 0.0
self.t_solve = 0.0
- def cholesky(
+ def factorize(
self,
A: sp.sparse.spmatrix,
sparsity: str,
) -> None:
- """Compute Cholesky factor of input matrix."""
+ """Compute the decomposition of a matrix.
+
+ Parameters
+ ----------
+ A : sp.sparse.spmatrix
+ The input matrix to decompose.
+ sparsity : str
+ The sparsity pattern of the matrix. Either 'bt' or 'bta'.
+
+ Returns
+ -------
+ None
+
+ Note:
+ -----
+ Uses the Cholesky decomposition.
+ """
+ synchronize(comm=self.comm)
+ tic = time.perf_counter()
+
# Reset the tip block for reccurrent calls
# print(f"WorldRank {self.rank} ENTERING {sparsity} cholesky.", flush=True)
self._spmatrix_to_structured(A, sparsity)
- tic = time.perf_counter()
- synchronize(comm=self.comm)
if sparsity == "bta":
ppobtaf(
self.A_diagonal_blocks,
@@ -199,7 +228,7 @@ def cholesky(
self.A_arrow_bottom_blocks,
self.A_arrow_tip_block,
buffer=self.buffer,
- pobtars=self.pobtars,
+ pobtars=self.reduced_system,
comm=self.comm,
strategy="allgather",
nccl_comm=self.nccl_comm,
@@ -209,7 +238,7 @@ def cholesky(
self.A_diagonal_blocks,
self.A_lower_diagonal_blocks,
buffer=self.buffer,
- pobtrs=self.pobtars,
+ pobtrs=self.reduced_system,
comm=self.comm,
strategy="allgather",
nccl_comm=self.nccl_comm,
@@ -218,53 +247,96 @@ def cholesky(
raise ValueError(
f"Unknown sparsity pattern: {sparsity}. Use 'bt' or 'bta'."
)
+
synchronize(comm=self.comm)
toc = time.perf_counter()
- self.t_cholesky += toc - tic
+ self.t_factorize += toc - tic
def solve(
self,
rhs: NDArray,
sparsity: str,
) -> NDArray:
- """Solve linear system using Cholesky factor."""
- self._slice_rhs(rhs, sparsity)
-
- tic = time.perf_counter()
+ """Solve linear system using Cholesky factor.
+
+ Parameters
+ ----------
+ rhs : NDArray
+ Right-hand side of the linear system.
+ sparsity : str
+ The sparsity pattern of the matrix. Either 'bt' or 'bta'.
+
+ Returns
+ -------
+ NDArray
+ Solution of the linear system.
+
+ Raises
+ ------
+ ValueError
+ If the sparsity pattern is unknown.
+ """
synchronize(comm=self.comm)
- if sparsity == "bta":
- ppobtas(
- L_diagonal_blocks=self.A_diagonal_blocks,
- L_lower_diagonal_blocks=self.A_lower_diagonal_blocks,
- L_lower_arrow_blocks=self.A_arrow_bottom_blocks,
- L_arrow_tip_block=self.A_arrow_tip_block,
- B=self.dist_rhs,
- buffer=self.buffer,
- pobtars=self.pobtars,
- comm=self.comm,
- strategy="allgather",
- nccl_comm=self.nccl_comm,
- )
- elif sparsity == "bt":
- ppobts(
- L_diagonal_blocks=self.A_diagonal_blocks,
- L_lower_diagonal_blocks=self.A_lower_diagonal_blocks,
- B=self.dist_rhs[: -self.arrowhead_blocksize],
- buffer=self.buffer,
- pobtars=self.pobtars,
- comm=self.comm,
- strategy="allgather",
- nccl_comm=self.nccl_comm,
- )
- else:
- raise ValueError(
- f"Unknown sparsity pattern: {sparsity}. Use 'bt' or 'bta'."
- )
+ tic = time.perf_counter()
+
+ # Store the original shape of rhs to reshape the solution back after solving
+ in_rhs_shape = rhs.shape
+
+ # Ensure rhs is a 2D array
+ if rhs.ndim == 1:
+ rhs = rhs[:, None]
+
+ # Handle multiple RHS by processing each column separately
+ for col in range(rhs.shape[1]):
+ rhs_col = rhs[:, col : col + 1] # Keep 2D shape with single column
+
+ self._slice_rhs(rhs_col, sparsity)
+
+ if sparsity == "bta":
+ ppobtas(
+ L_diagonal_blocks=self.A_diagonal_blocks,
+ L_lower_diagonal_blocks=self.A_lower_diagonal_blocks,
+ L_lower_arrow_blocks=self.A_arrow_bottom_blocks,
+ L_arrow_tip_block=self.A_arrow_tip_block,
+ B=self.dist_rhs,
+ buffer=self.buffer,
+ pobtars=self.reduced_system,
+ comm=self.comm,
+ strategy="allgather",
+ nccl_comm=self.nccl_comm,
+ )
+ elif sparsity == "bt":
+ ppobts(
+ L_diagonal_blocks=self.A_diagonal_blocks,
+ L_lower_diagonal_blocks=self.A_lower_diagonal_blocks,
+ B=(
+ self.dist_rhs[: -self.arrowhead_blocksize]
+ if self.arrowhead_blocksize > 0
+ else self.dist_rhs
+ ),
+ buffer=self.buffer,
+ pobtrs=self.reduced_system,
+ comm=self.comm,
+ strategy="allgather",
+ nccl_comm=self.nccl_comm,
+ )
+ else:
+ raise ValueError(
+ f"Unknown sparsity pattern: {sparsity}. Use 'bt' or 'bta'."
+ )
+
+ self._gather_rhs(rhs_col, sparsity)
+
+ # Copy the solved column back to the original rhs
+ rhs[:, col : col + 1] = rhs_col
+
synchronize(comm=self.comm)
toc = time.perf_counter()
self.t_solve += toc - tic
- self._gather_rhs(rhs, sparsity)
+ # Reshape the solution back to the original shape if needed
+ if in_rhs_shape != rhs.shape:
+ rhs = rhs.reshape(in_rhs_shape)
return rhs
@@ -272,7 +344,18 @@ def logdet(
self,
sparsity: str,
) -> float:
- """Compute logdet of input matrix using Cholesky factor."""
+ """Compute the log determinant of the matrix.
+
+ Parameters
+ ----------
+ sparsity : str
+ The sparsity pattern of the matrix. Either 'bt' or 'bta'.
+
+ Returns
+ -------
+ float
+ The log determinant of the matrix.
+ """
logdet = xp.array(0.0, dtype=xp.float64)
if self.rank == 0:
@@ -282,14 +365,16 @@ def logdet(
# Rank 0 do the reduced system; The loop start from 1 because of the
# AllGather strategy and the size of the reduced system associated.
- _n = self.pobtars["A_diagonal_blocks"].shape[0]
+ _n = self.reduced_system["A_diagonal_blocks"].shape[0]
for i in range(1, _n):
logdet += xp.sum(
- xp.log(self.pobtars["A_diagonal_blocks"][i].diagonal())
+ xp.log(self.reduced_system["A_diagonal_blocks"][i].diagonal())
)
if sparsity == "bta":
- logdet += xp.sum(xp.log(self.pobtars["A_arrow_tip_block"].diagonal()))
+ logdet += xp.sum(
+ xp.log(self.reduced_system["A_arrow_tip_block"].diagonal())
+ )
else:
for i in range(1, self.n_locals[self.rank] - 1):
logdet += xp.sum(xp.log(self.A_diagonal_blocks[i].diagonal()))
@@ -303,10 +388,9 @@ def logdet(
synchronize(comm=self.comm)
if xp.isnan(logdet):
- print(
+ raise ValueError(
f"WorldRank {MPI.COMM_WORLD.rank} logdet is NaN for {sparsity} matrix."
)
- exit()
return 2 * logdet
@@ -322,7 +406,7 @@ def selected_inversion(
L_lower_arrow_blocks=self.A_arrow_bottom_blocks,
L_arrow_tip_block=self.A_arrow_tip_block,
buffer=self.buffer,
- pobtars=self.pobtars,
+ pobtars=self.reduced_system,
comm=self.comm,
strategy="allgather",
nccl_comm=self.nccl_comm,
@@ -332,7 +416,7 @@ def selected_inversion(
L_diagonal_blocks=self.A_diagonal_blocks,
L_lower_diagonal_blocks=self.A_lower_diagonal_blocks,
buffer=self.buffer,
- pobtrs=self.pobtars,
+ pobtrs=self.reduced_system,
comm=self.comm,
strategy="allgather",
nccl_comm=self.nccl_comm,
@@ -352,8 +436,10 @@ def _spmatrix_to_structured(
"""Map sp.spmatrix to BT or BTA."""
self.A_diagonal_blocks[:] = 0.0
self.A_lower_diagonal_blocks[:] = 0.0
- self.A_arrow_bottom_blocks[:] = 0.0
- self.A_arrow_tip_block[:] = 0.0
+ if self.A_arrow_bottom_blocks is not None:
+ self.A_arrow_bottom_blocks[:] = 0.0
+ if self.A_arrow_tip_block is not None:
+ self.A_arrow_tip_block[:] = 0.0
if xp.__name__ == "cupy" and sparsity == "bta":
if sparsity == "bta":
@@ -408,9 +494,9 @@ def _spmatrix_to_structured(
block_slice = A_csc[
-self.arrowhead_blocksize :, -self.arrowhead_blocksize :
].tocoo()
- self.A_arrow_tip_block[
- block_slice.row, block_slice.col
- ] = block_slice.data
+ self.A_arrow_tip_block[block_slice.row, block_slice.col] = (
+ block_slice.data
+ )
def _spmatrix_to_bta(
self,
@@ -646,6 +732,7 @@ def _structured_to_spmatrix(
self,
A: sp.sparse.spmatrix,
sparsity: str,
+ symmetrize: bool = True,
) -> sp.sparse.spmatrix:
"""Map BT or BTA matrix to sp.spmatrix using sparsity pattern provided in A."""
# A is assumed to be symmetric, only use lower triangular part
@@ -717,18 +804,19 @@ def _structured_to_spmatrix(
rows = xp.concatenate(rows)
cols = xp.concatenate(cols)
- # TODO: Need to communicate to agregates/Map the local B matrix to all ranks
+ # Need to communicate to agregates/Map the local B matrix to all ranks
# Need to operate on the datas
- l_data = allgather(data, comm=self.comm)
- l_rows = allgather(rows, comm=self.comm)
- l_cols = allgather(cols, comm=self.comm)
+ l_data = xp.concatenate(allgather(data, comm=self.comm))
+ l_rows = xp.concatenate(allgather(rows, comm=self.comm))
+ l_cols = xp.concatenate(allgather(cols, comm=self.comm))
synchronize(comm=self.comm)
- B_out = sp.sparse.coo_matrix((l_data, (l_rows, l_cols)), shape=B.shape).tocsc()
- # Symmetrize B
- B_out = B_out + sp.sparse.tril(B_out, k=-1).T
+ B_out = sp.sparse.coo_matrix((l_data, (l_rows, l_cols)), shape=B.shape).tocsc()
- return B_out
+ if symmetrize:
+ return B_out + sp.sparse.tril(B_out, k=-1).T
+ else:
+ return B_out
def _slice_rhs(
self,
@@ -740,14 +828,16 @@ def _slice_rhs(
start_idx = int(xp.cumsum(n_idx)[self.rank])
end_idx = int(xp.cumsum(n_idx)[self.rank + 1])
- # Ensure rhs is a 2D array with shape (n, 1)
- if rhs.ndim == 1:
- rhs = rhs[:, None]
-
+ # rhs is guaranteed to be 2D with shape (n, 1) at this point
# print(f"Rank {self.rank} rhs.shape: {rhs.shape}, self.dist_rhs.shape: {self.dist_rhs.shape}, start_idx: {start_idx* self.diagonal_blocksize}, end_idx: {end_idx* self.diagonal_blocksize}")
- self.dist_rhs[: -self.arrowhead_blocksize] = rhs[
- start_idx * self.diagonal_blocksize : end_idx * self.diagonal_blocksize
- ]
+ if self.arrowhead_blocksize > 0:
+ self.dist_rhs[: -self.arrowhead_blocksize] = rhs[
+ start_idx * self.diagonal_blocksize : end_idx * self.diagonal_blocksize
+ ]
+ else:
+ self.dist_rhs[:] = rhs[
+ start_idx * self.diagonal_blocksize : end_idx * self.diagonal_blocksize
+ ]
if sparsity == "bta":
self.dist_rhs[-self.arrowhead_blocksize :] = rhs[
-self.arrowhead_blocksize :, :
@@ -769,15 +859,27 @@ def _gather_rhs(
and not backend_flags["mpi_cuda_aware"]
and not backend_flags["nccl_avail"]
):
- self.dist_rhs[: -self.arrowhead_blocksize].flatten().get(
- out=self.send_rhs[
- self.remainders[self.rank] * self.diagonal_blocksize :
- ]
- )
+ if self.arrowhead_blocksize > 0:
+ self.dist_rhs[: -self.arrowhead_blocksize].flatten().get(
+ out=self.send_rhs[
+ self.remainders[self.rank] * self.diagonal_blocksize :
+ ]
+ )
+ else:
+ self.dist_rhs.flatten().get(
+ out=self.send_rhs[
+ self.remainders[self.rank] * self.diagonal_blocksize :
+ ]
+ )
else:
- self.send_rhs[
- self.remainders[self.rank] * self.diagonal_blocksize :
- ] = self.dist_rhs[: -self.arrowhead_blocksize].flatten()
+ if self.arrowhead_blocksize > 0:
+ self.send_rhs[
+ self.remainders[self.rank] * self.diagonal_blocksize :
+ ] = self.dist_rhs[: -self.arrowhead_blocksize].flatten()
+ else:
+ self.send_rhs[
+ self.remainders[self.rank] * self.diagonal_blocksize :
+ ] = self.dist_rhs.flatten()
synchronize(comm=self.comm)
self.comm.Allgather(
@@ -786,6 +888,9 @@ def _gather_rhs(
)
synchronize(comm=self.comm)
+ # This part needs a major re-work as the multiple RHS are handled one column at a time
+ # but Serinv can handle multiple RHS in one go. The problem is linked to DALIA internal
+ # representation (2nd dimension of rhs when only 1 rhs) and the collectives operations.
if (
backend_flags["array_module"] == "cupy"
and not backend_flags["mpi_cuda_aware"]
@@ -795,18 +900,29 @@ def _gather_rhs(
start_idx = int(xp.cumsum(n_idx)[i])
end_idx = int(xp.cumsum(n_idx)[i + 1])
+ # Extract data from pinned host memory
+ host_data = self.recv_rhs[
+ (i * self.max_n_locals + self.remainders[i])
+ * self.diagonal_blocksize : (i + 1)
+ * self.max_n_locals
+ * self.diagonal_blocksize
+ ]
+
+ # Since we're processing one column at a time, rhs.shape[1] is always 1
+ # Reshape the 1D host data to 2D (n_rows, 1)
+ n_rows = (
+ end_idx * self.diagonal_blocksize
+ - start_idx * self.diagonal_blocksize
+ )
+ host_data_2d = host_data[:n_rows].reshape(-1, 1)
+
+ # Transfer to device and assign
+ # Here we can't use `.set()` because multidimmensional rhs is not contiguous array.
rhs[
start_idx
* self.diagonal_blocksize : end_idx
* self.diagonal_blocksize
- ].set(
- arr=self.recv_rhs[
- (i * self.max_n_locals + self.remainders[i])
- * self.diagonal_blocksize : (i + 1)
- * self.max_n_locals
- * self.diagonal_blocksize
- ]
- )
+ ] = xp.asarray(host_data_2d)
else:
for i in range(self.comm_size):
start_idx = int(xp.cumsum(n_idx)[i])
@@ -821,22 +937,24 @@ def _gather_rhs(
* self.diagonal_blocksize : (i + 1)
* self.max_n_locals
* self.diagonal_blocksize
- ]
+ ].reshape(
+ -1, rhs.shape[1]
+ )
if sparsity == "bta":
# Map the arrow-tip of self.dist_rhs to the global rhs
rhs[-self.arrowhead_blocksize :] = self.dist_rhs[
-self.arrowhead_blocksize :
- ].flatten()
+ ]
def get_solver_memory(self) -> int:
"""Return the memory used by the solver in number of bytes"""
bytes_pobtars: int = (
self.buffer.nbytes
- + self.pobtars["A_diagonal_blocks"].nbytes
- + self.pobtars["A_lower_diagonal_blocks"].nbytes
- + self.pobtars["A_lower_arrow_blocks"].nbytes
- + self.pobtars["A_arrow_tip_block"].nbytes
- + self.pobtars["B"].nbytes
+ + self.reduced_system["A_diagonal_blocks"].nbytes
+ + self.reduced_system["A_lower_diagonal_blocks"].nbytes
+ + self.reduced_system["A_lower_arrow_blocks"].nbytes
+ + self.reduced_system["A_arrow_tip_block"].nbytes
+ + self.reduced_system["B"].nbytes
)
bytes_local_system: int = (
self.A_diagonal_blocks.nbytes
@@ -846,4 +964,4 @@ def get_solver_memory(self) -> int:
)
self.total_bytes += bytes_pobtars + bytes_local_system
- return self.total_bytes
\ No newline at end of file
+ return self.total_bytes
diff --git a/src/dalia/solvers/sparse_solver.py b/src/dalia/solvers/sparse_solver.py
index 7403cb60..8fd9a802 100644
--- a/src/dalia/solvers/sparse_solver.py
+++ b/src/dalia/solvers/sparse_solver.py
@@ -1,8 +1,18 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
+import time
+
from dalia import NDArray, sp, xp
from dalia.configs.dalia_config import SolverConfig
from dalia.core.solver import Solver
+from dalia.utils import synchronize_gpu
+
+# This is a workaround a problem in cupyx, where linalg is not properly namespaced (directly accessible).
+# May be removed in future versions of cupy (tested on cupy 13.4.1).
+if xp.__name__ == "cupy":
+ from cupyx.scipy.sparse.linalg import splu
+else:
+ from scipy.sparse.linalg import splu
class SparseSolver(Solver):
@@ -14,55 +24,253 @@ def __init__(
"""Initializes the solver."""
super().__init__(config)
- self.L: sp.sparse.spmatrix = None
+ self.LU_factor = None # Store the LU factorization object
+ self.A_inv = None # Store the inverse of A if needed
+
+ # Solver Metrics
+ self.t_factorize = 0.0
+ self.t_solve = 0.0
+
+ def factorize(self, A: sp.sparse.spmatrix, **kwargs) -> None:
+ """Compute the decomposition of a matrix.
+
+ Note: This uses LU decomposition since sparse Cholesky is not readily available.
+
- def cholesky(self, A: sp.sparse.spmatrix, **kwargs) -> None:
- """Compute Cholesky factor of input matrix."""
+ Parameters
+ ----------
+ A : sp.sparse.spmatrix
+ The input matrix to decompose.
+
+ Returns
+ -------
+ None
+
+ Note:
+ -----
+ Uses the LU decomposition by default as scipy.sparse doesn't implement Cholesky.
+ """
+ synchronize_gpu()
+ tic = time.perf_counter()
A = sp.sparse.csc_matrix(A)
- LU = sp.sparse.linalg.splu(A, diag_pivot_thresh=0, permc_spec="NATURAL")
+ # Use LU decomposition as the factorization method
+ self.LU_factor = splu(A, diag_pivot_thresh=0, permc_spec="NATURAL")
- if (LU.U.diagonal() > 0).all(): # Check the matrix A is positive definite.
- self.L = LU.L.dot(sp.sparse.diags(LU.U.diagonal() ** 0.5))
- else:
- raise ValueError("The matrix is not positive definite")
+ # Check if the matrix appears to be positive definite
+ if not (self.LU_factor.U.diagonal() > 0).all():
+ raise ValueError("The matrix does not appear to be positive definite")
+
+ synchronize_gpu()
+ toc = time.perf_counter()
+ self.t_factorize += toc - tic
def solve(
self,
rhs: NDArray,
**kwargs,
) -> NDArray:
- """Solve linear system using Cholesky factor."""
+ """Solve linear system using LU factorization.
- if self.L is None:
- raise ValueError("Cholesky factor not computed")
+ Parameters
+ ----------
+ rhs : NDArray
+ Right-hand side of the linear system.
- sp.sparse.linalg.spsolve_triangular(self.L, rhs, lower=True, overwrite_b=True)
- sp.sparse.linalg.spsolve_triangular(
- self.L.T, rhs, lower=False, overwrite_b=True
- )
+ Returns
+ -------
+ NDArray
+ Solution of the linear system.
+ """
+ synchronize_gpu()
+ tic = time.perf_counter()
- return rhs
+ if self.LU_factor is None:
+ raise ValueError("Matrix factorization not computed")
+
+ # Handle multiple RHS cases
+ if rhs.ndim == 1:
+ # Single RHS as 1D array
+ x = self.LU_factor.solve(rhs)
+ elif rhs.ndim == 2 and rhs.shape[1] == 1:
+ # Single RHS as column vector
+ x = self.LU_factor.solve(rhs.flatten())
+ x = x.reshape(rhs.shape)
+ elif rhs.ndim == 2 and rhs.shape[1] > 1:
+ # Multiple RHS (batched) - scipy splu can handle this directly
+ x = self.LU_factor.solve(rhs)
+ else:
+ raise ValueError(f"Unsupported RHS shape: {rhs.shape}")
+
+ synchronize_gpu()
+ toc = time.perf_counter()
+ self.t_solve += toc - tic
+
+ return x
def logdet(
self,
**kwargs,
) -> float:
- """Compute logdet of input matrix using Cholesky factor."""
+ """Compute the log determinant of the matrix.
+
+ Returns
+ -------
+ float
+ The log determinant of the matrix.
+ """
+
+ if self.LU_factor is None:
+ raise ValueError("Matrix factorization not computed")
+
+ # For LU decomposition: det(A) = det(L) * det(U)
+ # Since L has 1s on diagonal: det(L) = 1
+ # So det(A) = det(U) = product of diagonal elements of U
+ log_det_U = xp.sum(xp.log(xp.abs(self.LU_factor.U.diagonal())))
+
+ return float(log_det_U)
+
+ def selected_inversion(self, batch_size: int = 64, **kwargs) -> sp.sparse.spmatrix:
+ """Compute selected entries of the inverse using sparsity pattern of L and U factors.
+
+ This implementation:
+ - Processes the matrix in batches to avoid storing full dense inverse in memory
+ - Leverages the existing self.solve() method with sparse LU factors
+ - Only stores entries matching the sparsity pattern of L and U
+
+ Parameters
+ ----------
+ batch_size : int, optional
+ Number of columns to process in each batch (default: 64).
+ Smaller batches use less memory but may be slower.
- if self.L is None:
- raise ValueError("Cholesky factor not computed")
+ Returns
+ -------
+ sp.sparse.spmatrix
+ Sparse matrix with the inverse entries at positions where L or U have non-zeros
+ """
- return 2 * xp.sum(xp.log(self.L.diagonal()))
+ n = self.LU_factor.L.shape[0]
- def selected_inversion(self, **kwargs):
- # Placeholder for the selected inversion method.
- return super().selected_inversion(**kwargs)
+ # Get the combined sparsity pattern of L and U factors
+ # This determines which entries of A_inv we need to extract
+ L_coo = self.LU_factor.L.tocoo()
+ U_coo = self.LU_factor.U.tocoo()
+
+ # Combine patterns: collect all (row, col) pairs where L or U have non-zeros
+ # . move to CPU if on GPU to handle set operations
+ if xp.__name__ == "cupy":
+ L_rows = L_coo.row.get()
+ L_cols = L_coo.col.get()
+ U_rows = U_coo.row.get()
+ U_cols = U_coo.col.get()
+ else:
+ L_rows = L_coo.row
+ L_cols = L_coo.col
+ U_rows = U_coo.row
+ U_cols = U_coo.col
+
+ # Combine patterns and remove duplicates
+ pattern_set = set()
+ for r, c in zip(L_rows, L_cols):
+ pattern_set.add((int(r), int(c)))
+ for r, c in zip(U_rows, U_cols):
+ pattern_set.add((int(r), int(c)))
+
+ # Create a set for fast lookup: entries to extract
+ pattern_entries = pattern_set
+
+ # Storage for sparse result
+ data = []
+ row_indices = []
+ col_indices = []
+
+ # Process matrix in batches of columns
+ for batch_start in range(0, n, batch_size):
+ batch_end = min(batch_start + batch_size, n)
+ batch_cols = batch_end - batch_start
+
+ # Build batched RHS: identity columns for this batch
+ rhs_batch = xp.zeros((n, batch_cols))
+ for i in range(batch_cols):
+ rhs_batch[batch_start + i, i] = 1.0
+
+ # Solve A @ X = RHS using the existing batched solve method
+ # This leverages the sparse LU factorization efficiently
+ X = self.solve(rhs_batch)
+
+ # Extract only the entries that match the sparsity pattern
+ for row, col in pattern_entries:
+ # Check if this (row, col) pair is in the current batch
+ if batch_start <= col < batch_end:
+ batch_idx = col - batch_start
+ # Move to CPU if on GPU for data extraction
+ if xp.__name__ == "cupy":
+ value = float(X[row, batch_idx].get())
+ else:
+ value = float(X[row, batch_idx])
+ data.append(value)
+ row_indices.append(row)
+ col_indices.append(col)
+
+ # Convert lists to proper 1D arrays for sparse matrix construction
+ data_array = xp.array(data, dtype=xp.float64)
+ row_array = xp.array(row_indices, dtype=xp.int32)
+ col_array = xp.array(col_indices, dtype=xp.int32)
+
+ # Create sparse matrix from the selected entries
+ self.A_inv = sp.sparse.coo_matrix(
+ (data_array, (row_array, col_array)), shape=(n, n)
+ ).tocsr()
+
+ return self.A_inv
+
+ def _structured_to_spmatrix(
+ self, A: sp.sparse.spmatrix, **kwargs
+ ) -> sp.sparse.spmatrix:
+ """Convert the A_inv matrix to a sparse matrix masked using the given sparsity pattern A.
+
+ Extracts entries from self.A_inv at positions specified by the non-zero pattern of A,
+ maintaining sparsity throughout the operation.
+
+ Parameters
+ ----------
+ A : sp.sparse.spmatrix
+ Sparse matrix defining the sparsity pattern to extract.
+
+ Returns
+ -------
+ sp.sparse.spmatrix
+ Sparse matrix with values from A_inv at positions matching A's pattern.
+ """
+
+ B = A.tocoo()
+ # Extract values from A_inv using element-wise indexing
+ B.data = xp.array(
+ [float(self.A_inv[int(r), int(c)]) for r, c in zip(B.row, B.col)]
+ )
+ return B.tocsr()
def get_solver_memory(self) -> int:
"""Return the memory used by the solver in number of bytes"""
- if self.L is None:
+ if self.LU_factor is None:
return 0
- return self.L.data.nbytes + self.L.indptr.nbytes + self.L.indices.nbytes
\ No newline at end of file
+ # Estimate memory usage from L and U matrices
+ L_memory = (
+ self.LU_factor.L.data.nbytes
+ + self.LU_factor.L.indptr.nbytes
+ + self.LU_factor.L.indices.nbytes
+ )
+ U_memory = (
+ self.LU_factor.U.data.nbytes
+ + self.LU_factor.U.indptr.nbytes
+ + self.LU_factor.U.indices.nbytes
+ )
+
+ if self.A_inv is not None:
+ A_inv_memory = self.A_inv.nbytes
+ return L_memory + U_memory + A_inv_memory
+
+ return L_memory + U_memory
diff --git a/src/dalia/solvers/structured_solver.py b/src/dalia/solvers/structured_solver.py
index ed319d66..816f4cd0 100644
--- a/src/dalia/solvers/structured_solver.py
+++ b/src/dalia/solvers/structured_solver.py
@@ -1,7 +1,7 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
-from warnings import warn
import time
+from warnings import warn
from dalia import NDArray, sp, xp, xp_host
from dalia.configs.dalia_config import SolverConfig
@@ -15,7 +15,6 @@
warn(f"The serinv package is required to use the SerinvSolver: {e}")
-
class SerinvSolver(Solver):
"""Serinv Solver class."""
@@ -66,82 +65,118 @@ def __init__(
# Solver Metrics
self.total_bytes: int = 0
- self.t_cholesky = 0.0
+ self.t_factorize = 0.0
self.t_solve = 0.0
- def cholesky(
+ def factorize(
self,
A: sp.sparse.spmatrix,
sparsity: str,
) -> None:
- """Compute Cholesky factor of input matrix."""
- self._spmatrix_to_structured(A, sparsity)
-
- tic = time.perf_counter()
+ """Compute the decomposition of a matrix.
+
+ Parameters
+ ----------
+ A : sp.sparse.spmatrix
+ The input matrix to decompose.
+ sparsity : str
+ The sparsity pattern of the matrix. Either 'bt' or 'bta'.
+
+ Returns
+ -------
+ None
+
+ Note:
+ -----
+ Uses the Cholesky decomposition.
+ """
synchronize_gpu()
+ tic = time.perf_counter()
+
+ self._spmatrix_to_structured(A, sparsity)
if sparsity == "bta":
- pobtaf(
- self.A_diagonal_blocks,
- self.A_lower_diagonal_blocks,
- self.A_arrow_bottom_blocks,
- self.A_arrow_tip_block,
- )
+ pobtaf(
+ self.A_diagonal_blocks,
+ self.A_lower_diagonal_blocks,
+ self.A_arrow_bottom_blocks,
+ self.A_arrow_tip_block,
+ )
elif sparsity == "bt":
- pobtf(
- self.A_diagonal_blocks,
- self.A_lower_diagonal_blocks,
- )
+ pobtf(
+ self.A_diagonal_blocks,
+ self.A_lower_diagonal_blocks,
+ )
else:
raise ValueError(
f"Unknown sparsity pattern: {sparsity}. Use 'bt' or 'bta'."
)
+
synchronize_gpu()
toc = time.perf_counter()
- self.t_cholesky += toc - tic
+ self.t_factorize += toc - tic
def solve(
self,
rhs: NDArray,
sparsity: str,
) -> NDArray:
- """Solve linear system using Cholesky factor."""
-
- tic = time.perf_counter()
+ """Solve linear system using Cholesky factor.
+
+ Parameters
+ ----------
+ rhs : NDArray
+ Right-hand side of the linear system.
+ sparsity : str
+ The sparsity pattern of the matrix. Either 'bt' or 'bta'.
+
+ Returns
+ -------
+ NDArray
+ Solution of the linear system.
+
+ Raises
+ ------
+ ValueError
+ If the sparsity pattern is unknown.
+ """
synchronize_gpu()
+ tic = time.perf_counter()
+
if sparsity == "bta":
- pobtas(
- self.A_diagonal_blocks,
- self.A_lower_diagonal_blocks,
- self.A_arrow_bottom_blocks,
- self.A_arrow_tip_block,
- rhs,
- trans="N",
- )
- pobtas(
- self.A_diagonal_blocks,
- self.A_lower_diagonal_blocks,
- self.A_arrow_bottom_blocks,
- self.A_arrow_tip_block,
- rhs,
- trans="C",
- )
+ pobtas(
+ self.A_diagonal_blocks,
+ self.A_lower_diagonal_blocks,
+ self.A_arrow_bottom_blocks,
+ self.A_arrow_tip_block,
+ rhs,
+ trans="N",
+ )
+ pobtas(
+ self.A_diagonal_blocks,
+ self.A_lower_diagonal_blocks,
+ self.A_arrow_bottom_blocks,
+ self.A_arrow_tip_block,
+ rhs,
+ trans="C",
+ )
elif sparsity == "bt":
- pobts(
- self.A_diagonal_blocks,
- self.A_lower_diagonal_blocks,
- rhs,
- trans="N",
- )
- pobts(
- self.A_diagonal_blocks,
- self.A_lower_diagonal_blocks,
- rhs,
- trans="C",
- )
+ pobts(
+ self.A_diagonal_blocks,
+ self.A_lower_diagonal_blocks,
+ rhs,
+ trans="N",
+ )
+ pobts(
+ self.A_diagonal_blocks,
+ self.A_lower_diagonal_blocks,
+ rhs,
+ trans="C",
+ )
else:
raise ValueError(
f"Unknown sparsity pattern: {sparsity}. Use 'bt' or 'bta'."
)
+
synchronize_gpu()
toc = time.perf_counter()
self.t_solve += toc - tic
@@ -152,7 +187,18 @@ def logdet(
self,
sparsity: str,
) -> float:
- """Compute logdet of input matrix using Cholesky factor."""
+ """Compute the log determinant of the matrix.
+
+ Parameters
+ ----------
+ sparsity : str
+ The sparsity pattern of the matrix. Either 'bt' or 'bta'.
+
+ Returns
+ -------
+ float
+ The log determinant of the matrix.
+ """
logdet: float = 0.0
for i in range(self.n_diag_blocks):
logdet += xp.sum(xp.log(self.A_diagonal_blocks[i].diagonal()))
@@ -197,8 +243,10 @@ def _spmatrix_to_structured(
"""Map sp.spmatrix to BT or BTA."""
self.A_diagonal_blocks[:] = 0.0
self.A_lower_diagonal_blocks[:] = 0.0
- self.A_arrow_bottom_blocks[:] = 0.0
- self.A_arrow_tip_block[:] = 0.0
+ if self.A_arrow_bottom_blocks is not None:
+ self.A_arrow_bottom_blocks[:] = 0.0
+ if self.A_arrow_tip_block is not None:
+ self.A_arrow_tip_block[:] = 0.0
if xp.__name__ == "cupy":
if sparsity == "bta":
@@ -245,9 +293,9 @@ def _spmatrix_to_structured(
block_slice = A_csc[
-self.arrowhead_blocksize :, -self.arrowhead_blocksize :
].tocoo()
- self.A_arrow_tip_block[
- block_slice.row, block_slice.col
- ] = block_slice.data
+ self.A_arrow_tip_block[block_slice.row, block_slice.col] = (
+ block_slice.data
+ )
def _spmatrix_to_bta(
self,
@@ -471,8 +519,36 @@ def _structured_to_spmatrix(
self,
A: sp.sparse.spmatrix,
sparsity: str,
+ symmetrize: bool = True,
) -> sp.sparse.spmatrix:
- """Map BT or BTA matrix to sp.spmatrix using sparsity pattern provided in A."""
+ """Map a BT or BTA structured matrix to a sparse `csc` format given the sparsity pattern.
+
+ Parameters
+ ----------
+ A : sp.sparse.spmatrix
+ Input matrix in sparse format.
+ sparsity : str
+ The sparsity pattern of the matrix. Either 'bt' or 'bta'.
+ symmetrize : bool, optional
+ Whether to symmetrize the output matrix, by default True.
+
+ Returns
+ -------
+ sp.sparse.spmatrix
+ The output matrix in sparse format.
+
+ Notes
+ -----
+ By default this function symmetrize the matrix as it is mostly use in DALIA on the selected
+ inverse of SPD matrices. Because of further matrix operations on these matrices, we symmetrize
+ them to avoid numerical asymmetries.
+
+ Raises
+ ------
+ ValueError
+ If the sparsity pattern is unknown.
+
+ """
# A is assumed to be symmetric, only use lower triangular part
B = sp.sparse.csc_matrix(sp.sparse.tril(sp.sparse.csc_matrix(A)))
@@ -532,10 +608,10 @@ def _structured_to_spmatrix(
B_out = sp.sparse.coo_matrix((data, (rows, cols)), shape=B.shape).tocsc()
- # Symmetrize B
- B_out = B_out + sp.sparse.tril(B_out, k=-1).T
-
- return B_out
+ if symmetrize:
+ return B_out + sp.sparse.tril(B_out, k=-1).T
+ else:
+ return B_out
def get_solver_memory(self) -> int:
"""Return the memory used by the solver in number of bytes"""
@@ -547,4 +623,4 @@ def get_solver_memory(self) -> int:
)
self.total_bytes += bytes_pobtars
- return self.total_bytes
\ No newline at end of file
+ return self.total_bytes
diff --git a/src/dalia/submodels/__init__.py b/src/dalia/submodels/__init__.py
index 2e2b78fb..9338009c 100644
--- a/src/dalia/submodels/__init__.py
+++ b/src/dalia/submodels/__init__.py
@@ -1,8 +1,16 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
+from dalia.submodels.brainiac import BrainiacSubModel
from dalia.submodels.regression import RegressionSubModel
from dalia.submodels.spatial import SpatialSubModel
from dalia.submodels.spatio_temporal import SpatioTemporalSubModel
from dalia.submodels.brainiac import BrainiacSubModel
+from dalia.submodels.ar1 import AR1SubModel
-__all__ = ["RegressionSubModel", "SpatialSubModel", "SpatioTemporalSubModel", "BrainiacSubModel"]
+__all__ = [
+ "RegressionSubModel",
+ "SpatialSubModel",
+ "SpatioTemporalSubModel",
+ "BrainiacSubModel",
+ "AR1SubModel",
+]
diff --git a/src/dalia/submodels/ar1.py b/src/dalia/submodels/ar1.py
new file mode 100644
index 00000000..41a5ce25
--- /dev/null
+++ b/src/dalia/submodels/ar1.py
@@ -0,0 +1,71 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+from tabulate import tabulate
+
+import numpy as np
+
+from dalia import sp, xp
+from dalia.configs.submodels_config import AR1SubModelConfig
+from dalia.core.submodel import SubModel
+from dalia.utils import add_str_header
+
+
+class AR1SubModel(SubModel):
+ """Fit an AR(1) model."""
+
+ def __init__(
+ self,
+ config: AR1SubModelConfig,
+ ) -> None:
+ """Initializes the model."""
+ super().__init__(config)
+
+ # check that dimensions match
+
+
+ def construct_Q_prior(self, **kwargs) -> sp.sparse.coo_matrix:
+ """Construct the prior precision matrix."""
+
+ # kwargs expects hyperparameters in external scale
+ phi = kwargs.get("phi")
+ tau = kwargs.get("tau")
+
+ s2 = 1 / tau
+ denom = s2 * (1 - phi**2)
+
+ diag = [(1 + phi**2) / denom] * self.n_latent_parameters
+ diag[0] = diag[-1] = 1 / denom
+ off_diag = [-phi / denom] * (self.n_latent_parameters - 1)
+
+ Q_prior = sp.sparse.diags([off_diag, diag, off_diag], [-1, 0, 1])
+
+ # need this -> otherwise there might be a sorting issue
+ Q_prior = Q_prior.tocsr()
+ Q_prior.sort_indices()
+
+ return Q_prior.tocoo()
+
+ def __str__(self) -> str:
+ """String representation of the submodel."""
+ str_representation = ""
+
+ # --- Make the Submodel table ---
+ values = [
+ ["Submodel Type", self.submodel_type],
+ ["Number of Latent Parameters", self.n_latent_parameters],
+ ["Phi", f"{self.config.phi:.3f}"],
+ ["tau", f"{self.config.tau:.3f}"],
+ ]
+ submodel_table = tabulate(
+ values,
+ tablefmt="fancy_grid",
+ colalign=("left", "center"),
+ )
+
+ # Add the header title
+ submodel_table = add_str_header(
+ title=self.submodel_type.replace("_", " ").title(),
+ table=submodel_table,
+ )
+ str_representation += submodel_table
+
+ return str_representation
diff --git a/src/dalia/submodels/brainiac.py b/src/dalia/submodels/brainiac.py
index 054006e1..21100c1c 100644
--- a/src/dalia/submodels/brainiac.py
+++ b/src/dalia/submodels/brainiac.py
@@ -1,13 +1,11 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
+import numpy as np
from tabulate import tabulate
-from dalia import sp, NDArray
+from dalia import NDArray, sp, xp
from dalia.configs.submodels_config import BrainiacSubModelConfig
from dalia.core.submodel import SubModel
-from dalia.utils import scaled_logit, add_str_header
-
-import numpy as np
-from dalia import sp, xp
+from dalia.utils import add_str_header, scaled_logit
class BrainiacSubModel(SubModel):
@@ -52,25 +50,13 @@ def _check_dimensions_matrices(self) -> None:
self.z.shape[0] == self.a.shape[1]
), f"Numbers rows in z ({self.z.shape[0]}) must match number of columns in a ({self.a.shape[1]})."
- def rescale_hyperparameters_to_interpret(self, **kwargs) -> NDArray:
-
- h2_scaled = kwargs.get("h2")
- # rescale h2 to (0,1) as it's currently between -INF:+INF
- h2 = scaled_logit(h2_scaled, direction="backward")
-
- theta_interpret = np.array([h2, *kwargs["alpha"]])
- print("theta_interpret: ", theta_interpret)
- return theta_interpret
def construct_Q_prior(self, **kwargs) -> sp.sparse.coo_matrix:
"""Construct the prior precision matrix."""
# Extract all alpha_x values and put them into an array
alpha_keys = sorted([key for key in kwargs if key.startswith("alpha_")])
alpha = xp.array([kwargs[key] for key in alpha_keys])
- h2_scaled = kwargs.get("h2")
-
- # rescale h2 to (0,1) as it's currently between -INF:+INF
- h2 = scaled_logit(h2_scaled, direction="backward")
+ h2 = kwargs.get("h2")
# \Phi = 1 / \sum_k=1^B exp(Z^k \alpha) * diag(exp(Z_1 \alpha), exp(Z_2 \alpha), ... )
exp_Z_alpha = xp.exp(self.z @ alpha)
@@ -86,20 +72,16 @@ def construct_Q_prior(self, **kwargs) -> sp.sparse.coo_matrix:
return Q_prior.tocoo()
- def evaluate_likelihood(
- self, eta: NDArray, y: NDArray, **kwargs
- ) -> float:
+ def evaluate_likelihood(self, eta: NDArray, y: NDArray, **kwargs) -> float:
n_observations = y.shape[0]
- h2_scaled = kwargs.get("h2")
- # rescale h2 to (0,1) as it's currently between -INF:+INF
- h2 = scaled_logit(h2_scaled, direction="backward")
- if(h2 == 1):
+ h2 = kwargs.get("h2")
+ if h2 == 1:
raise ValueError("h2 is 1. Will lead to division by zero.")
yEta = y - eta
likelihood: float = (
- 0.5 * - np.log(1 - h2) * n_observations - 0.5 / (1 - h2) * yEta.T @ yEta
+ 0.5 * -np.log(1 - h2) * n_observations - 0.5 / (1 - h2) * yEta.T @ yEta
)
return likelihood
@@ -107,11 +89,9 @@ def evaluate_likelihood(
def evaluate_gradient_likelihood(
self, eta: NDArray, y: NDArray, **kwargs
) -> NDArray:
- h2_scaled = kwargs.get("h2")
-
- # rescale h2 to (0,1) as it's currently between -INF:+INF
- h2 = scaled_logit(h2_scaled, direction="backward")
- if(h2 == 1):
+ h2 = kwargs.get("h2")
+
+ if h2 == 1:
raise ValueError("h2 is 1. Will lead to division by zero.")
gradient = -1 / (1 - h2) * (eta - y)
@@ -119,11 +99,9 @@ def evaluate_gradient_likelihood(
return gradient
def evaluate_d_matrix(self, **kwargs) -> NDArray:
- h2_scaled = kwargs.get("h2")
- # rescale h2 to (0,1) as it's currently between -INF:+INF
- h2 = scaled_logit(h2_scaled, direction="backward")
+ h2 = kwargs.get("h2")
- if(h2 == 1):
+ if h2 == 1:
raise ValueError("h2 is 1. Will lead to division by zero.")
d_matrix = -1 / (1 - h2) * sp.sparse.eye(self.a.shape[0])
@@ -135,20 +113,20 @@ def __str__(self) -> str:
# --- Make the Submodel table ---
values = [
- ["h2", f"{self.config.h2:.3f}"],
- ["alpha", f"{self.config.alpha:.3f}"],
+ ["h2", f"{self.config.h2:.3f}"],
+ ["alpha", [f"{a:.3f}" for a in self.config.alpha]],
]
submodel_table = tabulate(
values,
tablefmt="fancy_grid",
colalign=("left", "center"),
)
-
+
# Add the header title
submodel_table = add_str_header(
title=self.submodel_type.replace("_", " ").title(),
table=submodel_table,
)
str_representation += submodel_table
-
+
return str_representation
diff --git a/src/dalia/submodels/regression.py b/src/dalia/submodels/regression.py
index 3de5e70f..986d05e2 100644
--- a/src/dalia/submodels/regression.py
+++ b/src/dalia/submodels/regression.py
@@ -6,6 +6,7 @@
from dalia.core.submodel import SubModel
from dalia.utils import add_str_header
+
class RegressionSubModel(SubModel):
"""Fit a regression model."""
@@ -40,15 +41,15 @@ def __str__(self) -> str:
# --- Make the Submodel table ---
values = [
- ["Number of Fixed Effects", self.n_fixed_effects],
- ["Prior Precision of Fixed Effects", self.fixed_effects_prior_precision],
+ ["Number of Fixed Effects", self.n_fixed_effects],
+ ["Prior Precision of Fixed Effects", self.fixed_effects_prior_precision],
]
submodel_table = tabulate(
values,
tablefmt="fancy_grid",
colalign=("left", "center"),
)
-
+
# Add the header title
submodel_table = add_str_header(
title=self.submodel_type.replace("_", " ").title(),
@@ -56,4 +57,4 @@ def __str__(self) -> str:
)
str_representation += submodel_table
- return str_representation
\ No newline at end of file
+ return str_representation
diff --git a/src/dalia/submodels/spatial.py b/src/dalia/submodels/spatial.py
index 2e441e87..7c98b720 100644
--- a/src/dalia/submodels/spatial.py
+++ b/src/dalia/submodels/spatial.py
@@ -1,16 +1,17 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
import math
-from tabulate import tabulate
import numpy as np
from scipy.sparse import csc_matrix, load_npz, spmatrix
+from tabulate import tabulate
from dalia import sp, xp
from dalia.configs.submodels_config import SpatialSubModelConfig
from dalia.core.submodel import SubModel
from dalia.utils import add_str_header
+
class SpatialSubModel(SubModel):
"""Fit a spatial model."""
@@ -123,7 +124,7 @@ def __str__(self) -> str:
# --- Make the Submodel table ---
values = [
- ["Number of Spatial Nodes", self.ns],
+ ["Number of Spatial Nodes", self.ns],
["Spatial Range (r_s)", f"{self.config.r_s:.3f}"],
["Spatial Variation (sigma_e)", f"{self.sigma_e:.3f}"],
]
@@ -132,12 +133,12 @@ def __str__(self) -> str:
tablefmt="fancy_grid",
colalign=("left", "center"),
)
-
+
# Add the header title
submodel_table = add_str_header(
title=self.submodel_type.replace("_", " ").title(),
table=submodel_table,
)
str_representation += submodel_table
-
- return str_representation
\ No newline at end of file
+
+ return str_representation
diff --git a/src/dalia/submodels/spatio_temporal.py b/src/dalia/submodels/spatio_temporal.py
index baead545..728c4176 100644
--- a/src/dalia/submodels/spatio_temporal.py
+++ b/src/dalia/submodels/spatio_temporal.py
@@ -1,16 +1,17 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
import math
-from tabulate import tabulate
import numpy as np
from scipy.sparse import csc_matrix, load_npz, spmatrix
+from tabulate import tabulate
from dalia import sp, xp
from dalia.configs.submodels_config import SpatioTemporalSubModelConfig
from dalia.core.submodel import SubModel
from dalia.utils import add_str_header
+
class SpatioTemporalSubModel(SubModel):
"""Fit a spatio-temporal model."""
@@ -268,9 +269,9 @@ def __str__(self) -> str:
# --- Make the Submodel table ---
values = [
- ["Number of Spatial Nodes", self.ns],
- ["Number of Temporal Nodes", self.nt],
- ["Manifold", self.manifold.capitalize()],
+ ["Number of Spatial Nodes", self.ns],
+ ["Number of Temporal Nodes", self.nt],
+ ["Manifold", self.manifold.capitalize()],
["Spatial Range (r_s)", f"{self.config.r_s:.3f}"],
["Temporal Range (r_t)", f"{self.config.r_t:.3f}"],
["Spatio-temporal Variation (sigma_st)", f"{self.sigma_st:.3f}"],
@@ -280,7 +281,7 @@ def __str__(self) -> str:
tablefmt="fancy_grid",
colalign=("left", "center"),
)
-
+
# Add the header title
submodel_table = add_str_header(
title=self.submodel_type.replace("_", " ").title(),
@@ -289,4 +290,3 @@ def __str__(self) -> str:
str_representation += submodel_table
return str_representation
-
diff --git a/src/dalia/utils/__init__.py b/src/dalia/utils/__init__.py
index b1a25fe8..fd63b5b7 100644
--- a/src/dalia/utils/__init__.py
+++ b/src/dalia/utils/__init__.py
@@ -1,30 +1,42 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
from dalia.utils.gpu_utils import (
+ format_size,
+ free_unused_gpu_memory,
get_array_module_name,
get_available_devices,
get_device,
get_host,
- set_device,
- free_unused_gpu_memory,
memory_report,
- format_size,
+ set_device,
)
from dalia.utils.host import get_host_configuration
from dalia.utils.link_functions import cloglog, scaled_logit, sigmoid
+from dalia.utils.correlation import compute_outer_covariance_matrix
+from dalia.utils.gaussian_quadrature import compute_variance_gauss_hermite
+from dalia.utils.bivariate_gaussian_quadrature import compute_bivariate_expectation
from dalia.utils.multiprocessing import (
- allreduce,
+ DummyCommunicator,
allgather,
+ allreduce,
bcast,
get_active_comm,
print_msg,
smartsplit,
synchronize,
synchronize_gpu,
- DummyCommunicator,
+ check_vector_consistency,
+)
+from dalia.utils.print_utils import (
+ add_str_header,
+ align_tables_side_by_side,
+ ascii_logo,
+ boxify,
)
from dalia.utils.spmatrix_utils import bdiag_tiling, extract_diagonal, memory_footprint
from dalia.utils.print_utils import add_str_header, align_tables_side_by_side, boxify, ascii_logo
+from dalia.utils.plotting import plot_marginal_distributions_hp, plot_prior_hp
+from .scalar_ndarray import ensure_scalar
__all__ = [
"get_available_devices",
@@ -36,6 +48,9 @@
"sigmoid",
"cloglog",
"scaled_logit",
+ "compute_outer_covariance_matrix",
+ "compute_variance_gauss_hermite",
+ "compute_bivariate_expectation",
"print_msg",
"synchronize",
"synchronize_gpu",
@@ -44,6 +59,7 @@
"allreduce",
"allgather",
"bcast",
+ "check_vector_consistency",
"bdiag_tiling",
"extract_diagonal",
"memory_footprint",
@@ -55,4 +71,6 @@
"memory_report",
"format_size",
"DummyCommunicator",
+ "plot_marginal_distributions_hp",
+ "plot_prior_hp",
]
diff --git a/src/dalia/utils/bivariate_gaussian_quadrature.py b/src/dalia/utils/bivariate_gaussian_quadrature.py
new file mode 100644
index 00000000..d9a113fb
--- /dev/null
+++ b/src/dalia/utils/bivariate_gaussian_quadrature.py
@@ -0,0 +1,254 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import numpy as np
+from dalia import xp
+from scipy.special import roots_hermite
+
+
+def compute_bivariate_expectation(func1, func2, mu1, mu2, Sigma, n_points=20):
+ """
+ Compute E[f(Z₁, Z₂)] where (Z₁, Z₂) ~ N(0, Σ) using bivariate Gauss-Hermite quadrature.
+
+ E[f(Z₁, Z₂)] = ∬ f(z₁, z₂) φ_ρ(z₁, z₂) dz₁ dz₂
+
+ where φ_ρ(z₁, z₂) is the bivariate standard normal density with correlation ρ.
+
+ Parameters
+ ----------
+ func : callable
+ Function f(z₁, z₂) to compute expectation of
+ rho : float, optional
+ Correlation coefficient between Z₁ and Z₂ (default: 0.0)
+ n_points : int, optional
+ Number of quadrature points per dimension (default: 20)
+
+ Returns
+ -------
+ float
+ Expected value E[f(Z₁, Z₂)]
+ """
+ # Get Gauss-Hermite quadrature points and weights
+ nodes, weights = roots_hermite(n_points)
+
+ nodes = xp.array(nodes)
+ weights = xp.array(weights)
+
+ # Transform nodes from Hermite polynomial roots to standard normal
+ z_nodes = xp.sqrt(2) * nodes
+ adjusted_weights = weights / xp.sqrt(xp.pi)
+
+ rho = Sigma[0, 1] # Correlation coefficient
+ sigma1 = xp.sqrt(Sigma[0, 0])
+ sigma2 = xp.sqrt(Sigma[1, 1])
+
+ # For bivariate case with correlation ρ and means μ₁, μ₂:
+ # If (U₁, U₂) are independent N(0,1), then:
+ # Z₁ = μ₁ + U₁
+ # Z₂ = μ₂ + ρU₁ + √(1-ρ²)U₂
+ # gives (Z₁, Z₂) ~ N([μ₁, μ₂], [[1, ρ], [ρ, 1]])
+
+ if xp.abs(rho) > 1:
+ raise ValueError("Correlation coefficient rho must be in [-1, 1]")
+
+ sqrt_one_minus_rho_sq = xp.sqrt(1 - rho**2)
+
+ expectation = 0.0
+
+ # Double loop over all combinations of quadrature points
+ for i, (u1, w1) in enumerate(zip(z_nodes, adjusted_weights)):
+ for j, (u2, w2) in enumerate(zip(z_nodes, adjusted_weights)):
+ # Transform to correlated variables
+ z1 = mu1 + sigma1 * u1
+ z2 = mu2 + rho * sigma1 * u1 + sigma2 * sqrt_one_minus_rho_sq * u2
+
+ # Evaluate function at transformed points
+ f_val = func1(z1) * func2(z2)
+
+ # Combined weight (product of univariate weights)
+ weight = w1 * w2
+
+ expectation += weight * f_val
+
+ return expectation
+
+
+def test_quadrature_accuracy():
+ """Test quadrature accuracy against known analytical results."""
+
+ print("=" * 80)
+ print("Testing Bivariate Gaussian Quadrature")
+ print("=" * 80)
+
+ # Test 2: Bivariate expectations with independence (ρ = 0)
+ print("2. Testing bivariate expectations with independence (ρ = 0):")
+
+ # E[1] = 1
+ biv_expectation_1 = compute_bivariate_expectation(lambda z: 1.0, lambda z: 1.0, rho=0.0, n_points=10)
+ print(f" E[1] = {biv_expectation_1:.8f} (should be 1.0, error: {abs(1.0 - biv_expectation_1):.2e})")
+
+ # E[Z₁Z₂] = 0 when independent
+ biv_expectation_z1z2 = compute_bivariate_expectation(lambda z: z, lambda z: z, rho=0.0, n_points=20)
+ print(f" E[Z₁Z₂] = {biv_expectation_z1z2:.8f} (should be 0.0, error: {abs(biv_expectation_z1z2):.2e})")
+
+ # E[Z₁² + Z₂²] = E[Z₁²] + E[Z₂²] = 1 + 1 = 2 (using independence)
+ biv_expectation_sq = compute_bivariate_expectation(lambda z: z**2, lambda z: z**2, rho=0.0, n_points=20)
+ print(f" E[Z₁² + Z₂²] = {biv_expectation_sq:.8f} (should be 2.0, error: {abs(2.0 - biv_expectation_sq):.2e})")
+ print()
+
+ # Test 3: Product expectations with independence
+ print("3. Testing product expectations E[f₁(Z₁)f₂(Z₂)] with independence:")
+
+ # E[Z₁ * Z₂] = E[Z₁] * E[Z₂] = 0 * 0 = 0 when independent
+ prod_exp_z1z2 = compute_bivariate_expectation(lambda z: z, lambda z: z, rho=0.0, n_points=20)
+ print(f" E[Z₁ * Z₂] = {prod_exp_z1z2:.8f} (should be 0.0, error: {abs(prod_exp_z1z2):.2e})")
+
+ # E[Z₁² * 1] = E[Z₁²] * E[1] = 1 * 1 = 1
+ prod_exp_z1sq_1 = compute_bivariate_expectation(lambda z: z**2, lambda z: 1.0, rho=0.0, n_points=20)
+ print(f" E[Z₁² * 1] = {prod_exp_z1sq_1:.8f} (should be 1.0, error: {abs(1.0 - prod_exp_z1sq_1):.2e})")
+
+ # E[exp(Z₁) * exp(Z₂)] = E[exp(Z₁)] * E[exp(Z₂)] = exp(0.5) * exp(0.5) = exp(1.0)
+ prod_exp_exp = compute_bivariate_expectation(lambda z: xp.exp(z), lambda z: xp.exp(z), rho=0.0, n_points=30)
+ analytical_prod_exp = xp.exp(1.0)
+ print(f" E[exp(Z₁) * exp(Z₂)] = {prod_exp_exp:.8f} (should be {analytical_prod_exp:.8f}, error: {abs(analytical_prod_exp - prod_exp_exp):.2e})")
+ print()
+
+ # Test 4: Bivariate expectations with correlation
+ print("4. Testing bivariate expectations with correlation:")
+
+ correlations = [-0.9, -0.5, -0.3, 0.0, 0.3, 0.5, 0.8, 0.9]
+
+ print(" Correlation | E[Z₁Z₂] | Analytical | Error")
+ print(" ------------|------------|------------|----------")
+
+ for rho in correlations:
+ # E[Z₁Z₂] = ρ for bivariate normal
+ biv_exp_corr = compute_bivariate_expectation(lambda z: z, lambda z: z, rho=rho, n_points=25)
+ error = abs(rho - biv_exp_corr)
+ print(f" {rho:10.1f} | {biv_exp_corr:10.6f} | {rho:10.6f} | {error:.2e}")
+ print()
+
+ # Test 4b: Extreme correlation values (±1)
+ print("4b. Testing extreme correlation values (ρ = ±1):")
+ print(" Note: Perfect correlation means Z₂ = ±Z₁")
+
+ extreme_correlations = [-1.0, 1.0]
+
+ print(" Correlation | E[Z₁Z₂] | Analytical | Error | Note")
+ print(" ------------|------------|------------|----------|------------------")
+
+ for rho in extreme_correlations:
+ # For extreme correlations, use more quadrature points for accuracy
+ biv_exp_corr = compute_bivariate_expectation(lambda z: z, lambda z: z, rho=rho, n_points=40)
+ error = abs(rho - biv_exp_corr)
+ note = "Perfect positive" if rho == 1.0 else "Perfect negative"
+ print(f" {rho:10.1f} | {biv_exp_corr:10.6f} | {rho:10.6f} | {error:.2e} | {note}")
+
+ # Additional test: For ρ = ±1, E[Z₁²Z₂²] should equal E[Z₁⁴] = 3
+ z1_sq_z2_sq = compute_bivariate_expectation(lambda z: z**2, lambda z: z**2, rho=rho, n_points=40)
+ expected_z1_4 = 3.0 # Fourth moment of standard normal
+ error_z4 = abs(z1_sq_z2_sq - expected_z1_4)
+ print(f" {rho:10.1f} | E[Z₁²Z₂²]={z1_sq_z2_sq:6.4f} | E[Z₁⁴]={expected_z1_4:6.1f} | {error_z4:.2e} | Should equal E[Z₁⁴]")
+
+ print()
+
+ # Test 5: Product expectations with correlation
+ print("5. Testing product expectations with correlation:")
+ print(" For E[f₁(Z₁)f₂(Z₂)], correlation affects the result when f₁ and f₂ are nonlinear")
+
+ def f1(z):
+ return z**2
+
+ def f2(z):
+ return z**3
+
+ def f_prod(z):
+ return f1(z[0]) * f2(z[1])
+
+ print(" E[Z₁² * Z₂³] for different correlations:")
+ # print(" Correlation | E[Z₁²Z₂³]")
+ # print(" ------------|----------")
+ print(" Correlation | E[Z₁²Z₂³] Numerical | E[Z₁²Z₂³] GHQ | Error")
+
+ for rho in [-0.9, -0.5, 0.0, 0.5, 0.9]:
+ prod_exp_nonlinear = compute_bivariate_expectation(f1, f2, rho=rho, n_points=30)
+
+ Sigma = xp.array([[1.0, rho], [rho, 1.0]])
+ mu = xp.array([0.0, 0.0])
+ import ghq
+
+ prod_ref = ghq.multivariate(f_prod, mu, Sigma, n_points=30)
+ error = abs(prod_exp_nonlinear - prod_ref)
+
+ print(f" {rho:10.1f} | {prod_exp_nonlinear:10.6f} | {prod_ref:10.6f} | {error:.2e}")
+ print()
+
+ # Test 6: Convergence study
+ print("6. Convergence study (increasing number of quadrature points):")
+
+ # For this test, we approximate E[exp(0.1*Z₁ + 0.2*Z₂)] using E[exp(0.1*Z₁) * exp(0.2*Z₂)]
+ # This is exact since exp(a+b) = exp(a)*exp(b)
+ def func1_test(z):
+ return xp.exp(0.1 * z)
+
+ def func2_test(z):
+ return xp.exp(0.2 * z)
+
+ # Analytical result for E[exp(aZ₁ + bZ₂)] with (Z₁,Z₂) ~ N(0, Σ)
+ # For a=0.1, b=0.2, ρ=0.5: E[exp(aZ₁ + bZ₂)] = exp(0.5 * (a² + b² + 2abρ))
+ a, b, rho_test = 0.1, 0.2, 0.5
+ analytical_mgf = xp.exp(0.5 * (a**2 + b**2 + 2*a*b*rho_test))
+
+ print(f" Testing E[exp(0.1*Z₁) * exp(0.2*Z₂)] with ρ = {rho_test}")
+ print(f" Analytical result: {analytical_mgf:.8f}")
+ print()
+ print(" n_points | Numerical | Error")
+ print(" ---------|--------------|----------")
+
+ for n in [5, 10, 15, 20, 25, 30]:
+ numerical_mgf = compute_bivariate_expectation(func1_test, func2_test, rho=rho_test, n_points=n)
+ error = abs(analytical_mgf - numerical_mgf)
+ print(f" {n:7d} | {numerical_mgf:12.8f} | {error:.2e}")
+
+ print()
+
+ # Test 7: Practical example with transformations
+ print("7. Practical example: Log-normal random variables")
+ print(" Let Y₁ = exp(Z₁), Y₂ = exp(Z₂) where (Z₁, Z₂) have correlation ρ")
+ print(" Computing E[Y₁ * Y₂] = E[exp(Z₁ + Z₂)]")
+
+ rho_examples = [-0.8, -0.3, 0.0, 0.5, 0.9]
+
+ print(" Correlation | E[Y₁Y₂] Numerical | E[Y₁Y₂] Analytical | Error")
+ print(" ------------|------------------|-------------------|----------")
+
+ for rho in rho_examples:
+ # Numerical computation
+ numerical_lognormal = compute_bivariate_expectation(
+ lambda z: xp.exp(z), lambda z: xp.exp(z),
+ rho=rho, n_points=30
+ )
+
+ # Analytical: E[exp(Z₁ + Z₂)] = exp(E[Z₁ + Z₂] + 0.5*Var[Z₁ + Z₂])
+ # E[Z₁ + Z₂] = 0, Var[Z₁ + Z₂] = Var[Z₁] + Var[Z₂] + 2*Cov[Z₁,Z₂] = 1 + 1 + 2*ρ = 2 + 2*ρ
+ analytical_lognormal = xp.exp(0.5 * (2 + 2*rho))
+
+ error = abs(numerical_lognormal - analytical_lognormal)
+
+ print(f" {rho:10.1f} | {numerical_lognormal:16.8f} | {analytical_lognormal:17.8f} | {error:.2e}")
+
+ print()
+ print("=" * 80)
+ print("Bivariate Gaussian Quadrature Test Summary:")
+ print("✓ Univariate expectations computed accurately")
+ print("✓ Bivariate expectations with independence verified")
+ print("✓ Product expectations working correctly")
+ print("✓ Correlation effects properly captured")
+ print("✓ Convergence behavior as expected")
+ print("✓ Practical log-normal example validated")
+ print()
+ print("All bivariate Gaussian quadrature functions are working correctly!")
+ print("=" * 80)
+
+
+if __name__ == "__main__":
+ test_quadrature_accuracy()
\ No newline at end of file
diff --git a/src/dalia/utils/correlation.py b/src/dalia/utils/correlation.py
new file mode 100644
index 00000000..80dc4c0c
--- /dev/null
+++ b/src/dalia/utils/correlation.py
@@ -0,0 +1,594 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import numpy as np
+from dalia import xp
+import matplotlib.pyplot as plt
+
+# Import our quadrature functions
+from dalia.utils.gaussian_quadrature import compute_variance_gauss_hermite
+from dalia.utils.bivariate_gaussian_quadrature import compute_bivariate_expectation
+# Import reparametrization functions
+from dalia.utils.reparametrizations import (
+ compute_transformed_quantiles,
+ compute_transformed_pdf,
+ compute_bounds
+)
+
+def compute_outer_covariance_matrix(mean_internal, cov_internal, transform_list, n_points=25):
+ """
+ Compute covariance matrix between all pairs of transformed parameters using bivariate quadrature.
+ The results are to be considered with caution. Covariance matrices provide reliable information in
+ the Gaussian context, however, not necessarily in the transformed space.
+
+ Parameters
+ ----------
+ mean_internal : ndarray
+ Mean vector of internal distribution
+ cov_internal : ndarray
+ Covariance matrix of internal distribution
+ transform_func : function
+ Transformation function that takes (theta_vector, direction) and returns transformed vector
+ This should be the model's rescale_hyperparameters_to_internal method
+ n_points : int
+ Number of quadrature points per dimension
+
+ Returns
+ -------
+ ndarray
+ Covariance matrix between transformed parameters (outer space)
+ """
+
+ n_dim = len(mean_internal)
+ outer_cov_matrix = np.zeros((n_dim, n_dim))
+
+ mean_outer = []
+ marginal_vars = []
+
+ print("Computing marginal means and variances for correlation calculations...")
+
+ for i in range(n_dim):
+ mu_i = mean_internal[i]
+ var_i = cov_internal[i, i]
+
+ # Compute marginal statistics
+ result = compute_variance_gauss_hermite(mu_i, var_i, transform_list[i], n_points)
+ mean_outer.append(result['mean'])
+ marginal_vars.append(result['variance'])
+
+ outer_cov_matrix[i, i] = result['variance']
+
+ print("Computing pairwise covariances...")
+
+ # Compute pairwise covariances
+ for i in range(n_dim):
+ for j in range(n_dim):
+ if i < j: # Only compute upper triangle, then symmetrize, i == j already filled
+ print(f" Computing Cov(X_{i+1}, X_{j+1})...", end=" ")
+
+ # Extract marginal parameters
+ mu_i, mu_j = mean_internal[i], mean_internal[j]
+ var_i, var_j = cov_internal[i, i], cov_internal[j, j]
+ cov_ij = cov_internal[i, j]
+
+ # Compute correlation coefficient in internal space
+ rho_internal = cov_ij / np.sqrt(var_i * var_j)
+
+ # Create transformation functions for these parameters
+ transform_func_i = transform_list[i]
+ transform_func_j = transform_list[j]
+
+ # Standardize the variables for bivariate quadrature
+ def standardized_func_i(z):
+ x_internal = mu_i + np.sqrt(var_i) * z
+ return transform_func_i(x_internal, "backward")
+
+ def standardized_func_j(z):
+ x_internal = mu_j + np.sqrt(var_j) * z
+ return transform_func_j(x_internal, "backward")
+
+ # Compute E[f_i(Z_i) * f_j(Z_j)] using bivariate quadrature
+ cross_moment = compute_bivariate_expectation(
+ standardized_func_i, standardized_func_j,
+ rho=rho_internal, n_points=n_points
+ )
+
+ # Covariance: Cov(X,Y) = E[XY] - E[X]E[Y]
+ covariance_outer = cross_moment - mean_outer[i] * mean_outer[j]
+
+ # Store covariance directly
+ outer_cov_matrix[i, j] = covariance_outer
+ outer_cov_matrix[j, i] = covariance_outer # Symmetric
+
+ print(f"{covariance_outer:.6f}")
+
+ return outer_cov_matrix
+
+
+if __name__ == "__main__":
+
+ """
+ Main test function for multivariate transformations.
+ """
+
+ # Dummy Prior Hyperparameter Class with Log Transform
+ class Prior_LogTransform:
+ """
+ Dummy implementation of gamma prior rescaling for testing purposes.
+ Implements log transformation: forward = log(x), backward = exp(x)
+ """
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """Log transformation between positive (external) and unconstrained (internal) space"""
+ if direction == "forward":
+ return np.log(theta) # theta -> log(theta)
+ elif direction == "backward":
+ return np.exp(theta) # log(theta) -> theta
+ elif direction == "forward_jacobian":
+ return 1.0 / theta # d(log(theta))/d(theta) = 1/theta
+ elif direction == "backward_jacobian":
+ return theta # d(exp(theta))/d(theta) = exp(theta) = theta
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ class Prior_LogisticTransform:
+ """
+ Dummy implementation of beta prior rescaling for testing purposes.
+ Implements logistic transformation: forward = logit(x), backward = sigmoid(x)
+ """
+
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """Logistic transformation between (0,1) and unconstrained space"""
+ if direction == "forward":
+ return np.log(theta / (1 - theta)) # logit(theta)
+ elif direction == "backward":
+ return 1 / (1 + np.exp(-theta)) # sigmoid(theta)
+ elif direction == "forward_jacobian":
+ return 1.0 / (theta * (1 - theta)) # d(logit(theta))/d(theta)
+ elif direction == "backward_jacobian":
+ sig = 1 / (1 + np.exp(-theta))
+ return sig * (1 - sig) # d(sigmoid(theta))/d(theta)
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ class Prior_IdentityTransform:
+ """
+ Dummy implementation of identity prior rescaling for testing purposes.
+ Implements identity transformation: forward = x, backward = x
+ """
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """Identity transformation (no change)"""
+ if direction in ["forward", "backward"]:
+ return theta
+ elif direction in ["forward_jacobian", "backward_jacobian"]:
+ return 1.0
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ def generate_random_covariance_matrix(n_dim=3, condition_number=10.0, random_seed=42):
+ """
+ Generate a random positive definite covariance matrix.
+
+ Parameters
+ ----------
+ n_dim : int
+ Dimension of the covariance matrix
+ condition_number : float
+ Maximum condition number (controls how ill-conditioned the matrix can be)
+ random_seed : int
+ Random seed for reproducibility
+
+ Returns
+ -------
+ ndarray
+ Random positive definite covariance matrix
+ """
+ np.random.seed(random_seed)
+
+ # Generate random eigenvalues between 1/condition_number and 1
+ eigenvals = np.random.uniform(1.0/condition_number, 1.0, n_dim)
+ eigenvals = np.sort(eigenvals)[::-1] # Sort in descending order
+
+ # Generate random orthogonal matrix (eigenvectors)
+ Q, _ = np.linalg.qr(np.random.randn(n_dim, n_dim))
+
+ # Construct covariance matrix: Σ = Q * diag(eigenvals) * Q^T
+ cov_matrix = Q @ np.diag(eigenvals) @ Q.T
+
+ return cov_matrix
+
+ print("=" * 90)
+ print("MULTIVARIATE TRANSFORMATION TEST")
+ print("Testing 3D Gaussian → Transformed Space using Gaussian Quadrature")
+ print("=" * 90)
+
+ # Set parameters
+ n_dim = 3
+ n_quad_points = 30
+
+ # Step 1: Generate random mean vector and covariance matrix
+ print("1. Generating Random 3D Gaussian Distribution")
+ print("-" * 50)
+
+ np.random.seed(42) # For reproducibility
+ mean_internal = np.random.uniform(-1, 1, n_dim)
+ cov_internal = generate_random_covariance_matrix(n_dim, condition_number=5.0)
+
+ print("Internal (Gaussian) Distribution Parameters:")
+ print(f"Mean vector: {mean_internal}")
+ print("Covariance matrix:")
+ print(cov_internal)
+ print(f"Condition number: {np.linalg.cond(cov_internal):.2f}")
+ print()
+
+ # Step 2: Create transformation functions
+ print("2. Setting up Transformation Functions")
+ print("-" * 50)
+
+ prior_log_transform = Prior_LogTransform()
+
+ def log_transform_func(x, direction):
+ return prior_log_transform.rescale_hyperparameters_to_internal(x, direction)
+
+ prior_logistic_transform = Prior_LogisticTransform()
+ def logistic_transform_func(x, direction):
+ return prior_logistic_transform.rescale_hyperparameters_to_internal(x, direction)
+
+ prior_identity_transform = Prior_IdentityTransform()
+ def identity_transform_func(x, direction):
+ return prior_identity_transform.rescale_hyperparameters_to_internal(x, direction)
+
+ # Apply different transforms to each dimension (directly assign transforms to parameters)
+ transforms = [
+ log_transform_func, # Parameter 1: Log transform
+ logistic_transform_func, # Parameter 2: Logistic transform
+ identity_transform_func # Parameter 3: Identity transform
+ ]
+
+ # Step 3: Compute marginal statistics in outer space
+ print("3. Computing Marginal Statistics in Outer Space")
+ print("-" * 50)
+
+ marginal_stats = []
+
+ for i in range(n_dim):
+ # Extract marginal parameters
+ mean_i = mean_internal[i]
+ var_i = cov_internal[i, i]
+
+ # Compute marginal statistics
+ stats = compute_variance_gauss_hermite(
+ mean_i, var_i, transforms[i], n_quad_points
+ )
+
+ marginal_stats.append(stats)
+
+ print(f"Parameter {i+1} ({stats['transform_name']}):")
+ print(f" Internal: μ = {mean_i:.4f}, σ² = {var_i:.4f}")
+ print(f" Outer: μ = {stats['mean']:.4f}, σ² = {stats['variance']:.4f}, σ = {stats['std']:.4f}")
+ print()
+
+ # Step 4: Compute covariance matrix in outer space
+ print("4. Computing Covariance Matrix in Outer Space")
+ print("-" * 50)
+
+ # Internal covariance matrix (for reference)
+ internal_cov_matrix = cov_internal.copy()
+
+ print("Internal (Gaussian) covariance matrix:")
+ print(internal_cov_matrix)
+ print()
+
+ # Create a mock transformation function that applies different transforms to each parameter
+ def mock_vectorized_transform(theta_vec, direction):
+ """Mock transformation function that applies different transforms to each parameter."""
+ result = theta_vec.copy()
+ for i, transform_func in enumerate(transforms):
+ if i < len(result):
+ result[i] = transform_func(theta_vec[i], direction)
+ return result
+
+ # Compute outer covariance matrix using the new function
+ outer_cov_matrix = compute_outer_covariance_matrix(
+ mean_internal, cov_internal, mock_vectorized_transform, n_quad_points
+ )
+
+ print("\nOuter (Transformed) covariance matrix:")
+ print(outer_cov_matrix)
+ print()
+
+ # Convert to correlation matrices for comparison
+ # Internal correlations (for reference)
+ internal_corr_matrix = np.zeros((n_dim, n_dim))
+ for i in range(n_dim):
+ for j in range(n_dim):
+ if i == j:
+ internal_corr_matrix[i, j] = 1.0
+ else:
+ internal_corr_matrix[i, j] = (cov_internal[i, j] /
+ np.sqrt(cov_internal[i, i] * cov_internal[j, j]))
+
+ # Outer correlations (derived from covariance matrix)
+ outer_corr_matrix = np.zeros((n_dim, n_dim))
+ for i in range(n_dim):
+ for j in range(n_dim):
+ if i == j:
+ outer_corr_matrix[i, j] = 1.0
+ else:
+ outer_corr_matrix[i, j] = (outer_cov_matrix[i, j] /
+ np.sqrt(outer_cov_matrix[i, i] * outer_cov_matrix[j, j]))
+
+ print("Derived outer correlation matrix:")
+ print(outer_corr_matrix)
+ print()
+
+ # Step 5: Compare transformations
+ print("5. Transformation Effects Analysis")
+ print("-" * 50)
+
+ print("Comparison of Internal vs Outer Statistics:")
+ print(f"{'Parameter':<12} {'Transform':<25} {'Mean Change':<12} {'Var Change':<12} {'Cov Change':<12}")
+ print("-" * 85)
+
+ for i in range(n_dim):
+ mean_change = abs(marginal_stats[i]['mean'] - mean_internal[i])
+ var_change = abs(marginal_stats[i]['variance'] - cov_internal[i, i])
+
+ # Average covariance change for this parameter
+ cov_changes = []
+ for j in range(n_dim):
+ if i != j:
+ cov_changes.append(abs(outer_cov_matrix[i, j] - internal_cov_matrix[i, j]))
+ avg_cov_change = np.mean(cov_changes) if cov_changes else 0.0
+
+ transform_name = transforms[i].name.split(' ')[0]
+
+ print(f"{i+1:<12} {transform_name:<25} {mean_change:<12.4f} {var_change:<12.4f} {avg_cov_change:<12.4f}")
+
+ print()
+
+ # Additional covariance matrix validation
+ print("Covariance Matrix Validation:")
+ print("-" * 50)
+
+ # Check if outer covariance matrix is positive semidefinite
+ eigenvals = np.linalg.eigvals(outer_cov_matrix)
+ is_pos_def = np.all(eigenvals >= -1e-10) # Allow small numerical errors
+
+ print(f"Outer covariance matrix eigenvalues: {eigenvals}")
+ print(f"Is positive semidefinite: {is_pos_def}")
+
+ # Check symmetry
+ is_symmetric = np.allclose(outer_cov_matrix, outer_cov_matrix.T)
+ print(f"Is symmetric: {is_symmetric}")
+
+ # Check diagonal elements (should be positive variances)
+ diag_elements = np.diag(outer_cov_matrix)
+ all_positive_vars = np.all(diag_elements > 0)
+ print(f"All diagonal elements (variances) positive: {all_positive_vars}")
+ print(f"Outer variances: {diag_elements}")
+
+ print()
+
+ # Step 6: Analytical validation for specific cases
+ print("6. Analytical Validation")
+ print("-" * 50)
+
+ # For log transformation (Parameter 1), we can validate against log-normal theory
+ if transforms[0].name.startswith("Log"): # Log transform
+ mu_1 = mean_internal[0]
+ sigma2_1 = cov_internal[0, 0]
+
+ # Analytical log-normal moments
+ analytical_mean = np.exp(mu_1 + sigma2_1/2)
+ analytical_var = (np.exp(sigma2_1) - 1) * np.exp(2*mu_1 + sigma2_1)
+
+ numerical_mean = marginal_stats[0]['mean']
+ numerical_var = marginal_stats[0]['variance']
+
+ print("Log-normal validation (Parameter 1):")
+ print(f" Analytical mean: {analytical_mean:.6f}")
+ print(f" Numerical mean: {numerical_mean:.6f}")
+ print(f" Relative error: {abs(analytical_mean - numerical_mean)/analytical_mean:.2e}")
+ print(f" Analytical var: {analytical_var:.6f}")
+ print(f" Numerical var: {numerical_var:.6f}")
+ print(f" Relative error: {abs(analytical_var - numerical_var)/analytical_var:.2e}")
+ print()
+
+ # Demonstrate reparametrization functions
+ print("Reparametrization functions analysis:")
+
+ # Create transform function for reparametrization utilities
+ transform_func = transforms[0] # Log transform
+ def reparam_func(x, direction):
+ return transform_func(x, direction)
+
+ # Compute quantiles using reparametrization function
+ percentiles = np.array([0.025, 0.25, 0.5, 0.75, 0.975])
+ quantiles = compute_transformed_quantiles(mu_1, sigma2_1, percentiles, reparam_func)
+
+ print(" Quantiles in outer space:")
+ for p, q in zip(percentiles, quantiles):
+ print(f" {p*100:4.1f}%: {q:.4f}")
+
+ # Compute bounds using reparametrization function
+ (int_lower, int_upper), (orig_lower, orig_upper) = compute_bounds(
+ mu_1, sigma2_1, reparam_func, n_std=3
+ )
+ print(f" 3σ bounds: Internal [{int_lower:.3f}, {int_upper:.3f}] -> Outer [{orig_lower:.3f}, {orig_upper:.3f}]")
+
+ # Compute PDF at a few points to demonstrate reparametrization
+ test_points = [0.5, 1.0, 2.0, 5.0]
+ print(" PDF values at test points:")
+ for x_orig in test_points:
+ x_int = reparam_func(x_orig, "forward")
+ pdf_orig = compute_transformed_pdf(mu_1, sigma2_1, x_int, reparam_func)
+ print(f" x = {x_orig:.1f}: PDF = {pdf_orig:.6f}")
+ print()
+ print()
+
+ # Step 7: Create visualization
+ print("7. Generating Visualization")
+ print("-" * 50)
+
+ try:
+ # Create single figure with 3 rows (one per parameter), 2 columns (internal, outer)
+ fig, axes = plt.subplots(3, 2, figsize=(16, 12))
+
+ # Plot marginal distributions for each parameter
+ for i in range(n_dim):
+ # Get marginal parameters
+ mu_i = mean_internal[i]
+ sigma_i = np.sqrt(cov_internal[i, i])
+ transform_func = transforms[i]
+
+ # Create transform function for reparametrization
+ def param_transform_func(x, direction):
+ return transform_func(x, direction)
+
+ # === INTERNAL DISTRIBUTION PLOT ===
+ ax_int = axes[i, 0] # Row i, column 0 (internal)
+
+ # Create well-spaced internal grid
+ x_internal = np.linspace(mu_i - 4*sigma_i, mu_i + 4*sigma_i, 300)
+ pdf_internal = (1/(sigma_i * np.sqrt(2*np.pi))) * np.exp(-0.5*((x_internal - mu_i)/sigma_i)**2)
+
+ ax_int.plot(x_internal, pdf_internal, 'b-', linewidth=3, label=f'N({mu_i:.2f}, {sigma_i:.2f}²)')
+
+ # Add quantiles for internal distribution
+ internal_percentiles = np.array([0.025, 0.25, 0.5, 0.75, 0.975])
+ from scipy.stats import norm
+ internal_quantiles = norm.ppf(internal_percentiles, loc=mu_i, scale=sigma_i)
+
+ colors_int = ['red', 'orange', 'green', 'orange', 'red']
+ for p, q, color in zip(internal_percentiles, internal_quantiles, colors_int):
+ pdf_val = (1/(sigma_i * np.sqrt(2*np.pi))) * np.exp(-0.5*((q - mu_i)/sigma_i)**2)
+ ax_int.axvline(q, color=color, linestyle='--', alpha=0.7, linewidth=2)
+ if p in [0.025, 0.5, 0.975]: # Label key percentiles
+ ax_int.text(q, pdf_val * 1.05, f'{p:.3f}', rotation=90, ha='center', va='bottom', fontsize=10, fontweight='bold')
+
+ # Set internal plot properties
+ ax_int.set_xlim(mu_i - 4*sigma_i, mu_i + 4*sigma_i)
+ ax_int.set_ylim(0, max(pdf_internal) * 1.15)
+ ax_int.set_xlabel(f'Parameter {i+1} (Internal Scale)', fontsize=12)
+ ax_int.set_ylabel('PDF', fontsize=12)
+ ax_int.set_title(f'Parameter {i+1}: Internal Distribution\n{transform_func.name}', fontsize=14, fontweight='bold')
+ ax_int.legend(fontsize=11)
+ ax_int.grid(True, alpha=0.3)
+
+ # === OUTER DISTRIBUTION PLOT ===
+ ax_out = axes[i, 1] # Row i, column 1 (outer)
+
+ try:
+ # Compute bounds for outer distribution with more generous margins
+ (int_lower, int_upper), (orig_lower, orig_upper) = compute_bounds(
+ mu_i, sigma_i**2, param_transform_func, n_std=4
+ )
+
+ # Add some margin to outer bounds for better visualization
+ orig_range = orig_upper - orig_lower
+ orig_margin = orig_range * 0.1
+ orig_lower_plot = max(orig_lower - orig_margin, 1e-6) if orig_lower > 0 else orig_lower - orig_margin
+ orig_upper_plot = orig_upper + orig_margin
+
+ # Create fine grid in outer space
+ x_outer = np.linspace(orig_lower_plot, orig_upper_plot, 300)
+
+ # Filter out invalid values for certain transformations
+ if "Logistic" in transforms[i].name: # Logistic transform (0,1)
+ x_outer = x_outer[(x_outer > 0.001) & (x_outer < 0.999)]
+ elif "Log" in transforms[i].name: # Log transform (positive)
+ x_outer = x_outer[x_outer > 0.001]
+ # Identity transform needs no filtering - can handle all real values
+
+ # Compute PDF in outer space
+ pdf_outer = []
+ for x in x_outer:
+ try:
+ x_int = param_transform_func(x, "forward")
+ pdf_val = compute_transformed_pdf(mu_i, sigma_i**2, x_int, param_transform_func)
+ pdf_outer.append(pdf_val)
+ except:
+ pdf_outer.append(0.0)
+
+ pdf_outer = np.array(pdf_outer)
+
+ # Plot outer distribution
+ ax_out.plot(x_outer, pdf_outer, 'r-', linewidth=3, label=f'Transformed Distribution')
+
+ # Add quantiles for outer distribution
+ outer_quantiles = compute_transformed_quantiles(mu_i, sigma_i**2, internal_percentiles, param_transform_func)
+
+ colors_out = ['red', 'orange', 'green', 'orange', 'red']
+ for p, q, color in zip(internal_percentiles, outer_quantiles, colors_out):
+ if orig_lower_plot <= q <= orig_upper_plot: # Only plot if within bounds
+ try:
+ x_int_q = param_transform_func(q, "forward")
+ pdf_val_q = compute_transformed_pdf(mu_i, sigma_i**2, x_int_q, param_transform_func)
+ ax_out.axvline(q, color=color, linestyle='--', alpha=0.7, linewidth=2)
+ if p in [0.025, 0.5, 0.975]: # Label key percentiles
+ ax_out.text(q, pdf_val_q * 1.05, f'{p:.3f}', rotation=90, ha='center', va='bottom', fontsize=10, fontweight='bold')
+ except:
+ pass
+
+ # Set outer plot properties with proper limits
+ ax_out.set_xlim(orig_lower_plot, orig_upper_plot)
+ if len(pdf_outer) > 0 and max(pdf_outer) > 0:
+ ax_out.set_ylim(0, max(pdf_outer) * 1.15)
+
+ # Format x-axis nicely for different transformations
+ if "Log" in transforms[i].name: # Log transform
+ ax_out.set_xlabel(f'Parameter {i+1} (Outer Scale: exp)', fontsize=12)
+ elif "Logistic" in transforms[i].name: # Logistic transform
+ ax_out.set_xlabel(f'Parameter {i+1} (Outer Scale: sigmoid)', fontsize=12)
+ elif "Identity" in transforms[i].name: # Identity transform
+ ax_out.set_xlabel(f'Parameter {i+1} (Outer Scale: identity)', fontsize=12)
+
+ ax_out.set_ylabel('PDF', fontsize=12)
+ ax_out.set_title(f'Parameter {i+1}: Outer Distribution\n{transform_func.name}', fontsize=14, fontweight='bold')
+ ax_out.legend(fontsize=11)
+ ax_out.grid(True, alpha=0.3)
+
+ except Exception as e:
+ print(f"Warning: Could not plot outer distribution for parameter {i+1}: {e}")
+ ax_out.text(0.5, 0.5, f'Error plotting\nparameter {i+1}', transform=ax_out.transAxes, ha='center', va='center', fontsize=12)
+ ax_out.set_title(f'Parameter {i+1}: Error', fontsize=14)
+
+ # Add column labels
+ axes[0, 0].text(0.5, 1.15, 'Internal (Gaussian) Scale', transform=axes[0, 0].transAxes,
+ ha='center', va='bottom', fontsize=16, fontweight='bold')
+ axes[0, 1].text(0.5, 1.15, 'Outer (Transformed) Scale', transform=axes[0, 1].transAxes,
+ ha='center', va='bottom', fontsize=16, fontweight='bold')
+
+ # Finalize figure
+ fig.suptitle('Marginal Distributions: Internal vs Outer Scales', fontsize=18, fontweight='bold', y=0.98)
+ fig.tight_layout()
+ fig.subplots_adjust(top=0.92) # Make room for suptitle and column headers
+ fig.savefig('marginal_distributions_comparison.png', dpi=300, bbox_inches='tight')
+
+ # Show figure
+ plt.show()
+
+ print("✓ Marginal distributions comparison saved as 'marginal_distributions_comparison.png'")
+
+ except Exception as e:
+ print(f"⚠ Could not generate plots: {e}")
+
+ print()
+
+ # Step 8: Summary
+ print("8. Summary")
+ print("-" * 50)
+
+ # Calculate total changes in both covariance and correlation structures
+ total_cov_change = np.sum(np.abs(outer_cov_matrix - internal_cov_matrix)) / 2 # Divide by 2 due to symmetry
+ total_corr_change = np.sum(np.abs(outer_corr_matrix - internal_corr_matrix)) / 2 # Divide by 2 due to symmetry
+ print(f"✓ Total covariance structure change: {total_cov_change:.6f}")
+ print(f"✓ Total correlation structure change: {total_corr_change:.4f}")
+
+ print()
+ print("=" * 90)
+ print("MULTIVARIATE TRANSFORMATION TEST COMPLETED!")
+ print("=" * 90)
+
+
+if __name__ == "__main__":
+ test_multivariate_transformation()
\ No newline at end of file
diff --git a/src/dalia/utils/gaussian_quadrature.py b/src/dalia/utils/gaussian_quadrature.py
new file mode 100644
index 00000000..25c2dab4
--- /dev/null
+++ b/src/dalia/utils/gaussian_quadrature.py
@@ -0,0 +1,273 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+from dalia import xp
+
+from scipy.special import roots_hermite
+
+def compute_variance_gauss_hermite(mean_internal, variance_internal, transform, n_points=20):
+ """
+ Compute variance of transformed distribution using Gauss-Hermite quadrature
+
+ For a distribution Y where transform(Y) ~ N(μ, σ²), we want to compute:
+ Var(Y) = E[Y²] - (E[Y])²
+
+ Using Gauss-Hermite quadrature by transforming to standard normal form.
+
+ Theory:
+ If Z ~ N(0,1), then X = μ + σZ ~ N(μ, σ²)
+ So Y = φ⁻¹(X) = φ⁻¹(μ + σZ)
+
+ E[Y] = E[φ⁻¹(X)] = E[φ⁻¹(μ + σZ)]
+ E[Y²] = E[φ⁻¹(X)φ⁻¹(X)] = E[φ⁻¹(μ + σZ) φ⁻¹(μ + σZ)]
+ """
+
+ # Get Gauss-Hermite quadrature points and weights
+ nodes, weights = roots_hermite(n_points)
+
+ # copy nodes and weight to device
+ nodes = xp.array(nodes)
+ weights = xp.array(weights)
+
+ # Transform nodes from Hermite polynomial roots to standard normal
+ # Hermite nodes are for exp(-x²), we want exp(-x²/2)/√(2π)
+ # So we scale by √2: z = √2 * node
+ z_nodes = xp.sqrt(2) * nodes
+
+ # Compute transformed values Y = φ⁻¹(μ + σZ) for each node
+ internal_values = mean_internal + xp.sqrt(variance_internal) * z_nodes
+ y_values = transform(internal_values, direction="backward")
+
+ # Gauss-Hermite weights need to be adjusted for standard normal
+ # Original: ∫ f(x) exp(-x²) dx ≈ Σ w_i f(x_i)
+ # For standard normal: ∫ f(z) (1/√(2π)) exp(-z²/2) dz
+ # After substitution x = z/√2: ∫ f(√2 x) (1/√π) exp(-x²) dx
+ adjusted_weights = weights / xp.sqrt(xp.pi)
+
+ # Compute first and second moments
+ mean_y = xp.sum(adjusted_weights * y_values)
+ second_moment_y = xp.sum(adjusted_weights * y_values**2)
+
+ # Variance = E[Y²] - (E[Y])²
+ variance_y = second_moment_y - mean_y**2
+
+ return {
+ 'mean': mean_y,
+ 'second_moment': second_moment_y,
+ 'variance': variance_y,
+ 'std': xp.sqrt(variance_y)
+ }
+
+
+#################################################################################
+#### just for testing purposes below ##########################################
+# Dummy classes to avoid circular imports during testing
+class DummyConfig:
+ def __init__(self, alpha=2.0, beta=1.0):
+ self.alpha = alpha
+ self.beta = beta
+
+class DummyGammaPriorHyperparameters:
+ def __init__(self, config):
+ self.config = config
+
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """Log transformation: external (positive) <-> internal (unconstrained)"""
+ if direction == "forward":
+ return xp.log(theta)
+ elif direction == "backward":
+ return xp.exp(theta)
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+
+def test_gaussian_quadrature():
+ """
+ Testing Gaussian quadrature function.
+
+ Tests multiple transformation functions and compares results with
+ analytical solutions where available.
+ """
+
+ print("=" * 80)
+ print("GAUSSIAN QUADRATURE TESTS")
+ print("=" * 80)
+
+ # Test parameters
+ test_tolerance = 1e-6
+ n_quad_points = 20
+
+ # Test 1: Identity transformation (should recover original normal moments)
+ print("1. Testing Identity Transformation (X = Y)")
+ print("-" * 50)
+
+ def identity_transform(x, direction):
+ return x # No transformation
+
+ mu = 2.0
+ sigma2 = 1.5
+
+ result = compute_variance_gauss_hermite(mu, sigma2, identity_transform, n_quad_points)
+
+ # For identity, we should recover the original normal distribution moments
+ expected_mean = mu
+ expected_variance = sigma2
+
+ print(f" Expected mean: {expected_mean:.6f}, Got: {result['mean']:.6f}")
+ print(f" Expected var: {expected_variance:.6f}, Got: {result['variance']:.6f}")
+
+ mean_error = abs(result['mean'] - expected_mean)
+ var_error = abs(result['variance'] - expected_variance)
+
+ print(f" Mean error: {mean_error:.2e}")
+ print(f" Var error: {var_error:.2e}")
+
+ assert mean_error < test_tolerance, f"Identity mean test failed: error = {mean_error}"
+ assert var_error < test_tolerance, f"Identity variance test failed: error = {var_error}"
+ print(" ✓ Identity transformation test PASSED")
+ print()
+
+ # Test 2: Log transformation (log-normal distribution)
+ print("2. Testing Log Transformation (Y = exp(X))")
+ print("-" * 50)
+
+ def log_transform(x, direction):
+ if direction == "forward":
+ return xp.log(x)
+ elif direction == "backward":
+ return xp.exp(x)
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ mu = 0.5
+ sigma2 = 0.25
+
+ result = compute_variance_gauss_hermite(mu, sigma2, log_transform, n_quad_points)
+
+ # Analytical moments for log-normal: if log(Y) ~ N(μ, σ²)
+ expected_mean = xp.exp(mu + sigma2/2)
+ expected_variance = (xp.exp(sigma2) - 1) * xp.exp(2*mu + sigma2)
+
+ print(f" Expected mean: {expected_mean:.6f}, Got: {result['mean']:.6f}")
+ print(f" Expected var: {expected_variance:.6f}, Got: {result['variance']:.6f}")
+
+ mean_error = abs(result['mean'] - expected_mean) / expected_mean
+ var_error = abs(result['variance'] - expected_variance) / expected_variance
+
+ print(f" Relative mean error: {mean_error:.2e}")
+ print(f" Relative var error: {var_error:.2e}")
+
+ assert mean_error < 1e-4, f"Log-normal mean test failed: rel error = {mean_error}"
+ assert var_error < 1e-4, f"Log-normal variance test failed: rel error = {var_error}"
+ print(" ✓ Log transformation test PASSED")
+ print()
+
+ # Test 3: Linear transformation (Y = aX + b)
+ print("3. Testing Linear Transformation (Y = aX + b)")
+ print("-" * 50)
+
+ a, b = 3.0, -1.5
+
+ def linear_transform(x, direction):
+ if direction == "forward":
+ return (x - b) / a
+ elif direction == "backward":
+ return a * x + b
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ mu = 1.0
+ sigma2 = 0.8
+
+ result = compute_variance_gauss_hermite(mu, sigma2, linear_transform, n_quad_points)
+
+ # For Y = aX + b where X ~ N(μ, σ²): E[Y] = aμ + b, Var[Y] = a²σ²
+ expected_mean = a * mu + b
+ expected_variance = a**2 * sigma2
+
+ print(f" Transform: Y = {a}X + {b}")
+ print(f" Expected mean: {expected_mean:.6f}, Got: {result['mean']:.6f}")
+ print(f" Expected var: {expected_variance:.6f}, Got: {result['variance']:.6f}")
+
+ mean_error = abs(result['mean'] - expected_mean)
+ var_error = abs(result['variance'] - expected_variance)
+
+ print(f" Mean error: {mean_error:.2e}")
+ print(f" Var error: {var_error:.2e}")
+
+ assert mean_error < test_tolerance, f"Linear mean test failed: error = {mean_error}"
+ assert var_error < test_tolerance, f"Linear variance test failed: error = {var_error}"
+ print(" ✓ Linear transformation test PASSED")
+ print()
+
+ # Test 4: Gamma prior rescaling (again log transformation)
+ print("4. Testing Gamma Prior Rescaling")
+ print("-" * 50)
+
+ config = DummyConfig(alpha=2.0, beta=1.0)
+ gamma_prior = DummyGammaPriorHyperparameters(config=config)
+
+ def gamma_rescale(x, direction):
+ return gamma_prior.rescale_hyperparameters_to_internal(x, direction)
+
+ mu = 0.2
+ sigma2 = 0.3
+
+ result = compute_variance_gauss_hermite(mu, sigma2, gamma_rescale, n_quad_points)
+
+ # This is the same as log-normal since rescale uses exp transformation
+ expected_mean = xp.exp(mu + sigma2/2)
+ expected_variance = (xp.exp(sigma2) - 1) * xp.exp(2*mu + sigma2)
+
+ print(f" Expected mean: {expected_mean:.6f}, Got: {result['mean']:.6f}")
+ print(f" Expected var: {expected_variance:.6f}, Got: {result['variance']:.6f}")
+
+ mean_error = abs(result['mean'] - expected_mean) / expected_mean
+ var_error = abs(result['variance'] - expected_variance) / expected_variance
+
+ print(f" Relative mean error: {mean_error:.2e}")
+ print(f" Relative var error: {var_error:.2e}")
+
+ assert mean_error < 1e-4, f"Gamma rescale mean test failed: rel error = {mean_error}"
+ assert var_error < 1e-4, f"Gamma rescale variance test failed: rel error = {var_error}"
+ print(" ✓ Gamma prior rescaling test PASSED")
+ print()
+
+ # Test 5: Convergence with increasing quadrature points
+ print("5. Testing Convergence with Quadrature Points")
+ print("-" * 50)
+
+ mu_conv = 0.1
+ sigma2_conv = 0.4
+ expected_mean_conv = xp.exp(mu_conv + sigma2_conv/2)
+
+ n_points_list = [5, 10, 15, 20, 30, 50]
+ errors_mean = []
+ errors_var = []
+
+ print(" n_points | Mean | Rel. Mean Error | Rel. Var. Error")
+ print(" ----------|-------------|------------------|-----------------")
+
+ for n in n_points_list:
+ result = compute_variance_gauss_hermite(mu_conv, sigma2_conv, log_transform, n)
+ rel_error_mean = abs(result['mean'] - expected_mean_conv) / expected_mean_conv
+ errors_mean.append(rel_error_mean)
+
+ print(f" {n:8d} | {result['mean']:10.6f} | {rel_error_mean:5.3e} ")
+
+ # Check that errors generally decrease (allowing some numerical noise)
+ improving = sum(errors_mean[i+1] < errors_mean[i] * 1.1 for i in range(len(errors_mean)-1))
+ improvement_rate = improving / (len(errors_mean) - 1)
+
+ print(f" Improvement rate: {improvement_rate:.1%}")
+ assert improvement_rate > 0.6, f"Convergence test failed: improvement rate = {improvement_rate}"
+ print(" ✓ Convergence test PASSED")
+ print()
+
+ # Summary
+ print("=" * 80)
+ print("ALL TESTS PASSED! ✓")
+ print("=" * 80)
+
+
+if __name__ == "__main__":
+ test_gaussian_quadrature()
\ No newline at end of file
diff --git a/src/dalia/utils/gpu_utils.py b/src/dalia/utils/gpu_utils.py
index 90f7d019..f329b590 100644
--- a/src/dalia/utils/gpu_utils.py
+++ b/src/dalia/utils/gpu_utils.py
@@ -1,6 +1,9 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
import inspect
+import os
+
+import psutil
from dalia import NDArray, backend_flags, xp
@@ -100,7 +103,7 @@ def get_device(arr: NDArray) -> NDArray:
def format_size(size_bytes):
- for unit in ['B', 'KB', 'MB', 'GB']:
+ for unit in ["B", "KB", "MB", "GB"]:
if size_bytes < 1024:
return f"{size_bytes:.2f} {unit}"
size_bytes /= 1024
@@ -110,34 +113,41 @@ def format_size(size_bytes):
# query memory usage GPU and free unused memory
def free_unused_gpu_memory() -> int:
"""Free unused memory on the GPU."""
-
+
if backend_flags["cupy_avail"]:
mempool = cp.get_default_memory_pool()
-
- mempool.free_all_blocks()
+
+ mempool.free_all_blocks()
return mempool.total_bytes()
-
- else:
+
+ else:
# return dummy value for numpy
return 1
+
def memory_report() -> int:
- """Free unused memory on the GPU."""
+ """Report the current memory usage.
+
+ Returns
+ -------
+ used_memory : int
+ The current memory usage in bytes.
+ total_memory : int
+ The total memory available in bytes.
+ """
used_memory = 0
total_memory = 0
if backend_flags["cupy_avail"] and backend_flags["array_module"] == "cupy":
- # Get GPU memory usage
+ # Get (GPU) memory usage
mempool = cp.get_default_memory_pool()
used_memory = mempool.used_bytes()
total_memory = mempool.total_bytes()
else:
- # TODO: Implement on the host
- used_memory = -1
- total_memory = -1
+ # Get (CPU) memory usage
+ pid = os.getpid()
+ used_memory = psutil.Process(pid).memory_info().rss
+ total_memory = psutil.virtual_memory().total
return used_memory, total_memory
-
-
-
diff --git a/src/dalia/utils/link_functions.py b/src/dalia/utils/link_functions.py
index 9f335bba..97ddfe84 100644
--- a/src/dalia/utils/link_functions.py
+++ b/src/dalia/utils/link_functions.py
@@ -23,5 +23,9 @@ def scaled_logit(x: NDArray, direction: str) -> NDArray:
return (1.0 / k) * xp.log(x / (1.0 - x))
elif direction == "backward":
return 1 / (1 + xp.exp(-k * x))
+ elif direction == "forward_jacobian":
+ return 1.0 / (k * x * (1.0 - x))
+ elif direction == "backward_jacobian": ### should be 1 / forward_jacobian ...
+ return k * x * (1.0 - x)
else:
raise ValueError(f"Unknown direction: {direction}")
diff --git a/src/dalia/utils/multiprocessing.py b/src/dalia/utils/multiprocessing.py
index 9ceeb6d6..92a1d259 100644
--- a/src/dalia/utils/multiprocessing.py
+++ b/src/dalia/utils/multiprocessing.py
@@ -1,8 +1,10 @@
# Copyright 2024-2025 DALIA authors. All rights reserved.
-import numpy as np
+from typing import Literal
+
from dataclasses import dataclass
-from dalia import ArrayLike, backend_flags, comm_rank
+
+from dalia import ArrayLike, backend_flags, comm_rank, xp
from dalia.utils.gpu_utils import get_array_module_name, get_device, get_host
if backend_flags["mpi_avail"]:
@@ -11,11 +13,13 @@
if backend_flags["cupy_avail"]:
import cupy as cp
+
@dataclass
class DummyCommunicator:
"""Communicator class to handle MPI communication when
MPI is not available.
"""
+
size: int = 1
rank: int = 0
@@ -112,7 +116,9 @@ def allgather(
if backend_flags["mpi_avail"]:
if get_array_module_name(obj) == "cupy" and not backend_flags["mpi_cuda_aware"]:
obj_comm = get_host(obj)
- return get_device(np.concatenate(comm.allgather(obj_comm)))
+ gathered_objs = comm.allgather(obj_comm)
+ # Convert gathered numpy arrays back to cupy arrays
+ return [get_device(arr) for arr in gathered_objs]
else:
return comm.allgather(obj)
@@ -134,9 +140,27 @@ def bcast(
comm (CommunicatorType), optional:
The communication group. Default is MPI.COMM_WORLD.
"""
+ # Need to check data module and MPI capabilities
+ d2h2d_needed : bool = (
+ backend_flags["mpi_avail"] and get_array_module_name(data) == "cupy"
+ )
+
if backend_flags["mpi_avail"]:
- comm.Bcast(data, root=root)
+ if d2h2d_needed:
+ data_comm = get_host(data)
+ else:
+ data_comm = data
+ if data.ndim == 0:
+ comm.Bcast(data_comm, root=root)
+ else:
+ comm.Bcast(data_comm[:], root=root)
+
+ if d2h2d_needed:
+ if data.ndim == 0:
+ data[...] = get_device(data_comm)
+ else:
+ data[:] = get_device(data_comm)
def get_active_comm(
comm,
@@ -209,3 +233,67 @@ def smartsplit(
color_new_group = 0
return active_comm, comm_new_group, color_new_group
+
+def check_vector_consistency(
+ value: ArrayLike,
+ comm,
+ flag: str,
+ verbose: Literal["No", "Minimal", "Full"] = "No",
+ rtol: float = 1e-10,
+):
+ """ Check if all processes have the same value.
+
+ Parameters:
+ -----------
+ value (ArrayLike):
+ The value to check.
+ comm (CommunicatorType), optional:
+ The communication group.
+ flag (str):
+ A string to identify the value being checked in the error message.
+ verbose (str):
+ The level of verbosity for the error message. Choose from 'No', 'Minimal', or 'Full'. Default is 'No'.
+ rtol (float):
+ The relative tolerance for the consistency check. Default is 1e-10.
+
+ Raises:
+ -------
+ ValueError:
+ If the value is not consistent across all processes.
+ """
+ synchronize(comm = comm)
+
+ # A vector might for some models be passed as a scalar, or simply a list
+ # . In these cases we convert it to an array for the consistency check.
+ if value is not None:
+ if not isinstance(value, list):
+ value = xp.array([value])
+ elif not isinstance(value, xp.ndarray):
+ value = xp.array(value)
+
+ value_ref = value.copy()
+
+ bcast(data=value_ref[:], root=0, comm=comm)
+
+ norm_diff = xp.linalg.norm(value - value_ref)
+
+ if norm_diff > rtol:
+ # Print indices and values where value and value_ref differ
+ if verbose == "No":
+ raise ValueError(
+ f"Process {comm.Get_rank()} has a different {flag} than the reference process with a norm of the difference of {norm_diff:.4e}."
+ )
+
+ diff_indices = xp.where(value != value_ref)[0]
+ if verbose == "Minimal":
+ # Only print the first 5 differences, make sure it's 5 or the max number of differences
+ diff_indices = diff_indices[:min(5, len(diff_indices))]
+ elif verbose == "Full":
+ pass
+ else:
+ raise ValueError(
+ f"Invalid verbose option: {verbose}. Choose from 'No', 'Minimal', or 'Full'."
+ )
+
+ for idx in diff_indices:
+ print(f"Process {comm.Get_rank()} difference at index {idx}: {flag}={value_ref[idx]}, value={value[idx]}")
diff --git a/src/dalia/utils/plotting.py b/src/dalia/utils/plotting.py
new file mode 100644
index 00000000..c28b4212
--- /dev/null
+++ b/src/dalia/utils/plotting.py
@@ -0,0 +1,150 @@
+import matplotlib.pyplot as plt
+import numpy as np
+from scipy.stats import norm
+
+from dalia import xp
+from dalia.utils import get_host
+
+def plot_prior_hp(param_name, theta_interval, prior_hp, log=False):
+ """
+ Plot prior distribution of a hyperparameter.
+
+ All priors are in log-scale. Therefore exponeniate unless log=True.
+
+ Parameters
+ ----------
+ param_name : str
+ Name of the hyperparameter.
+ theta_interval: tuple of float
+ Interval (min, max) for plotting the prior.
+ prior_hp : PriorHyperparameters
+ Prior hyperparameter object.
+ log : bool, optional
+ Whether to plot in log-scale or original scale. Default is False.
+
+
+ Note
+ ----
+ If log is True, the plot will be in log-scale. Otherwise, it will be in the original scale.
+
+
+ Returns
+ -------
+ fig, ax : matplotlib Figure and Axes
+ The figure and axes objects containing the plot.
+ """
+
+ if theta_interval[0] == 0:
+ theta_interval = (1e-6, theta_interval[1])
+
+ theta_vals = xp.linspace(theta_interval[0], theta_interval[1], 200)
+ prior_vals = xp.array([prior_hp.evaluate_log_prior(theta) for theta in theta_vals])
+
+ theta_vals = get_host(theta_vals)
+ prior_vals = get_host(prior_vals)
+
+ if log:
+ xlabel = f"{param_name}"
+ ylabel = "Log Prior Density"
+ else:
+ prior_vals = np.exp(prior_vals)
+ xlabel = f"{param_name}"
+ ylabel = "Prior Density"
+
+ fig, ax = plt.subplots(figsize=(8, 5))
+ ax.plot(theta_vals, prior_vals, 'b-', linewidth=2)
+ ax.set_xlabel(xlabel)
+ ax.set_ylabel(ylabel)
+ ax.set_title(f"Prior Distribution of {param_name}")
+ ax.grid(True, alpha=0.3)
+
+ return fig, ax
+
+
+
+
+def plot_marginal_distributions_hp(marginals_hp):
+ """Plot marginal distributions of hyperparameters in both internal and external parametrizations."""
+
+ # Get all hyperparameters
+ hyperparams = marginals_hp['hyperparameters']
+ n_params = len(hyperparams)
+
+ # Create subplot grid: n_params rows, 2 columns (internal left, external right)
+ fig, axes = plt.subplots(n_params, 2, figsize=(15, 5*n_params))
+
+ # Handle case of single parameter
+ if n_params == 1:
+ axes = axes.reshape(1, -1)
+
+ # Quantile colors and labels
+ colors = ['#DEB887', '#DEB887','darkred', '#DEB887', '#DEB887']
+ labels = ['2.5%', '25%', '50%', '75%','97.5%']
+
+ for row, (param_name, param_data) in enumerate(hyperparams.items()):
+ # Get internal parameters
+ mean_internal = param_data['mean_internal']
+ var_internal = param_data['variance_internal']
+ std_internal = np.sqrt(var_internal)
+
+ # Get external parameters
+ mean_external = param_data['mean_external']
+ var_external = param_data['variance_external']
+ theta_external, pdf_external = param_data['pdf_data']
+
+ # Get quantiles
+ quantile_pairs_internal = param_data['quantiles']['internal']['pairs']
+ quantile_pairs_external = param_data['quantiles']['external']['pairs']
+
+ # ===== LEFT PLOT: INTERNAL PARAMETRIZATION =====
+ ax_left = axes[row, 0]
+
+ # Create internal distribution (Gaussian)
+ x_internal = np.linspace(mean_internal - 4*std_internal, mean_internal + 4*std_internal, 100)
+ pdf_internal = norm.pdf(x_internal, loc=mean_internal, scale=std_internal)
+
+ # Plot internal PDF
+ ax_left.plot(x_internal, pdf_internal, 'b-', linewidth=2, label='PDF (Internal)')
+
+ # Mark internal mean
+ ax_left.axvline(mean_internal, color='red', linestyle='--', linewidth=2,
+ label=f'Mean = {mean_internal:.3f}')
+
+ # Mark internal quantiles
+ for i, (prob, q_val) in enumerate(quantile_pairs_internal):
+ if i < len(labels):
+ ax_left.axvline(q_val, color=colors[i], linestyle=':', linewidth=2,
+ label=f'{labels[i]} = {q_val:.3f}')
+
+ ax_left.set_xlabel(f'{param_name} (internal scale)')
+ ax_left.set_ylabel('PDF')
+ ax_left.set_title(f'{param_name}: Internal Distribution (Gaussian)')
+ ax_left.legend()
+ ax_left.grid(True, alpha=0.3)
+
+ # ===== RIGHT PLOT: EXTERNAL PARAMETRIZATION =====
+ ax_right = axes[row, 1]
+
+ # Plot external PDF
+ ax_right.plot(theta_external, pdf_external, 'b-', linewidth=2, label='PDF')
+
+ # Mark external mean
+ ax_right.axvline(mean_external, color='red', linestyle='--', linewidth=2,
+ label=f'Mean = {mean_external:.3f}')
+
+ # Mark external quantiles
+ for i, (prob, q_val) in enumerate(quantile_pairs_external):
+ if i < len(labels):
+ ax_right.axvline(q_val, color=colors[i], linestyle=':', linewidth=2,
+ label=f'{labels[i]} = {q_val:.3f}')
+
+ ax_right.set_xlabel(f'{param_name} ')
+ ax_right.set_ylabel('PDF')
+ ax_right.set_title(f'{param_name}: Marginal Distribution')
+ ax_right.legend()
+ ax_right.grid(True, alpha=0.3)
+
+ # plt.tight_layout()
+ # plt.show()
+
+ return fig, axes
diff --git a/src/dalia/utils/reparametrizations.py b/src/dalia/utils/reparametrizations.py
new file mode 100644
index 00000000..d5924b01
--- /dev/null
+++ b/src/dalia/utils/reparametrizations.py
@@ -0,0 +1,332 @@
+from scipy.stats import norm
+import numpy as np
+
+from dalia import NDArray, xp
+from dalia.utils import get_host, get_device
+
+
+def compute_transformed_quantiles(mean_internal, var_internal, percentiles, transform):
+ """
+ Compute quantiles for a transformed distribution
+
+ Parameters:
+ - original_dist_params: (mean, std) for the internal/transformed distribution
+ - percentiles: array of probability values (0, 1)
+ - transform: TransformationFunction object
+
+ Returns:
+ - quantiles in original scale
+
+ Notes:
+ Computing Quantiles for Transformed Distributions
+
+ Idea:
+ If X ~ f(x) and Y = φ(X), then find quantiles of Y:
+ 1. For a given probability p, find q_p such that P(Y ≤ q_p) = p
+ 2. This is equivalent to P(φ(X) ≤ q_p) = p
+ 3. We suppose φ is bijective and monotonely increasing. Then, if F_X is the CDF of X, we can write:
+ F_Y(q_p) = P(Y ≤ q_p) = P(φ(X) ≤ q_p) = P(X ≤ φ⁻¹(q_p)) = F_X(φ⁻¹(q_p))
+
+ 4. So φ⁻¹(q_p) = F_X⁻¹(p), where F_X⁻¹ is the quantile function of X
+ 5. Therefore: q_p = φ(F_X⁻¹(p))
+ """
+
+ # Step 1: Compute quantiles in internal scale
+ quantiles_np = get_host(percentiles)
+ mean_internal_np = get_host(mean_internal)
+ var_internal_np = get_host(var_internal)
+
+ internal_quantiles = norm.ppf(quantiles_np, loc=mean_internal_np, scale=var_internal_np**0.5)
+
+ # copy qunatiles to device
+ internal_quantiles = get_device(internal_quantiles)
+
+ # Step 2: Transform back to original scale
+ # If φ: original → internal, then original quantiles = φ⁻¹(internal quantiles)
+ original_quantiles = transform(internal_quantiles, direction='backward')
+
+ return original_quantiles
+
+def compute_transformed_pdf(mean_internal, var_internal, x_internal, transform):
+ """
+ Compute PDF of transformed distribution using change of variables
+
+ If Y = φ(X), then f_Y(y) = f_X(φ⁻¹(y)) * |dφ⁻¹/dy|
+ But we want f_X(x) where x is in original scale, so:
+ f_X(x) = f_Y(φ(x)) * |dφ/dx|
+ """
+
+ # PDF in internal scale
+ x_internal_np = get_host(x_internal)
+ mean_internal_np = get_host(mean_internal)
+ var_internal_np = get_host(var_internal)
+ pdf_internal_np = norm.pdf(x_internal_np, loc=mean_internal_np, scale=var_internal_np**0.5)
+
+ # copy pdf_internal to device
+ pdf_internal = get_device(pdf_internal_np)
+
+ # Jacobian: derivative of transformation
+ # Ensure x_internal is treated as array for vectorized operations
+ x_original = transform(x_internal, direction='backward')
+ jacobian = xp.abs(transform(x_original, direction='forward_jacobian'))
+
+ # PDF values in original scale
+ pdf_original = pdf_internal * jacobian
+
+ return x_original, pdf_original
+
+# Automatic bound calculation based on 4 (default) standard deviations in internal scale
+def compute_bounds(mean_internal, var_internal, transform, n_std=4):
+ """
+ Compute plotting bounds based on n standard deviations in internal scale
+ """
+ # Internal scale bounds (±n standard deviations)
+ internal_lower = mean_internal - n_std * var_internal**0.5
+ internal_upper = mean_internal + n_std * var_internal**0.5
+
+ # Transform to original scale
+ original_lower = transform(internal_lower, direction='backward')
+ original_upper = transform(internal_upper, direction='backward')
+
+ return (internal_lower, internal_upper), (original_lower, original_upper)
+
+###################################### TEST ######################################
+if __name__ == "__main__":
+ """
+ Test reparametrization functions using a dummy gamma prior hyperparameter class.
+
+ This test demonstrates how the reparametrization functions work with
+ transformation functions that have forward, backward, and jacobian directions.
+ """
+ import numpy as np
+ import matplotlib.pyplot as plt
+
+ # Dummy Gamma Prior Hyperparameter Class
+ class DummyGammaPrior:
+ """
+ Dummy implementation of gamma prior rescaling for testing purposes.
+ Implements log transformation: forward = log(x), backward = exp(x)
+ """
+ def rescale_hyperparameters_to_internal(self, theta, direction):
+ """Log transformation between positive (external) and unconstrained (internal) space"""
+ if direction == "forward":
+ return xp.log(theta) # theta -> log(theta)
+ elif direction == "backward":
+ return xp.exp(theta) # log(theta) -> theta
+ elif direction == "forward_jacobian":
+ return 1.0 / theta # d(log(theta))/d(theta) = 1/theta
+ elif direction == "backward_jacobian":
+ return theta # d(exp(theta))/d(theta) = exp(theta) = theta
+ else:
+ raise ValueError(f"Unknown direction: {direction}")
+
+ print("=" * 80)
+ print("Testing Reparametrization Functions with Dummy Gamma Prior")
+ print("=" * 80)
+
+ # Create dummy gamma prior instance
+ gamma_prior = DummyGammaPrior()
+
+ # Define transform function compatible with reparametrization functions
+ def transform_func(x, direction):
+ return gamma_prior.rescale_hyperparameters_to_internal(x, direction)
+
+ # Test parameters (internal space: log-normal distribution)
+ mean_internal = 0.5 # mean of log(theta)
+ std_internal = 0.8 # std of log(theta)
+
+ print(f"Internal distribution parameters:")
+ print(f" Mean (log scale): {mean_internal:.3f}")
+ print(f" Std (log scale): {std_internal:.3f}")
+ print()
+
+ # Test 1: Transform validation
+ print("1. Testing transformation consistency:")
+ test_values = [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]
+
+ print(" Original -> Internal -> Original (round-trip test)")
+ for theta in test_values:
+ # Forward transformation
+ log_theta = transform_func(theta, "forward")
+
+ # Backward transformation
+ theta_recovered = transform_func(log_theta, "backward")
+
+ # Check error
+ error = abs(theta - theta_recovered)
+
+ print(f" {theta:5.1f} -> {log_theta:6.3f} -> {theta_recovered:6.3f}, "
+ f"error = {error:.2e}")
+ print()
+
+ # Test 2: Jacobian validation using finite differences
+ print("2. Testing Jacobian accuracy (forward direction):")
+
+ test_theta_vals = [0.5, 1.0, 2.0, 3.0]
+ eps = 1e-8
+
+ print(" θ | Analytical | Numerical | Error")
+ print(" ------|------------|------------|----------")
+
+ for theta in test_theta_vals:
+ # Analytical jacobian
+ jac_analytical = transform_func(theta, "forward_jacobian")
+
+ # Numerical jacobian using finite differences
+ f_plus = transform_func(theta + eps, "forward")
+ f_minus = transform_func(theta - eps, "forward")
+ jac_numerical = (f_plus - f_minus) / (2 * eps)
+
+ error = abs(jac_analytical - jac_numerical)
+
+ print(f" {theta:4.1f} | {jac_analytical:10.6f} | {jac_numerical:10.6f} | {error:.2e}")
+ print()
+
+ # Test 3: Quantile computation
+ print("3. Testing quantile computation:")
+
+ percentiles = xp.array([0.025, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.975])
+ original_quantiles = compute_transformed_quantiles(
+ mean_internal, std_internal**0.5, percentiles, transform_func
+ )
+
+ print(" Percentile | Quantile (original scale)")
+ print(" -----------|-----------------------")
+ for p, q in zip(percentiles, original_quantiles):
+ print(f" {p:8.2f} | {q:18.6f}")
+ print()
+
+ # Test 4: PDF computation and validation
+ print("4. Testing PDF computation:")
+
+ # Compute bounds for plotting
+ (internal_lower, internal_upper), (original_lower, original_upper) = compute_bounds(
+ mean_internal, std_internal, transform_func, n_std=4
+ )
+
+ print(f" Internal bounds: [{internal_lower:.3f}, {internal_upper:.3f}]")
+ print(f" Original bounds: [{original_lower:.3f}, {original_upper:.3f}]")
+
+ # Test PDF at specific points
+ test_x_original = xp.array([0.5, 1.0, 2.0, 3.0, 5.0])
+
+ print(" x (orig) | PDF (orig) | log(x) | PDF (int)")
+ print(" ---------|------------|----------|----------")
+
+ x_internal = transform_func(test_x_original, "forward")
+ x_original, pdf_orig = compute_transformed_pdf(mean_internal, std_internal**2, x_internal, transform_func)
+ pdf_internal = norm.pdf(x_internal, loc=mean_internal, scale=std_internal)
+
+ for i in range(len(test_x_original)):
+ print(f" {test_x_original[i]:7.2f} | {pdf_orig[i]:10.6f} | {x_internal[i]:8.3f} | {pdf_internal[i]:8.6f}")
+ print()
+
+ # Test 5: Analytical validation for log-normal distribution
+ print("5. Analytical validation (log-normal distribution):")
+
+ # For log-normal distribution, we can compute analytical moments
+ mu = mean_internal
+ sigma = std_internal
+
+ # Analytical log-normal statistics
+ analytical_mean = xp.exp(mu + sigma**2/2)
+ analytical_var = (xp.exp(sigma**2) - 1) * xp.exp(2*mu + sigma**2)
+ analytical_std = xp.sqrt(analytical_var)
+
+ # Numerical verification using quantiles
+ # Mean ≈ 50th percentile for log-normal (approximately)
+ median_quantile = compute_transformed_quantiles(
+ mean_internal, std_internal**0.5, xp.array([0.5]), transform_func
+ )[0]
+
+ print(f" Analytical mean: {analytical_mean:.6f}")
+ print(f" Analytical std: {analytical_std:.6f}")
+ print(f" Median quantile: {median_quantile:.6f}")
+ print(f" Mean/Median ratio: {analytical_mean/median_quantile:.6f}")
+ print(" (Should be > 1 for log-normal due to skewness)")
+ print()
+
+ # Test 6: PDF integration check (numerical verification)
+ print("6. PDF integration check:")
+
+ # Create fine grid for integration -> need to start in original scale for dx to be equidistant
+ x_original = xp.linspace(original_lower, original_upper, 1000)
+ x_internal = transform_func(x_original, "forward")
+ x_original, pdf_values = compute_transformed_pdf(mean_internal, std_internal**2, x_internal, transform_func)
+
+ # Numerical integration using trapezoidal rule
+ dx = x_original[1] - x_original[0]
+ integral = np.trapezoid(pdf_values, dx=dx)
+
+ print(f" Numerical integral of PDF: {integral:.6f}")
+ print(f" Should be close to 1.0, error: {abs(1.0 - integral):.6f}")
+ print()
+
+ # repeat with non-equidistant grid in original scale but equidistant in internal scale
+ print(" Repeating PDF integration with equidistant grid in internal scale:")
+ x_internal = np.linspace(internal_lower, internal_upper, 1000)
+ x_original, pdf_values = compute_transformed_pdf(mean_internal, std_internal**2, x_internal, transform_func)
+
+ integral = np.trapezoid(pdf_values, x=x_original)
+ print(f" Numerical integral of PDF: {integral:.6f}")
+ print(f" Should be close to 1.0, error: {abs(1.0 - integral):.6f}")
+ print()
+
+ # Test 7: Bounds computation for different n_std values
+ print("7. Testing bounds computation:")
+
+ for n_std in [1, 2, 3, 4]:
+ (int_lower, int_upper), (orig_lower, orig_upper) = compute_bounds(
+ mean_internal, std_internal, transform_func, n_std=n_std
+ )
+
+ print(f" {n_std}σ bounds:")
+ print(f" Internal: [{int_lower:7.3f}, {int_upper:7.3f}]")
+ print(f" Original: [{orig_lower:7.3f}, {orig_upper:7.3f}]")
+ print()
+
+ # Test 8: Plotting PDFs in both scales
+ print("8. Plotting PDFs in internal and original scales:")
+
+ # Plot 1: PDF in internal scale (log-scale, normal distribution)
+ x_internal = xp.linspace(internal_lower, internal_upper, 500)
+ pdf_internal = norm.pdf(x_internal, loc=mean_internal, scale=std_internal)
+ # Create visualization
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
+
+ ax1.plot(x_internal, pdf_internal, 'b-', linewidth=2, label='Internal PDF')
+ internal_quantiles = norm.ppf(percentiles, loc=mean_internal, scale=std_internal)
+ for i, (p, q) in enumerate(zip(percentiles, internal_quantiles)):
+ color = 'red' if p in [0.025, 0.975] else 'orange'
+ ax1.axvline(q, color=color, linestyle='--', alpha=0.7)
+ if i % 2 == 0:
+ ax1.text(q, max(pdf_internal)*0.8, f'{p*100:.1f}%', rotation=90, ha='right', va='top')
+
+ ax1.set_title(f'Internal Scale: N(mean = {mean_internal}, std = {std_internal:.3f})')
+ ax1.set_xlabel('θ (internal)')
+ ax1.set_ylabel('PDF')
+ ax1.set_xlim(x_internal[0], x_internal[-1])
+ ax1.grid(True, alpha=0.3)
+ ax1.legend()
+
+ # Plot 2: Original distribution with quantiles
+ x_original = np.linspace(int_lower, int_upper, 1000)
+ x_original, pdf_original = compute_transformed_pdf(mean_internal, std_internal**2, x_internal, transform_func)
+
+ ax2.plot(x_original, pdf_original, 'g-', linewidth=2, label='Original PDF')
+ for i, (p, q) in enumerate(zip(percentiles, original_quantiles)):
+ color = 'red' if p in [0.025, 0.975] else 'orange'
+ ax2.axvline(q, color=color, linestyle='--', alpha=0.7)
+ if i % 2 == 0:
+ ax2.text(q, max(pdf_original)*0.8, f'{p*100:.1f}%', rotation=90, ha='right', va='top')
+
+ ax2.set_title('Original Scale: Log-Normal Distribution')
+ ax2.set_xlabel('θ_outer (original)')
+ ax2.set_ylabel('PDF')
+ ax2.set_xlim(orig_lower, 15)
+ ax2.grid(True, alpha=0.3)
+ ax2.legend()
+ plt.tight_layout()
+ plt.show()
+
+ print("=" * 80)
\ No newline at end of file
diff --git a/src/dalia/utils/scalar_ndarray.py b/src/dalia/utils/scalar_ndarray.py
new file mode 100644
index 00000000..fc2cb287
--- /dev/null
+++ b/src/dalia/utils/scalar_ndarray.py
@@ -0,0 +1,41 @@
+from dalia import xp
+
+
+def ensure_scalar(value: float | xp.ndarray) -> float:
+ """ Ensure that the input value is a scalar float. If the input is a 1-element array, extract the scalar value.
+ If the input is an array with more than 1 element, raise an error.
+
+ Parameters
+ ----------
+ value : float or xp.ndarray
+ The value to ensure is a scalar float.
+
+ Returns
+ -------
+ float
+ The scalar float value.
+
+ Raises
+ ------
+ ValueError
+ If the input is not a float or a 1-element array.
+
+ """
+ if isinstance(value, float):
+ return value
+
+ if isinstance(value, xp.ndarray):
+ # Need to handle the case where the array is 0-dimensional
+ # (i.e., a scalar wrapped in an array)
+ if value.ndim == 0:
+ return float(value.item())
+
+ if value.size > 1:
+ raise ValueError(
+ f"value evaluation returned an array of size {value.size}, expected a scalar."
+ )
+ return float(value[0])
+
+ raise ValueError(
+ f"value evaluation returned an object of type {type(value)}, expected a float or a 1-element array."
+ )
diff --git a/src/dalia/utils/spmatrix_utils.py b/src/dalia/utils/spmatrix_utils.py
index 6fe51491..50f2944d 100644
--- a/src/dalia/utils/spmatrix_utils.py
+++ b/src/dalia/utils/spmatrix_utils.py
@@ -51,9 +51,7 @@ def extract_diagonal(
if a.shape[0] != a.shape[1]:
raise ValueError("The input matrix must be square.")
- diagonal = xp.zeros(a.shape[0])
-
- # if scipy.sparse or xp.ndarray .diagonal() exists
+ # If scipy.sparse or xp.ndarray .diagonal() exists
if not backend_flags["cupy_avail"] or isinstance(a, xp.ndarray):
diagonal = a.diagonal()
else:
@@ -81,4 +79,3 @@ def memory_footprint(
total_memory_gb = total_memory_bytes / (1024**3)
print(f"Total memory footprint of Q_prior: {total_memory_gb:.6f} GB")
-
diff --git a/tests/README.md b/tests/README.md
index cc2fe6c4..9098098c 100644
--- a/tests/README.md
+++ b/tests/README.md
@@ -1,3 +1,20 @@
# DALIA testing folder
-Testing suite is in construction and not available at the moment.
-Please refer to the available examples for `Regression` and `Spatio-temporal` models.
+The DALIA testing suite is in construction. Integration tests are not yet available, for now you can refer to the examples provided in the `examples/` directory for testing of the entire pipeline.
+
+
+## How to run tests
+
+The tests can either be run directly using `pytest` or through the provided `runner.sh` script. The `runner.sh` script allows for more convenient selection of test categories and backends.
+
+In a "functionnal" environment, on a cluster with the appropriate modules loaded, and with a working conda environment activated, you can run the tests as follows:
+- Directly using: `./runner.sh`
+- Check available options: `./runner.sh --help`.
+
+## Tests status
+
+| Reference | Status | Reason |
+| --------- | ------ | ------ |
+| `component_integration/solvers/sparse_solvers/sequential/test_selected_inversion()` | Not Implemented | Not Implemented |
+| `component_integration/solvers/sparse_solvers/sequential/test_factorize()` | Limited (cannot check for numerical correctness) | LU decomposition instead of Cholesky due to `scipy` limitations |
+| `component_integration/solvers/structured_solvers/distributed/test_factorize()` | Limited (cannot check for numerical correctness) | Distributed factorization is not numerically equal to sequential reference |
+
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..2a36e059
--- /dev/null
+++ b/tests/__init__.py
@@ -0,0 +1,9 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+from .test_config import ATOLS, RANDOM_SEED, RTOLS
+
+__all__ = [
+ "RTOLS",
+ "ATOLS",
+ "RANDOM_SEED",
+]
diff --git a/tests/component_integration/__init__.py b/tests/component_integration/__init__.py
new file mode 100644
index 00000000..7dd9c468
--- /dev/null
+++ b/tests/component_integration/__init__.py
@@ -0,0 +1 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
diff --git a/tests/component_integration/solvers/__init__.py b/tests/component_integration/solvers/__init__.py
new file mode 100644
index 00000000..7dd9c468
--- /dev/null
+++ b/tests/component_integration/solvers/__init__.py
@@ -0,0 +1 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
diff --git a/tests/component_integration/solvers/conftest.py b/tests/component_integration/solvers/conftest.py
new file mode 100644
index 00000000..959a0542
--- /dev/null
+++ b/tests/component_integration/solvers/conftest.py
@@ -0,0 +1,63 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+NUM_RHS = [
+ pytest.param(1, id="num_rhs=1"),
+ pytest.param(3, id="num_rhs=3"),
+ pytest.param(5, id="num_rhs=5"),
+]
+
+
+@pytest.fixture(params=NUM_RHS)
+def num_rhs(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+@pytest.fixture
+def create_rhs():
+ from .utils import _create_rhs
+
+ return _create_rhs
+
+
+@pytest.fixture
+def reference_cholesky():
+ from .utils import _reference_cholesky
+
+ return _reference_cholesky
+
+
+@pytest.fixture
+def reference_solve():
+ from .utils import _reference_solve
+
+ return _reference_solve
+
+
+@pytest.fixture
+def reference_logdet():
+ from .utils import _reference_logdet
+
+ return _reference_logdet
+
+
+@pytest.fixture
+def reference_inversion():
+ from .utils import _reference_inversion
+
+ return _reference_inversion
+
+
+@pytest.fixture
+def allclose_ndarrays():
+ from .utils import _allclose_ndarrays
+
+ return _allclose_ndarrays
+
+
+@pytest.fixture
+def allclose_floats():
+ from .utils import _allclose_floats
+
+ return _allclose_floats
diff --git a/tests/component_integration/solvers/dense_solvers/__init__.py b/tests/component_integration/solvers/dense_solvers/__init__.py
new file mode 100644
index 00000000..7dd9c468
--- /dev/null
+++ b/tests/component_integration/solvers/dense_solvers/__init__.py
@@ -0,0 +1 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
diff --git a/tests/component_integration/solvers/dense_solvers/conftest.py b/tests/component_integration/solvers/dense_solvers/conftest.py
new file mode 100644
index 00000000..2f4d88a7
--- /dev/null
+++ b/tests/component_integration/solvers/dense_solvers/conftest.py
@@ -0,0 +1,50 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+MATRIX_SIZE = [
+ pytest.param(1, id="matrix_size=1"),
+ pytest.param(2, id="matrix_size=2"),
+ pytest.param(10, id="matrix_size=10"),
+ pytest.param(100, id="matrix_size=100"),
+]
+
+
+@pytest.fixture(params=MATRIX_SIZE, autouse=True)
+def matrix_size(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+MATRIX_TYPE = [
+ pytest.param("dense", id="matrix_type=dense"),
+ pytest.param("sparse", id="matrix_type=sparse"),
+]
+
+
+@pytest.fixture(params=MATRIX_TYPE, autouse=True)
+def matrix_type(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+SOLVER_TYPES = [
+ pytest.param("dense", id="solver_type=dense"),
+]
+
+
+@pytest.fixture(params=SOLVER_TYPES)
+def solver_type(request: pytest.FixtureRequest) -> str:
+ return request.param
+
+
+@pytest.fixture
+def create_solver():
+ from .utils import _create_solver
+
+ return _create_solver
+
+
+@pytest.fixture
+def generate_spd_spmatrix():
+ from .utils import _generate_spd_spmatrix
+
+ return _generate_spd_spmatrix
diff --git a/tests/component_integration/solvers/dense_solvers/sequential/test_factorize.py b/tests/component_integration/solvers/dense_solvers/sequential/test_factorize.py
new file mode 100644
index 00000000..dec4413c
--- /dev/null
+++ b/tests/component_integration/solvers/dense_solvers/sequential/test_factorize.py
@@ -0,0 +1,32 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_factorize_correctness(
+ reference_cholesky,
+ allclose_ndarrays,
+ generate_spd_spmatrix,
+ create_solver,
+ solver_type,
+ matrix_size,
+ matrix_type,
+):
+ """Test correctness of the factorization step of the solver."""
+ # Generate test case
+ A = generate_spd_spmatrix(matrix_type, matrix_size)
+
+ # Solver to compare
+ solver = create_solver(solver_type, matrix_size)
+
+ solver.factorize(A)
+
+ # Reference
+ L_ref = reference_cholesky(A)
+
+ # Compare
+ allclose_ndarrays(
+ a_reference=L_ref,
+ b_toverify=solver.L,
+ )
diff --git a/tests/component_integration/solvers/dense_solvers/sequential/test_logdet.py b/tests/component_integration/solvers/dense_solvers/sequential/test_logdet.py
new file mode 100644
index 00000000..39d31f41
--- /dev/null
+++ b/tests/component_integration/solvers/dense_solvers/sequential/test_logdet.py
@@ -0,0 +1,34 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_logdet_correctness(
+ reference_logdet,
+ allclose_floats,
+ generate_spd_spmatrix,
+ create_solver,
+ solver_type,
+ matrix_size,
+ matrix_type,
+):
+ """Test correctness of the log-determinant computation of the solver."""
+ # Generate test case
+ A = generate_spd_spmatrix(matrix_type, matrix_size)
+
+ # Solver to compare
+ solver = create_solver(solver_type, matrix_size)
+
+ solver.factorize(A)
+
+ logdet_solver = solver.logdet()
+
+ # Reference
+ logdet_ref = reference_logdet(A)
+
+ # Compare
+ allclose_floats(
+ a_reference=logdet_ref,
+ b_toverify=logdet_solver,
+ )
diff --git a/tests/component_integration/solvers/dense_solvers/sequential/test_selected_inversion.py b/tests/component_integration/solvers/dense_solvers/sequential/test_selected_inversion.py
new file mode 100644
index 00000000..95dbd722
--- /dev/null
+++ b/tests/component_integration/solvers/dense_solvers/sequential/test_selected_inversion.py
@@ -0,0 +1,49 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import warnings
+
+import pytest
+
+from tests import ATOLS, RTOLS
+
+
+@pytest.mark.mpi_skip()
+def test_selected_inversion_correctness(
+ reference_inversion,
+ allclose_ndarrays,
+ generate_spd_spmatrix,
+ create_solver,
+ solver_type,
+ matrix_size,
+ matrix_type,
+):
+ """Test correctness of the selected inversion step of the solver."""
+ # Generate test case
+ A = generate_spd_spmatrix(matrix_type, matrix_size)
+
+ # Solver to compare
+ solver = create_solver(solver_type, matrix_size)
+
+ solver.factorize(A)
+
+ X_solver = solver.selected_inversion()
+
+ # Reference
+ X_ref = reference_inversion(A)
+
+ # Compare
+ # Tolerance is relaxed due to numerical differences in selected inversion implementation
+ # of the dense solver (uses trsm on L and L.t less precise than xp.linalg.inv())
+ allclose_ndarrays(
+ a_reference=X_ref,
+ b_toverify=X_solver,
+ relaxed_tolerance=True,
+ )
+
+ # Warn that this test only checks callability, not numerical correctness
+ warnings.warn(
+ f"Test passed but numerical accuracy relaxed due to differences in numerical approaches. "
+ f"Relaxed accuracy: rtol<{RTOLS['relaxed']} and atol<{ATOLS['relaxed']}",
+ UserWarning,
+ stacklevel=2,
+ )
diff --git a/tests/component_integration/solvers/dense_solvers/sequential/test_solve.py b/tests/component_integration/solvers/dense_solvers/sequential/test_solve.py
new file mode 100644
index 00000000..c414a606
--- /dev/null
+++ b/tests/component_integration/solvers/dense_solvers/sequential/test_solve.py
@@ -0,0 +1,38 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_solve_correctness(
+ reference_solve,
+ allclose_ndarrays,
+ generate_spd_spmatrix,
+ create_rhs,
+ create_solver,
+ solver_type,
+ matrix_size,
+ matrix_type,
+ num_rhs,
+):
+ """Test correctness of the solve step of the solver."""
+ # Generate test case
+ A = generate_spd_spmatrix(matrix_type, matrix_size)
+
+ b = create_rhs(n_rhs=num_rhs, matrix_size=matrix_size)
+
+ # Solver to compare
+ solver = create_solver(solver_type, matrix_size)
+
+ solver.factorize(A)
+
+ x_solver = solver.solve(b.copy())
+
+ # Reference
+ x_ref = reference_solve(A, b)
+
+ # Compare
+ allclose_ndarrays(
+ a_reference=x_ref,
+ b_toverify=x_solver,
+ )
diff --git a/tests/component_integration/solvers/dense_solvers/utils.py b/tests/component_integration/solvers/dense_solvers/utils.py
new file mode 100644
index 00000000..f86b4665
--- /dev/null
+++ b/tests/component_integration/solvers/dense_solvers/utils.py
@@ -0,0 +1,58 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import numpy as np
+
+from dalia import backend_flags, sp, xp
+from tests import RANDOM_SEED
+
+np.random.seed(RANDOM_SEED)
+
+if backend_flags["cupy_avail"]:
+ import cupy as cp
+
+ cp.random.seed(cp.uint64(RANDOM_SEED))
+
+
+def _create_solver(
+ solver_type: str,
+ matrix_size: int,
+):
+ from dalia.configs.dalia_config import SolverConfig
+ from dalia.solvers import DenseSolver
+
+ config = SolverConfig(type=solver_type)
+
+ if solver_type == "dense":
+ return DenseSolver(config=config, n=matrix_size)
+ else:
+ raise ValueError(f"Unknown solver type: {solver_type}")
+
+
+def _generate_spd_spmatrix(
+ matrix_type: str,
+ n: int,
+):
+ """Returns a random, positive definite, matrix.
+
+ Parameters
+ ----------
+ matrix_type : str
+ Type of the matrix: "sparse" or "dense".
+ n : int
+ Size of the matrix.
+
+ Returns
+ -------
+ A : ArrayLike
+ Random, positive definite, matrix.
+ """
+
+ if matrix_type == "sparse":
+ L = sp.sparse.random(n, n, density=0.5)
+ L = L + sp.sparse.eye(n, format="csr") # Make diagonal entries positive
+ L = sp.sparse.tril(L, format="csr") # lower triangular
+ return L @ L.T # SPD and sparse (but denser than L)
+ else:
+ L = xp.tril(xp.random.rand(n, n))
+ L = L + xp.eye(n) # Make diagonal entries positive
+ return L @ L.T # SPD and dense
diff --git a/tests/component_integration/solvers/sparse_solvers/__init__.py b/tests/component_integration/solvers/sparse_solvers/__init__.py
new file mode 100644
index 00000000..7dd9c468
--- /dev/null
+++ b/tests/component_integration/solvers/sparse_solvers/__init__.py
@@ -0,0 +1 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
diff --git a/tests/component_integration/solvers/sparse_solvers/conftest.py b/tests/component_integration/solvers/sparse_solvers/conftest.py
new file mode 100644
index 00000000..7a606e1f
--- /dev/null
+++ b/tests/component_integration/solvers/sparse_solvers/conftest.py
@@ -0,0 +1,51 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+MATRIX_SIZE = [
+ pytest.param(1, id="matrix_size=1"),
+ pytest.param(2, id="matrix_size=2"),
+ pytest.param(10, id="matrix_size=10"),
+ pytest.param(100, id="matrix_size=100"),
+]
+
+
+@pytest.fixture(params=MATRIX_SIZE, autouse=True)
+def matrix_size(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+DENSITY = [
+ pytest.param(0.1, id="density=0.1"),
+ pytest.param(0.5, id="density=0.5"),
+ pytest.param(1.0, id="density=1.0"),
+]
+
+
+@pytest.fixture(params=DENSITY, autouse=True)
+def density(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+SOLVERS_TYPES = [
+ pytest.param("scipy", id="solver_type=scipy"),
+]
+
+
+@pytest.fixture(params=SOLVERS_TYPES)
+def solver_type(request: pytest.FixtureRequest) -> str:
+ return request.param
+
+
+@pytest.fixture
+def create_solver():
+ from .utils import _create_solver
+
+ return _create_solver
+
+
+@pytest.fixture
+def generate_spd_spmatrix():
+ from .utils import _generate_spd_spmatrix
+
+ return _generate_spd_spmatrix
diff --git a/tests/component_integration/solvers/sparse_solvers/sequential/test_factorize.py b/tests/component_integration/solvers/sparse_solvers/sequential/test_factorize.py
new file mode 100644
index 00000000..93cf3c66
--- /dev/null
+++ b/tests/component_integration/solvers/sparse_solvers/sequential/test_factorize.py
@@ -0,0 +1,47 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import warnings
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_factorize_correctness(
+ generate_spd_spmatrix,
+ allclose_ndarrays,
+ create_solver,
+ solver_type,
+ matrix_size,
+ density,
+):
+ A = generate_spd_spmatrix(matrix_size, density)
+
+ solver = create_solver(solver_type)
+
+ # Test that factorization is callable without errors
+ solver.factorize(A)
+
+ # Reconstruct the matrix from its LU factors and compare
+ A_recovered = solver.LU_factor.L @ solver.LU_factor.U
+ A_dense = A.toarray() if hasattr(A, "toarray") else A
+ A_recovered_dense = A_recovered.toarray() if hasattr(A_recovered, "toarray") else A_recovered
+
+ try:
+ allclose_ndarrays(
+ A_dense,
+ A_recovered_dense,
+ relaxed_tolerance=False,
+ )
+ except AssertionError as e:
+ allclose_ndarrays(
+ A_dense,
+ A_recovered_dense,
+ relaxed_tolerance=True,
+ )
+ warnings.warn(
+ f"Test passed for sparse_solver.{solver_type}.factorize() within relaxed tolerance."
+ "This is likely due to numerical innacuray in retrieving the matrix from its LU factors."
+ "Correctness is further tested through `solve()` and `logdet()` tests.",
+ UserWarning,
+ stacklevel=2,
+ )
diff --git a/tests/component_integration/solvers/sparse_solvers/sequential/test_logdet.py b/tests/component_integration/solvers/sparse_solvers/sequential/test_logdet.py
new file mode 100644
index 00000000..2f23f4d6
--- /dev/null
+++ b/tests/component_integration/solvers/sparse_solvers/sequential/test_logdet.py
@@ -0,0 +1,30 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_logdet_correctness(
+ reference_logdet,
+ allclose_floats,
+ generate_spd_spmatrix,
+ create_solver,
+ solver_type,
+ matrix_size,
+ density,
+):
+ A = generate_spd_spmatrix(matrix_size, density)
+
+ solver = create_solver(solver_type)
+
+ # Test that factorization is callable without errors
+ solver.factorize(A)
+
+ logdet_solver = solver.logdet()
+
+ logdet_ref = reference_logdet(A)
+
+ allclose_floats(
+ a_reference=logdet_ref,
+ b_toverify=logdet_solver,
+ )
diff --git a/tests/component_integration/solvers/sparse_solvers/sequential/test_solve.py b/tests/component_integration/solvers/sparse_solvers/sequential/test_solve.py
new file mode 100644
index 00000000..e2d12e5d
--- /dev/null
+++ b/tests/component_integration/solvers/sparse_solvers/sequential/test_solve.py
@@ -0,0 +1,37 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_solve_correctness(
+ reference_solve,
+ allclose_ndarrays,
+ generate_spd_spmatrix,
+ create_rhs,
+ create_solver,
+ solver_type,
+ matrix_size,
+ density,
+ num_rhs,
+):
+ A = generate_spd_spmatrix(matrix_size, density)
+
+ b = create_rhs(
+ n_rhs=num_rhs,
+ matrix_size=matrix_size,
+ )
+
+ solver = create_solver(solver_type)
+
+ # Test that factorization is callable without errors
+ solver.factorize(A)
+
+ x_solver = solver.solve(rhs=b.copy())
+
+ x_ref = reference_solve(A, b)
+
+ allclose_ndarrays(
+ a_reference=x_ref,
+ b_toverify=x_solver,
+ )
diff --git a/tests/component_integration/solvers/sparse_solvers/utils.py b/tests/component_integration/solvers/sparse_solvers/utils.py
new file mode 100644
index 00000000..63afdd1b
--- /dev/null
+++ b/tests/component_integration/solvers/sparse_solvers/utils.py
@@ -0,0 +1,52 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import numpy as np
+
+from dalia import backend_flags, sp, xp
+from tests import RANDOM_SEED
+
+np.random.seed(RANDOM_SEED)
+
+if backend_flags["cupy_avail"]:
+ import cupy as cp
+
+ cp.random.seed(cp.uint64(RANDOM_SEED))
+
+
+def _create_solver(
+ solver_type: str,
+):
+ from dalia.configs.dalia_config import SolverConfig
+ from dalia.solvers import SparseSolver
+
+ config = SolverConfig(type=solver_type)
+
+ if solver_type == "scipy":
+ return SparseSolver(config=config)
+ else:
+ raise ValueError(f"Unknown solver type: {solver_type}")
+
+
+def _generate_spd_spmatrix(
+ n: int,
+ density: float,
+):
+ """Returns a random, positive definite, sparse matrix.
+
+ Parameters
+ ----------
+ n : int
+ Size of the matrix.
+ density : float
+ Density of the matrix.
+
+ Returns
+ -------
+ A : ArrayLike
+ Random, positive definite, sparse matrix.
+ """
+ L = sp.sparse.random(n, n, density=density, data_rvs=xp.random.randn)
+ L = L + n * sp.sparse.eye(n) # Make diagonal dominant
+ L = sp.sparse.tril(L) # lower triangular
+
+ return L @ L.T # SPD and sparse (but denser than L)
diff --git a/tests/component_integration/solvers/structured_solvers/__init__.py b/tests/component_integration/solvers/structured_solvers/__init__.py
new file mode 100644
index 00000000..7dd9c468
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/__init__.py
@@ -0,0 +1 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
diff --git a/tests/component_integration/solvers/structured_solvers/conftest.py b/tests/component_integration/solvers/structured_solvers/conftest.py
new file mode 100644
index 00000000..d68be698
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/conftest.py
@@ -0,0 +1,62 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+from tests.structured_solvers_utils import (
+ _allclose_dense_structured,
+ _create_pobt,
+ _create_pobta,
+ _create_solver,
+)
+
+DIAGONAL_BLOCKSIZE = [
+ pytest.param(2, id="diagonal_blocksize=2"),
+ pytest.param(3, id="diagonal_blocksize=3"),
+]
+
+
+@pytest.fixture(params=DIAGONAL_BLOCKSIZE, autouse=True)
+def diagonal_blocksize(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+ARROWHEAD_BLOCKSIZE = [
+ pytest.param(0, id="arrowhead_blocksize=0"),
+ pytest.param(2, id="arrowhead_blocksize=2"),
+ pytest.param(3, id="arrowhead_blocksize=3"),
+]
+
+
+@pytest.fixture(params=ARROWHEAD_BLOCKSIZE, autouse=True)
+def arrowhead_blocksize(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+SOLVERS_TYPES = [
+ pytest.param("serinv", id="solvers_types=serinv"),
+]
+
+
+@pytest.fixture(params=SOLVERS_TYPES)
+def solver_type(request: pytest.FixtureRequest) -> str:
+ return request.param
+
+
+@pytest.fixture
+def create_solver():
+ return _create_solver
+
+
+@pytest.fixture
+def create_pobta():
+ return _create_pobta
+
+
+@pytest.fixture
+def create_pobt():
+ return _create_pobt
+
+
+@pytest.fixture
+def allclose_dense_structured():
+ return _allclose_dense_structured
diff --git a/tests/component_integration/solvers/structured_solvers/distributed/conftest.py b/tests/component_integration/solvers/structured_solvers/distributed/conftest.py
new file mode 100644
index 00000000..524cf6e7
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/distributed/conftest.py
@@ -0,0 +1,25 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+N_DIAG_BLOCKS_PER_PROCESS = [
+ pytest.param(3, id="n_diag_blocks=3"),
+ pytest.param(5, id="n_diag_blocks=5"),
+ pytest.param(10, id="n_diag_blocks=10"),
+]
+
+
+@pytest.fixture(params=N_DIAG_BLOCKS_PER_PROCESS, autouse=True)
+def n_diag_blocks_per_process(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+NON_UNIFORM_PARTITION_SIZES = [
+ pytest.param(True, id="non_uniform_partition=True"),
+ pytest.param(False, id="non_uniform_partition=False"),
+]
+
+
+@pytest.fixture(params=NON_UNIFORM_PARTITION_SIZES, autouse=True)
+def non_uniform_partition(request: pytest.FixtureRequest) -> bool:
+ return request.param
diff --git a/tests/component_integration/solvers/structured_solvers/distributed/test_factorize.py b/tests/component_integration/solvers/structured_solvers/distributed/test_factorize.py
new file mode 100644
index 00000000..fa3b9cf1
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/distributed/test_factorize.py
@@ -0,0 +1,47 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import warnings
+
+import pytest
+
+
+@pytest.mark.mpi(min_size=2)
+def test_factorize_correctness(
+ create_solver,
+ create_pobta,
+ create_pobt,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks_per_process,
+ arrowhead_blocksize,
+ non_uniform_partition,
+):
+ """Test Distributed Cholesky decomposition correctness against NumPy/CuPy dense reference."""
+ import mpi4py.MPI as MPI
+
+ # Generate test matrix based on sparsity pattern
+ n_diag_blocks = n_diag_blocks_per_process * MPI.COMM_WORLD.Get_size() + (
+ 1 if non_uniform_partition else 0
+ )
+
+ if arrowhead_blocksize > 0: # bta
+ A = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Create solver
+ solver = create_solver(
+ solver_type, diagonal_blocksize, n_diag_blocks, arrowhead_blocksize
+ )
+
+ # Run solver factorize
+ solver.factorize(A, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Warn that this test only checks callability, not numerical correctness
+ warnings.warn(
+ f"Test passed but only verified `structured_solvers.distributed.{solver_type}.factorize()` is callable. "
+ "Cannot verify against reference Cholesky as distributed Cholesky factorization is not equal to sequential one."
+ "Correctness is still tested through `solve()`, `logdet()`, and `selected_inversion()` tests.",
+ UserWarning,
+ stacklevel=2,
+ )
diff --git a/tests/component_integration/solvers/structured_solvers/distributed/test_logdet.py b/tests/component_integration/solvers/structured_solvers/distributed/test_logdet.py
new file mode 100644
index 00000000..75703d11
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/distributed/test_logdet.py
@@ -0,0 +1,53 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi(min_size=2)
+def test_logdet_correctness(
+ reference_logdet,
+ allclose_floats,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks_per_process,
+ arrowhead_blocksize,
+ non_uniform_partition,
+):
+ """Test logdet correctness against NumPy/CuPy reference."""
+ import mpi4py.MPI as MPI
+
+ # Generate test matrix based on sparsity pattern
+ n_diag_blocks = n_diag_blocks_per_process * MPI.COMM_WORLD.Get_size() + (
+ 1 if non_uniform_partition else 0
+ )
+
+ if arrowhead_blocksize > 0: # bta
+ A = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Create solver
+ solver = create_solver(
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+ distributed=True,
+ )
+
+ # Run solver factorize
+ solver.factorize(A, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Compute logdet using solver
+ logdet_solver = solver.logdet(sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Compute reference
+ logdet_ref = reference_logdet(A)
+
+ allclose_floats(
+ a_reference=logdet_ref,
+ b_toverify=logdet_solver,
+ )
diff --git a/tests/component_integration/solvers/structured_solvers/distributed/test_selected_inversion.py b/tests/component_integration/solvers/structured_solvers/distributed/test_selected_inversion.py
new file mode 100644
index 00000000..1bf1567e
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/distributed/test_selected_inversion.py
@@ -0,0 +1,69 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi(min_size=2)
+def test_selected_inversion_correctness(
+ reference_inversion,
+ allclose_dense_structured,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks_per_process,
+ arrowhead_blocksize,
+ non_uniform_partition,
+):
+ """Test Distributed Selected Inversion correctness against NumPy/CuPy dense reference."""
+ import mpi4py.MPI as MPI
+
+ # Generate test matrix based on sparsity pattern
+ n_diag_blocks = n_diag_blocks_per_process * MPI.COMM_WORLD.Get_size() + (
+ 1 if non_uniform_partition else 0
+ )
+
+ if arrowhead_blocksize > 0: # bta
+ A = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Create solver
+ solver = create_solver(
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+ distributed=True,
+ )
+
+ # Run solver factorize
+ solver.factorize(A, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Run solver selected inversion
+ solver.selected_inversion(sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Get the computed selected inverse in sparse format
+ A_selinv_solver = solver._structured_to_spmatrix(
+ A,
+ sparsity="bta" if arrowhead_blocksize > 0 else "bt",
+ symmetrize=True,
+ )
+
+ # Reference dense inversion
+ A_inv_ref = reference_inversion(A)
+
+ # Assert correctness within sparsity pattern
+ allclose_dense_structured(
+ A_reference=A_inv_ref,
+ B_toverify=(
+ A_selinv_solver.toarray()
+ if hasattr(A_selinv_solver, "toarray")
+ else A_selinv_solver
+ ),
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ assert_upper_triangle=True,
+ )
diff --git a/tests/component_integration/solvers/structured_solvers/distributed/test_solve.py b/tests/component_integration/solvers/structured_solvers/distributed/test_solve.py
new file mode 100644
index 00000000..af99f0b6
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/distributed/test_solve.py
@@ -0,0 +1,62 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi(min_size=2)
+def test_solve_correctness(
+ reference_solve,
+ allclose_ndarrays,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ create_rhs,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks_per_process,
+ arrowhead_blocksize,
+ num_rhs,
+ non_uniform_partition,
+):
+ """Test Triangular Solve correctness against NumPy/CuPy reference."""
+ import mpi4py.MPI as MPI
+
+ # Generate test matrix based on sparsity pattern
+ n_diag_blocks = n_diag_blocks_per_process * MPI.COMM_WORLD.Get_size() + (
+ 1 if non_uniform_partition else 0
+ )
+
+ if arrowhead_blocksize > 0: # bta
+ A = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Generate rhs
+ b = create_rhs(
+ n_rhs=num_rhs,
+ matrix_size=n_diag_blocks * diagonal_blocksize + arrowhead_blocksize,
+ )
+
+ # Compute reference
+ x_ref = reference_solve(A, b.copy())
+
+ # Create solver
+ solver = create_solver(
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+ distributed=True,
+ )
+
+ # Run solver factorize
+ solver.factorize(A, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Run solver solve
+ x_solver = solver.solve(rhs=b, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Verify results
+ allclose_ndarrays(
+ a_reference=x_ref,
+ b_toverify=x_solver,
+ )
diff --git a/tests/component_integration/solvers/structured_solvers/sequential/conftest.py b/tests/component_integration/solvers/structured_solvers/sequential/conftest.py
new file mode 100644
index 00000000..76ea69a6
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/sequential/conftest.py
@@ -0,0 +1,15 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+N_DIAG_BLOCKS = [
+ pytest.param(1, id="n_diag_blocks=1"),
+ pytest.param(2, id="n_diag_blocks=2"),
+ pytest.param(3, id="n_diag_blocks=3"),
+ pytest.param(4, id="n_diag_blocks=4"),
+]
+
+
+@pytest.fixture(params=N_DIAG_BLOCKS, autouse=True)
+def n_diag_blocks(request: pytest.FixtureRequest) -> int:
+ return request.param
diff --git a/tests/component_integration/solvers/structured_solvers/sequential/test_factorize.py b/tests/component_integration/solvers/structured_solvers/sequential/test_factorize.py
new file mode 100644
index 00000000..bd647cd9
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/sequential/test_factorize.py
@@ -0,0 +1,50 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_factorize_correctness(
+ allclose_dense_structured,
+ reference_cholesky,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+):
+ """Test Cholesky decomposition correctness against NumPy/CuPy dense reference."""
+ # Generate test matrix based on sparsity pattern
+ if arrowhead_blocksize > 0: # bta
+ A = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Create solver
+ solver = create_solver(
+ solver_type, diagonal_blocksize, n_diag_blocks, arrowhead_blocksize
+ )
+
+ # Run solver factorize
+ solver.factorize(A, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Compute reference
+ L_ref = reference_cholesky(A)
+
+ L_solver = solver._structured_to_spmatrix(
+ A,
+ sparsity="bta" if arrowhead_blocksize > 0 else "bt",
+ symmetrize=False,
+ )
+
+ L_solver_dense = L_solver.toarray() if hasattr(L_solver, "toarray") else L_solver
+
+ allclose_dense_structured(
+ A_reference=L_ref,
+ B_toverify=L_solver_dense,
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ )
diff --git a/tests/component_integration/solvers/structured_solvers/sequential/test_logdet.py b/tests/component_integration/solvers/structured_solvers/sequential/test_logdet.py
new file mode 100644
index 00000000..cd756de8
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/sequential/test_logdet.py
@@ -0,0 +1,42 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_logdet_correctness(
+ reference_logdet,
+ allclose_floats,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+):
+ """Test Cholesky decomposition correctness against NumPy/CuPy reference."""
+ # Generate test matrix based on sparsity pattern
+ if arrowhead_blocksize > 0: # bta
+ A = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Create solver
+ solver = create_solver(
+ solver_type, diagonal_blocksize, n_diag_blocks, arrowhead_blocksize
+ )
+
+ # Run solver factorize
+ solver.factorize(A, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Compute logdet using solver
+ logdet_solver = solver.logdet(sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Compute reference
+ logdet_ref = reference_logdet(A)
+
+ allclose_floats(
+ a_reference=logdet_ref,
+ b_toverify=logdet_solver,
+ )
diff --git a/tests/component_integration/solvers/structured_solvers/sequential/test_selected_inversion.py b/tests/component_integration/solvers/structured_solvers/sequential/test_selected_inversion.py
new file mode 100644
index 00000000..1baab51e
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/sequential/test_selected_inversion.py
@@ -0,0 +1,58 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_selected_inversion_correctness(
+ reference_inversion,
+ allclose_dense_structured,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+):
+ """Test Cholesky decomposition correctness against NumPy/CuPy dense reference."""
+ # Generate test matrix based on sparsity pattern
+ if arrowhead_blocksize > 0: # bta
+ A = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Create solver
+ solver = create_solver(
+ solver_type, diagonal_blocksize, n_diag_blocks, arrowhead_blocksize
+ )
+
+ # Run solver factorize
+ solver.factorize(A, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Run solver selected inversion
+ solver.selected_inversion(sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Get the computed selected inverse in sparse format
+ A_selinv_solver = solver._structured_to_spmatrix(
+ A,
+ sparsity="bta" if arrowhead_blocksize > 0 else "bt",
+ symmetrize=True,
+ )
+
+ # Reference dense inversion
+ A_inv_ref = reference_inversion(A)
+
+ # Assert correctness within sparsity pattern
+ allclose_dense_structured(
+ A_reference=A_inv_ref,
+ B_toverify=(
+ A_selinv_solver.toarray()
+ if hasattr(A_selinv_solver, "toarray")
+ else A_selinv_solver
+ ),
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ assert_upper_triangle=True,
+ )
diff --git a/tests/component_integration/solvers/structured_solvers/sequential/test_solve.py b/tests/component_integration/solvers/structured_solvers/sequential/test_solve.py
new file mode 100644
index 00000000..aabaf809
--- /dev/null
+++ b/tests/component_integration/solvers/structured_solvers/sequential/test_solve.py
@@ -0,0 +1,51 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_solve_correctness(
+ reference_solve,
+ allclose_ndarrays,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ create_rhs,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+ num_rhs,
+):
+ """Test Cholesky decomposition correctness against NumPy/CuPy reference."""
+ # Generate test matrix based on sparsity pattern
+ if arrowhead_blocksize > 0: # bta
+ A = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Generate rhs
+ b = create_rhs(
+ n_rhs=num_rhs,
+ matrix_size=n_diag_blocks * diagonal_blocksize + arrowhead_blocksize,
+ )
+
+ # Compute reference
+ x_ref = reference_solve(A, b.copy())
+
+ # Create solver
+ solver = create_solver(
+ solver_type, diagonal_blocksize, n_diag_blocks, arrowhead_blocksize
+ )
+
+ # Run solver factorize
+ solver.factorize(A, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Run solver solve
+ x_solver = solver.solve(rhs=b, sparsity="bta" if arrowhead_blocksize > 0 else "bt")
+
+ # Verify results
+ allclose_ndarrays(
+ a_reference=x_ref,
+ b_toverify=x_solver,
+ )
diff --git a/tests/component_integration/solvers/utils.py b/tests/component_integration/solvers/utils.py
new file mode 100644
index 00000000..c2d50eca
--- /dev/null
+++ b/tests/component_integration/solvers/utils.py
@@ -0,0 +1,204 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import numpy as np
+
+from dalia import backend_flags, xp, sp
+from tests import ATOLS, RANDOM_SEED, RTOLS
+
+np.random.seed(RANDOM_SEED)
+
+if backend_flags["cupy_avail"]:
+ import cupy as cp
+
+ cp.random.seed(cp.uint64(RANDOM_SEED))
+
+
+def _to_ndarray(A):
+ """Convert input to ndarray.
+
+ Parameters
+ ----------
+ A : ArrayLike
+ Input array.
+
+ Returns
+ -------
+ ndarray
+ Converted ndarray.
+ """
+ A_dense = A.toarray() if hasattr(A, "toarray") else A
+ A_dense = xp.asarray(A_dense)
+ return A_dense
+
+
+def _create_rhs(n_rhs: int, matrix_size: int):
+ """Returns a random right-hand side.
+
+ Parameters
+ ----------
+ n_rhs : int
+ Number of right-hand sides.
+ diagonal_blocksize : int
+ Size of the diagonal blocks.
+ arrowhead_blocksize : int
+ Size of the arrowhead blocks.
+ n_diag_blocks : int
+ Number of diagonal blocks.
+
+ Returns
+ -------
+ B : ArrayLike
+ Random right-hand side.
+ """
+
+ B = xp.random.rand(matrix_size, n_rhs)
+
+ return B
+
+
+def _reference_cholesky(A):
+ """Compute reference Cholesky decomposition using NumPy/CuPy.
+
+ Parameters
+ ----------
+ A : ArrayLike
+ Input matrix to decompose.
+
+ Returns
+ -------
+ L : ndarray
+ Lower triangular Cholesky factor.
+ """
+ return xp.linalg.cholesky(_to_ndarray(A))
+
+
+def _reference_solve(A, rhs):
+ """Solve linear system using NumPy/CuPy.
+
+ The reference solution is computed using:
+ 1. Cholesky decomposition A = L L^T
+ 2. Solve for y: L y = rhs
+ 3. Solve for x: L^T x = y
+
+ Parameters
+ ----------
+ A : ArrayLike
+ System matrix.
+ rhs : ArrayLike
+ Right-hand side.
+
+ Returns
+ -------
+ x : ndarray
+ Solution vector.
+ """
+ L = xp.linalg.cholesky(_to_ndarray(A))
+ y = sp.linalg.solve_triangular(L, _to_ndarray(rhs), lower=True)
+ x = sp.linalg.solve_triangular(L, y, trans="T", lower=True)
+ return x
+
+
+def _reference_logdet(A):
+ """Compute log determinant using NumPy/CuPy Cholesky.
+
+ The log determinant is computed using:
+ 1. Cholesky decomposition A = L L^T
+ 2. logdet(A) = 2 * sum(log(diag(L)))
+
+ Parameters
+ ----------
+ A : ArrayLike
+ Input matrix.
+
+ Returns
+ -------
+ logdet : float
+ Log determinant of the matrix.
+ """
+ L = xp.linalg.cholesky(_to_ndarray(A))
+ return 2.0 * xp.sum(xp.log(xp.diag(L)))
+
+
+def _reference_inversion(A):
+ """Compute matrix inverse using NumPy/CuPy.
+
+ The reference inverse is computed using:
+ 1. Cholesky decomposition A = L L^T
+ 2. Solve for L_inv: L L_inv = I
+ 3. Compute A_inv = L_inv^T L_inv
+
+ Parameters
+ ----------
+ A : ArrayLike
+ Input matrix.
+
+ Returns
+ -------
+ A_inv : ndarray
+ Inverse matrix.
+ """
+ L = xp.linalg.cholesky(_to_ndarray(A))
+ L_inv = sp.linalg.solve_triangular(
+ L, xp.eye(L.shape[0]), lower=True, overwrite_b=False
+ )
+ A_inv = L_inv.T @ L_inv
+ return A_inv
+
+
+def _allclose_ndarrays(
+ a_reference: np.ndarray,
+ b_toverify: np.ndarray,
+ relaxed_tolerance: bool = False,
+):
+ """Check correctness of two ndarrays.
+
+ Parameters
+ ----------
+ A_reference : ndarray
+ Reference vector.
+ B_toverify : ndarray
+ Vector to verify.
+ relaxed_tolerance : bool, optional
+ Whether to use relaxed tolerance for comparison, by default False.
+
+ Raises
+ ------
+ AssertionError
+ If the vectors are not close enough.
+ """
+
+ assert xp.allclose(
+ a_reference,
+ b_toverify,
+ rtol=RTOLS["relaxed"] if relaxed_tolerance else RTOLS["strict"],
+ atol=ATOLS["relaxed"] if relaxed_tolerance else ATOLS["strict"],
+ )
+
+
+def _allclose_floats(
+ a_reference: float,
+ b_toverify: float,
+ relaxed_tolerance: bool = False,
+):
+ """Check correctness of two floats.
+
+ Parameters
+ ----------
+ a_reference : float
+ Reference float.
+ b_toverify : float
+ Float to verify.
+ relaxed_tolerance : bool, optional
+ Whether to use relaxed tolerance for comparison, by default False.
+
+ Raises
+ ------
+ AssertionError
+ If the floats are not close enough.
+ """
+ assert xp.isclose(
+ a_reference,
+ b_toverify,
+ rtol=RTOLS["relaxed"] if relaxed_tolerance else RTOLS["strict"],
+ atol=ATOLS["relaxed"] if relaxed_tolerance else ATOLS["strict"],
+ )
diff --git a/tests/integration/gst/itest.py b/tests/integration/gst/itest.py
new file mode 100644
index 00000000..a45fecde
--- /dev/null
+++ b/tests/integration/gst/itest.py
@@ -0,0 +1,111 @@
+from pathlib import Path
+import numpy as np
+
+from dalia.configs import likelihood_config, dalia_config, submodels_config
+from dalia.core.model import Model
+from dalia.core.dalia import DALIA
+from dalia.utils import print_msg, get_host
+from dalia.submodels import RegressionSubModel, SpatioTemporalSubModel
+
+SCRIPT_DIR = Path(__file__).resolve()
+DALIA_DIR = SCRIPT_DIR.parent.parent.parent.parent
+EXAMPLE_PATH = DALIA_DIR / "examples" / "gst_medium"
+
+X_TOL = 1e-1
+THETA_TOL = 1e1
+TYPICAL_N_ITER = 26
+
+def gst_itest():
+ spatio_temporal_dict = {
+ "type": "spatio_temporal",
+ "input_dir": f"{EXAMPLE_PATH}/inputs_spatio_temporal",
+ "spatial_domain_dimension": 2,
+ "r_s": 0.0,
+ "r_t": 2.2,
+ "sigma_st": 1.3,
+ "manifold": "sphere",
+ "ph_s": {"type": "penalized_complexity", "alpha": 0.01, "u": 0.5},
+ "ph_t": {"type": "penalized_complexity", "alpha": 0.01, "u": 5},
+ "ph_st": {"type": "penalized_complexity", "alpha": 0.01, "u": 3},
+ }
+ spatio_temporal = SpatioTemporalSubModel(
+ config=submodels_config.parse_config(spatio_temporal_dict),
+ )
+ # . Regression submodel
+ regression_dict = {
+ "type": "regression",
+ "input_dir": f"{EXAMPLE_PATH}/inputs_regression",
+ "n_fixed_effects": 6,
+ "fixed_effects_prior_precision": 0.001,
+ }
+ regression = RegressionSubModel(
+ config=submodels_config.parse_config(regression_dict),
+ )
+ # Configurations of the likelihood
+ likelihood_dict = {
+ "type": "gaussian",
+ "prec_o": 4,
+ "prior_hyperparameters": {"type": "gaussian", "mean": 1.4, "precision": 0.5},
+ }
+ # Creation of the model by combining the submodels and the likelihood
+ model = Model(
+ submodels=[regression, spatio_temporal],
+ likelihood_config=likelihood_config.parse_config(likelihood_dict),
+ )
+ # Configurations of DALIA
+ dalia_dict = {
+ "solver": {
+ "type": "serinv",
+ "min_processes": 1,
+ },
+ "minimize": {
+ "max_iter": 100,
+ "gtol": 1e-3,
+ "disp": True,
+ },
+ "f_reduction_tol": 1e-4,
+ "theta_reduction_tol": 1e-4,
+ "inner_iteration_max_iter": 50,
+ "eps_inner_iteration": 1e-3,
+ "eps_gradient_f": 1e-3,
+ "simulation_dir": f"{EXAMPLE_PATH}",
+ }
+ dalia = DALIA(
+ model=model,
+ config=dalia_config.parse_config(dalia_dict),
+ )
+ results = dalia.run()
+
+ # Check iterations behavior
+ success_msg : str = "success"
+ if results["optimization_iterations"] > TYPICAL_N_ITER:
+ success_msg = "warning_more_iters_than_typical"
+ elif results["optimization_iterations"] < TYPICAL_N_ITER:
+ success_msg = "success_less_iters_than_typical"
+
+ # Compare hyperparameters
+ theta_ref = np.array(np.load(f"{EXAMPLE_PATH}/reference_outputs/theta_ref.npy"))
+ print(f"theta_ref: {theta_ref}")
+ print(f"theta_dalia: {get_host(results['theta_internal'])}")
+ print_msg(
+ "Norm (theta - theta_ref): ",
+ f"{np.linalg.norm(get_host(results['theta_internal']) - theta_ref):.4e}",
+ )
+ if np.linalg.norm(get_host(results["theta_internal"]) - theta_ref) > THETA_TOL:
+ return "theta_tol_exceeded"
+
+ # Compare latent parameters
+ x_ref = np.array(np.load(f"{EXAMPLE_PATH}/reference_outputs/x_ref.npy"))
+ print(f"x_ref: {x_ref}")
+ print(f"x_dalia: {get_host(results['x'])}")
+ print_msg(
+ "Norm (x - x_ref): ",
+ f"{np.linalg.norm(get_host(results['x']) - x_ref):.4e}",
+ )
+ if np.linalg.norm(get_host(results["x"]) - x_ref) > X_TOL:
+ return "x_tol_exceeded"
+
+ return success_msg
+
+if __name__ == "__main__":
+ gst_itest()
\ No newline at end of file
diff --git a/tests/integration/gstcoreg2/itest.py b/tests/integration/gstcoreg2/itest.py
new file mode 100644
index 00000000..d075a822
--- /dev/null
+++ b/tests/integration/gstcoreg2/itest.py
@@ -0,0 +1,234 @@
+from pathlib import Path
+import numpy as np
+
+from dalia.configs import (
+ likelihood_config,
+ models_config,
+ dalia_config,
+ submodels_config,
+)
+from dalia.core.model import Model
+from dalia.core.dalia import DALIA
+from dalia.models import CoregionalModel
+from dalia.submodels import RegressionSubModel, SpatioTemporalSubModel
+from dalia.utils import print_msg, get_host
+
+SCRIPT_DIR = Path(__file__).resolve()
+DALIA_DIR = SCRIPT_DIR.parent.parent.parent.parent
+EXAMPLE_PATH = DALIA_DIR / "examples" / "gst_coreg2_small"
+
+X_TOL = 1e3 # 1.5e+2
+THETA_TOL = 1e2 # 9.e+1
+TYPICAL_N_ITER = 80 # On Fritz: 86
+
+def gstcoreg2_itest():
+ nv = 2
+ ns = 354
+ nt = 12
+ nb = 2
+
+ theta_ref_file = (
+ f"{EXAMPLE_PATH}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_ref.npy"
+ )
+ theta_ref = np.load(theta_ref_file)
+ perturbation = [
+ 0.18197867,
+ -0.12551227,
+ 0.19998896,
+ 0.17226796,
+ 0.14656176,
+ -0.11864931,
+ 0.17817371,
+ -0.13006157,
+ 0.19308036,
+ ]
+ theta_initial = theta_ref + np.array(perturbation)
+
+ # Configurations of the submodels for the first model
+ # . Spatio-temporal submodel 1
+ spatio_temporal_1_dict = {
+ "type": "spatio_temporal",
+ "input_dir": f"{EXAMPLE_PATH}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/model_1/inputs_spatio_temporal",
+ "spatial_domain_dimension": 2,
+ "r_s": theta_initial[0],
+ "r_t": theta_initial[1],
+ "sigma_st": 0.0,
+ "manifold": "plane",
+ "ph_s": {
+ "type": "gaussian",
+ "mean": theta_ref[0],
+ "precision": 0.5,
+ },
+ "ph_t": {
+ "type": "gaussian",
+ "mean": theta_ref[1],
+ "precision": 0.5,
+ },
+ "ph_st": {
+ "type": "gaussian",
+ "mean": 0.0,
+ "precision": 0.5,
+ },
+ }
+ spatio_temporal_1 = SpatioTemporalSubModel(
+ config=submodels_config.parse_config(spatio_temporal_1_dict),
+ )
+ # . Regression submodel 1
+ regression_1_dict = {
+ "type": "regression",
+ "input_dir": f"{EXAMPLE_PATH}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/model_1/inputs_regression",
+ "n_fixed_effects": 1,
+ "fixed_effects_prior_precision": 0.001,
+ }
+ regression_1 = RegressionSubModel(
+ config=submodels_config.parse_config(regression_1_dict),
+ )
+ # . Likelihood submodel 1
+ likelihood_1_dict = {
+ "type": "gaussian",
+ "prec_o": theta_initial[2],
+ "prior_hyperparameters": {
+ "type": "gaussian",
+ "mean": theta_initial[2],
+ "precision": 0.5,
+ },
+ }
+ # Creation of the first model by combining the submodels and the likelihood
+ model_1 = Model(
+ submodels=[regression_1, spatio_temporal_1],
+ likelihood_config=likelihood_config.parse_config(likelihood_1_dict),
+ )
+
+ # Configurations of the submodels for the second model
+ # . Spatio-temporal submodel 2
+ spatio_temporal_2_dict = {
+ "type": "spatio_temporal",
+ "input_dir": f"{EXAMPLE_PATH}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/model_2/inputs_spatio_temporal",
+ "spatial_domain_dimension": 2,
+ "r_s": theta_initial[3],
+ "r_t": theta_initial[4],
+ "sigma_st": 0.0,
+ "manifold": "plane",
+ "ph_s": {
+ "type": "gaussian",
+ "mean": theta_ref[3],
+ "precision": 0.5,
+ },
+ "ph_t": {
+ "type": "gaussian",
+ "mean": theta_ref[4],
+ "precision": 0.5,
+ },
+ "ph_st": {
+ "type": "gaussian",
+ "mean": 0.0,
+ "precision": 0.5,
+ },
+ }
+ spatio_temporal_2 = SpatioTemporalSubModel(
+ config=submodels_config.parse_config(spatio_temporal_2_dict),
+ )
+ # . Regression submodel 2
+ regression_2_dict = {
+ "type": "regression",
+ "input_dir": f"{EXAMPLE_PATH}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/model_2/inputs_regression",
+ "n_fixed_effects": 1,
+ "fixed_effects_prior_precision": 0.001,
+ }
+ regression_2 = RegressionSubModel(
+ config=submodels_config.parse_config(regression_2_dict),
+ )
+ # . Likelihood submodel 2
+ likelihood_2_dict = {
+ "type": "gaussian",
+ "prec_o": theta_initial[5],
+ "prior_hyperparameters": {
+ "type": "gaussian",
+ "mean": theta_ref[5],
+ "precision": 0.5,
+ },
+ }
+ # Creation of the second model by combining the submodels and the likelihood
+ model_2 = Model(
+ submodels=[spatio_temporal_2, regression_2],
+ likelihood_config=likelihood_config.parse_config(likelihood_2_dict),
+ )
+ # Creation of the coregional model by combining the models
+ coreg_dict = {
+ "type": "coregional",
+ "n_models": 2,
+ "sigmas": [theta_initial[6], theta_initial[7]],
+ "lambdas": [theta_initial[8]],
+ "ph_sigmas": [
+ {"type": "gaussian", "mean": theta_ref[6], "precision": 0.5},
+ {"type": "gaussian", "mean": theta_ref[7], "precision": 0.5},
+ ],
+ "ph_lambdas": [
+ {"type": "gaussian", "mean": 0.0, "precision": 0.5},
+ ],
+ }
+ coreg_model = CoregionalModel(
+ models=[model_1, model_2],
+ coregional_model_config=models_config.parse_config(coreg_dict),
+ )
+ # Configurations of DALIA
+ dalia_dict = {
+ "solver": {
+ "type": "serinv",
+ "min_processes": 1,
+ },
+ "minimize": {
+ "max_iter": 100,
+ "gtol": 1e-3,
+ "disp": True,
+ "maxcor": len(coreg_model.theta_external),
+ },
+ "f_reduction_tol": 1e-3,
+ "theta_reduction_tol": 1e-4,
+ "inner_iteration_max_iter": 50,
+ "eps_inner_iteration": 1e-3,
+ "eps_gradient_f": 1e-3,
+ "eps_hessian_f": 5 * 1e-3,
+ "simulation_dir": f"{EXAMPLE_PATH}",
+ }
+ dalia = DALIA(
+ model=coreg_model,
+ config=dalia_config.parse_config(dalia_dict),
+ )
+ # Run the optimization
+ results = dalia.run()
+
+ # Check iterations behavior
+ success_msg : str = "success"
+ if results["optimization_iterations"] > TYPICAL_N_ITER:
+ success_msg = "warning_more_iters_than_typical"
+ elif results["optimization_iterations"] < TYPICAL_N_ITER:
+ success_msg = "success_less_iters_than_typical"
+
+ # Compare hyperparameters
+ theta_ref = np.load(f"{EXAMPLE_PATH}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/theta_ref.npy")
+ print(f"theta_ref: {theta_ref}")
+ print(f"theta_dalia: {get_host(results['theta_internal'])}")
+ print_msg(
+ "Norm (theta - theta_ref): ",
+ f"{np.linalg.norm(get_host(results['theta_internal']) - theta_ref):.4e}",
+ )
+ if np.linalg.norm(get_host(results["theta_internal"]) - theta_ref) > THETA_TOL:
+ return "theta_tol_exceeded"
+
+ # Compare latent parameters
+ x_ref = np.load(f"{EXAMPLE_PATH}/inputs_nv{nv}_ns{ns}_nt{nt}_nb{nb}/reference_outputs/x_ref.npy")
+ x_ref = x_ref[dalia.model.permutation_latent_variables]
+ print(f"x_ref: {x_ref}")
+ print(f"x_dalia: {get_host(results['x'])}")
+ print_msg(
+ "Norm (x - x_ref): ",
+ f"{np.linalg.norm(get_host(results['x']) - x_ref):.4e}",
+ )
+ if np.linalg.norm(get_host(results["x"]) - x_ref) > X_TOL:
+ return "x_tol_exceeded"
+
+ return success_msg
+
+if __name__ == "__main__":
+ gstcoreg2_itest()
diff --git a/tests/integration/par1/itest.py b/tests/integration/par1/itest.py
new file mode 100644
index 00000000..fb42f224
--- /dev/null
+++ b/tests/integration/par1/itest.py
@@ -0,0 +1,110 @@
+from pathlib import Path
+import numpy as np
+
+from dalia import xp
+from dalia.configs import likelihood_config, dalia_config, submodels_config
+from dalia.core.model import Model
+from dalia.core.dalia import DALIA
+from dalia.utils import print_msg, get_host
+from dalia.submodels import RegressionSubModel, AR1SubModel
+
+SCRIPT_DIR = Path(__file__).resolve()
+DALIA_DIR = SCRIPT_DIR.parent.parent.parent.parent
+EXAMPLE_PATH = DALIA_DIR / "examples" / "p_ar1"
+
+X_TOL = 1e2
+THETA_TOL = 2e-2
+TYPICAL_N_ITER = 12
+
+def par1_itest():
+ # load reference output
+ theta_original = np.load(f"{EXAMPLE_PATH}/reference_outputs/theta_original.npy")
+ x_original = np.load(f"{EXAMPLE_PATH}/reference_outputs/x_original.npy")
+
+ ar1_dict = {
+ "type": "ar1",
+ "input_dir": f"{EXAMPLE_PATH}/inputs_ar1",
+ "phi": 0.45, # has to be between 0 and 1
+ "tau": 0.5, # precision
+ "ph_phi": {"type": "beta", "alpha": 5.0, "beta": 1.0},
+ "ph_tau": {"type": "gamma", "alpha": 2.0, "beta": 0.5},
+ }
+ ar1 = AR1SubModel(
+ config=submodels_config.parse_config(ar1_dict),
+ )
+
+ # Configurations of the regression submodel
+ regression_dict = {
+ "type": "regression",
+ "input_dir": f"{EXAMPLE_PATH}/inputs_regression",
+ "n_fixed_effects": 1,
+ "fixed_effects_prior_precision": 0.001,
+ }
+ regression = RegressionSubModel(
+ config=submodels_config.parse_config(regression_dict),
+ )
+
+ likelihood_dict = {
+ "type": "poisson",
+ "input_dir": f"{EXAMPLE_PATH}",
+ }
+
+ model = Model(
+ submodels=[ar1, regression],
+ likelihood_config=likelihood_config.parse_config(likelihood_dict),
+ )
+ # Configurations of DALIA
+ dalia_dict = {
+ "solver": {"type": "dense"},
+ "minimize": {
+ "max_iter": 100,
+ "gtol": 1e-3,
+ "disp": True,
+ "maxcor": len(model.theta_external),
+ },
+ "f_reduction_tol": 1e-3,
+ "theta_reduction_tol": 1e-4,
+ "inner_iteration_max_iter": 50,
+ "eps_inner_iteration": 1e-3,
+ "eps_gradient_f": 1e-3,
+ "simulation_dir": f"{EXAMPLE_PATH}",
+ }
+
+ dalia = DALIA(
+ model=model,
+ config=dalia_config.parse_config(dalia_dict),
+ )
+ results = dalia.run()
+
+ # Check iterations behavior
+ success_msg : str = "success"
+ if results["optimization_iterations"] > TYPICAL_N_ITER:
+ success_msg = "warning_more_iters_than_typical"
+ elif results["optimization_iterations"] < TYPICAL_N_ITER:
+ success_msg = "success_less_iters_than_typical"
+
+ # Compare hyperparameters
+ theta_user = get_host(results["theta"])
+ print("theta_ref: ", theta_original)
+ print("theta user: ", theta_user)
+ print_msg(
+ "Norm (theta - theta_ref): ",
+ f"{np.linalg.norm(theta_user - theta_original):.4e}",
+ )
+ if np.linalg.norm(theta_user - theta_original) > THETA_TOL:
+ return "theta_tol_exceeded"
+
+ # Compare latent parameters
+ print(f"x_ref: {x_original}")
+ print(f"x_dalia: {get_host(results['x'])}")
+ print_msg(
+ "Norm (x - x_ref): ",
+ f"{np.linalg.norm(get_host(results['x']) - x_original):.4e}",
+ )
+ if np.linalg.norm(get_host(results["x"]) - x_original) > X_TOL:
+ return "x_tol_exceeded"
+
+ return success_msg
+
+if __name__ == "__main__":
+ par1_itest()
diff --git a/tests/integration/pr/itest.py b/tests/integration/pr/itest.py
new file mode 100644
index 00000000..95d18902
--- /dev/null
+++ b/tests/integration/pr/itest.py
@@ -0,0 +1,69 @@
+from pathlib import Path
+import numpy as np
+
+from dalia.configs import likelihood_config, dalia_config, submodels_config
+from dalia.core.model import Model
+from dalia.core.dalia import DALIA
+from dalia.utils import print_msg, get_host
+from dalia.submodels import RegressionSubModel
+
+SCRIPT_DIR = Path(__file__).resolve()
+DALIA_DIR = SCRIPT_DIR.parent.parent.parent.parent
+EXAMPLE_PATH = DALIA_DIR / "examples" / "pr"
+
+X_TOL = 1e-5
+
+def pr_itest():
+ # Configurations of the regression submodel
+ regression_dict = {
+ "type": "regression",
+ "input_dir": f"{EXAMPLE_PATH}/inputs",
+ "n_fixed_effects": 6,
+ "fixed_effects_prior_precision": 0.001,
+ }
+ regression = RegressionSubModel(
+ config=submodels_config.parse_config(regression_dict),
+ )
+ # Likelihood
+ likelihood_dict = {
+ "type": "poisson",
+ "input_dir": f"{EXAMPLE_PATH}",
+ }
+ model = Model(
+ submodels=[regression],
+ likelihood_config=likelihood_config.parse_config(likelihood_dict),
+ )
+ # Configurations of DALIA
+ dalia_dict = {
+ "solver": {"type": "dense"},
+ "minimize": {
+ "max_iter": 100,
+ "gtol": 1e-1,
+ "disp": True,
+ },
+ "inner_iteration_max_iter": 50,
+ "eps_inner_iteration": 1e-3,
+ "eps_gradient_f": 1e-3,
+ "simulation_dir": f"{EXAMPLE_PATH}",
+ }
+ dalia = DALIA(
+ model=model,
+ config=dalia_config.parse_config(dalia_dict),
+ )
+ results = dalia.run()
+
+ # Compare latent parameters
+ x_ref = np.load(f"{EXAMPLE_PATH}/reference_outputs/x_ref.npy")
+ print(f"x_ref: {x_ref}")
+ print(f"x_dalia: {get_host(results['x'])}")
+ print_msg(
+ "Norm (x - x_ref): ",
+ f"{np.linalg.norm(get_host(results['x']) - x_ref):.4e}",
+ )
+ if np.linalg.norm(get_host(results["x"]) - x_ref) > X_TOL:
+ return "x_tol_exceeded"
+
+ return "success"
+
+if __name__ == "__main__":
+ pr_itest()
\ No newline at end of file
diff --git a/tests/integration/readme.md b/tests/integration/readme.md
new file mode 100644
index 00000000..76c7f5cc
--- /dev/null
+++ b/tests/integration/readme.md
@@ -0,0 +1,25 @@
+# How to run the integration tests
+
+## On Daint
+I recomend getting on an interactive session where you can directly run the integration tests for all configurtions (in particular types of parallelization).
+```bash
+srun --pty --partition=debug --account=xxxx bash
+```
+
+Then, you can run the integration tests sequentially using:
+```bash
+python runner.py
+```
+
+Further configurations are available in the `runner.py` file.
+
+## On Fritz
+I recomend getting on an interactive session where you can directly run the integration tests for all configurtions (in particular types of parallelization).
+```bash
+salloc -N 1 --partition=spr2tb --time=00:30:00
+```
+
+Then, you can run the integration tests sequentially using:
+```bash
+srun python runner.py
+```
\ No newline at end of file
diff --git a/tests/integration/runner.py b/tests/integration/runner.py
new file mode 100644
index 00000000..84e4c506
--- /dev/null
+++ b/tests/integration/runner.py
@@ -0,0 +1,32 @@
+import os
+
+from gst.itest import gst_itest
+from gstcoreg2.itest import gstcoreg2_itest
+from par1.itest import par1_itest
+from pr.itest import pr_itest
+
+# run_test_scripts = {
+# "gst/itest.py": ["seq", "par_f", "par_s"],
+# "gcoreg/itest.py": ["seq", "par_f", "par_s"],
+# "par1/itest.py": ["seq", "par_f"],
+# "pr/itest.py": ["seq"],
+# }
+
+itest_calls = {
+ gst_itest: ["seq"],
+ gstcoreg2_itest: ["seq"],
+ par1_itest: ["seq"],
+ pr_itest: ["seq"],
+}
+
+os.environ["ARRAY_MODULE"] = "cupy" # "numpy" or "cupy"
+
+if __name__ == "__main__":
+
+ for itest, modes in itest_calls.items():
+ print(f"{itest.__name__} in mode `{modes[0]}` returned: {itest()}")
+
+ # for mode in modes:
+ # print(f"Running {itest.__name__} in {mode} mode...")
+ # command = f"ARRAY_MODULE={mode} python tests/integration/{script}"
+ # os.system(command)
\ No newline at end of file
diff --git a/tests/runner.sh b/tests/runner.sh
new file mode 100755
index 00000000..cfe123e7
--- /dev/null
+++ b/tests/runner.sh
@@ -0,0 +1,659 @@
+#!/bin/bash
+
+# DALIA Test Runner Script
+# Automatically detects available backends and runs appropriate tests
+# Compatible with bash, zsh, tcsh, and csh shells
+
+# =============================================================================
+# Shell Detection and Compatibility Functions
+# =============================================================================
+
+detect_shell() {
+ # Detect the current shell type
+ if [ -n "$BASH_VERSION" ]; then
+ echo "bash"
+ elif [ -n "$ZSH_VERSION" ]; then
+ echo "zsh"
+ elif [ -n "$tcsh" ]; then
+ echo "tcsh"
+ elif [ -n "$version" ]; then
+ echo "tcsh" # tcsh sets $version variable
+ else
+ # Fallback: check $SHELL variable
+ case "$SHELL" in
+ *bash*) echo "bash" ;;
+ *zsh*) echo "zsh" ;;
+ *tcsh*) echo "tcsh" ;;
+ *csh*) echo "csh" ;;
+ *) echo "unknown" ;;
+ esac
+ fi
+}
+
+set_env_var() {
+ # Set environment variable using appropriate shell syntax
+ local var_name="$1"
+ local var_value="$2"
+ local shell_type=$(detect_shell)
+
+ case "$shell_type" in
+ bash|zsh|sh)
+ export "$var_name=$var_value"
+ ;;
+ tcsh|csh)
+ setenv "$var_name" "$var_value"
+ ;;
+ *)
+ # Try both methods as fallback
+ export "$var_name=$var_value" 2>/dev/null || setenv "$var_name" "$var_value" 2>/dev/null
+ ;;
+ esac
+}
+
+print_message() {
+ # Print formatted message with consistent styling
+ local message="$1"
+ local type="$2" # INFO, SUCCESS, WARNING, ERROR
+
+ case "$type" in
+ SUCCESS) echo "✓ $message" ;;
+ WARNING) echo "⚠ $message" ;;
+ ERROR) echo "✗ $message" ;;
+ *) echo "• $message" ;;
+ esac
+}
+
+# =============================================================================
+# Argument Parsing Functions
+# =============================================================================
+
+show_help() {
+ # Display help information
+ echo "DALIA Test Runner Script"
+ echo "========================"
+ echo
+ echo "Usage: $0 [OPTIONS]"
+ echo
+ echo "OPTIONS:"
+ echo " --unit Run only unit tests (unit/ directory)"
+ echo " --component-integration Run only component integration tests (component_integration/ directory)"
+ echo " --cpu Run only CPU backend tests (NumPy)"
+ echo " --gpu Run only GPU backend tests (CuPy)"
+ echo " --mpi Run only MPI distributed tests"
+ echo " --yes Skip confirmation prompt (for automated testing)"
+ echo " --help, -h Show this help message"
+ echo
+ echo "EXAMPLES:"
+ echo " $0 Run all available tests"
+ echo " $0 --unit --cpu Run unit tests on CPU backend only"
+ echo " $0 --mpi --gpu --yes Run MPI tests on GPU backend without confirmation"
+ echo " $0 --component-integration Run component integration tests on all available backends"
+ echo
+}
+
+parse_arguments() {
+ # Parse command-line arguments and set global flags
+ RUN_UNIT=0
+ RUN_COMPONENT_INTEGRATION=0
+ RUN_CPU=0
+ RUN_GPU=0
+ RUN_MPI=0
+ AUTO_CONFIRM=0
+ RUN_ALL=1 # Default to running all tests
+ BACKEND_SPECIFIED=0 # Track if any backend flag was specified
+
+ while [ $# -gt 0 ]; do
+ case "$1" in
+ --unit)
+ RUN_UNIT=1
+ RUN_ALL=0
+ ;;
+ --component-integration)
+ RUN_COMPONENT_INTEGRATION=1
+ RUN_ALL=0
+ ;;
+ --cpu)
+ RUN_CPU=1
+ RUN_ALL=0
+ BACKEND_SPECIFIED=1
+ ;;
+ --gpu)
+ RUN_GPU=1
+ RUN_ALL=0
+ BACKEND_SPECIFIED=1
+ ;;
+ --mpi)
+ RUN_MPI=1
+ RUN_ALL=0
+ BACKEND_SPECIFIED=1
+ ;;
+ --yes)
+ AUTO_CONFIRM=1
+ ;;
+ --help|-h)
+ show_help
+ exit 0
+ ;;
+ *)
+ echo "Unknown option: $1"
+ echo "Use --help for usage information."
+ exit 1
+ ;;
+ esac
+ shift
+ done
+
+ # If no specific test type was selected, run all
+ if [ $RUN_ALL -eq 1 ]; then
+ RUN_UNIT=1
+ RUN_COMPONENT_INTEGRATION=1
+ RUN_CPU=1
+ RUN_GPU=1
+ RUN_MPI=1
+ # If test directories were specified but no backends, run on all backends
+ elif [ $BACKEND_SPECIFIED -eq 0 ] && [ $RUN_ALL -eq 0 ]; then
+ RUN_CPU=1
+ RUN_GPU=1
+ RUN_MPI=1
+ fi
+}
+
+# =============================================================================
+# Backend Detection Functions
+# =============================================================================
+
+check_gpu_availability() {
+ # Check if NVIDIA GPU is available via nvidia-smi
+ if command -v nvidia-smi >/dev/null 2>&1; then
+ if nvidia-smi >/dev/null 2>&1; then
+ local gpu_count=$(nvidia-smi --query-gpu=count --format=csv,noheader,nounits | head -1)
+ print_message "GPU detected: $gpu_count NVIDIA GPU(s) available" "SUCCESS"
+ return 0
+ else
+ print_message "nvidia-smi found but not working properly" "WARNING"
+ return 1
+ fi
+ else
+ print_message "nvidia-smi not found - no NVIDIA GPU detected" "INFO"
+ return 1
+ fi
+}
+
+check_cupy_installation() {
+ # Check if CuPy is installed and working in the current environment
+ print_message "Checking CuPy installation and functionality..." "INFO"
+
+ # Try to import and test CuPy
+ python -c "
+import sys
+try:
+ import cupy as cp
+ # Test basic CUDA operation
+ test_array = cp.array([1, 2, 3])
+ result = cp.sum(test_array)
+ print('CuPy test successful: sum([1,2,3]) =', result.get())
+ sys.exit(0)
+except ImportError:
+ print('CuPy not installed')
+ sys.exit(1)
+except Exception as e:
+ print('CuPy installed but not working:', str(e))
+ sys.exit(2)
+" 2>/dev/null
+
+ local cupy_status=$?
+ case $cupy_status in
+ 0)
+ print_message "CuPy is installed and working correctly" "SUCCESS"
+ return 0
+ ;;
+ 1)
+ print_message "CuPy is not installed" "WARNING"
+ return 1
+ ;;
+ 2)
+ print_message "CuPy is installed but not functioning (check CUDA drivers)" "WARNING"
+ return 1
+ ;;
+ *)
+ print_message "Unable to test CuPy installation" "ERROR"
+ return 1
+ ;;
+ esac
+}
+
+check_mpi_installation() {
+ # Check if MPI is available (srun/mpirun/mpiexec and mpi4py)
+ print_message "Checking MPI installation and functionality..." "INFO"
+
+ # Check for srun, mpirun or mpiexec command
+ if ! command -v srun >/dev/null 2>&1 && ! command -v mpirun >/dev/null 2>&1 && ! command -v mpiexec >/dev/null 2>&1; then
+ print_message "srun/mpirun/mpiexec not found - MPI not available" "WARNING"
+ return 1
+ fi
+
+ # Determine which MPI launcher to use (priority: srun, mpirun, mpiexec)
+ if command -v srun >/dev/null 2>&1; then
+ MPI_LAUNCHER="srun"
+ elif command -v mpirun >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpirun"
+ elif command -v mpiexec >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpiexec"
+ fi
+
+ # Try to import and test mpi4py
+ $MPI_LAUNCHER -n 1 python -c "
+import sys
+try:
+ import mpi4py
+ from mpi4py import MPI
+ # Test basic MPI functionality
+ comm = MPI.COMM_WORLD
+ rank = comm.Get_rank()
+ size = comm.Get_size()
+ print('mpi4py test successful: rank', rank, 'of', size)
+ sys.exit(0)
+except ImportError:
+ print('mpi4py not installed')
+ sys.exit(1)
+except Exception as e:
+ print('mpi4py installed but not working:', str(e))
+ sys.exit(2)
+" 2>/dev/null
+
+ local mpi_status=$?
+ case $mpi_status in
+ 0)
+ print_message "MPI is installed and working correctly" "SUCCESS"
+ return 0
+ ;;
+ 1)
+ print_message "mpi4py is not installed" "WARNING"
+ return 1
+ ;;
+ 2)
+ print_message "mpi4py is installed but not functioning" "WARNING"
+ return 1
+ ;;
+ *)
+ print_message "Unable to test MPI installation" "ERROR"
+ return 1
+ ;;
+ esac
+}
+
+determine_available_backends() {
+ # Determine which backends are available for testing
+ local backends=""
+ local has_gpu=0
+ local has_mpi=0
+
+ print_message "Detecting available backends..." "INFO" >&2
+ echo >&2
+
+ # NumPy is always available
+ print_message "NumPy backend: Always available" "SUCCESS" >&2
+ backends="numpy"
+
+ # Check for GPU + CuPy
+ if check_gpu_availability >&2; then
+ if check_cupy_installation >&2; then
+ print_message "CuPy backend: Available" "SUCCESS" >&2
+ backends="$backends cupy"
+ has_gpu=1
+ else
+ print_message "CuPy backend: Unavailable (CuPy not working)" "WARNING" >&2
+ fi
+ else
+ print_message "CuPy backend: Unavailable (No GPU detected)" "WARNING" >&2
+ fi
+
+ # Check for MPI
+ if check_mpi_installation >&2; then
+ print_message "MPI backend: Available" "SUCCESS" >&2
+ backends="$backends mpi"
+ has_mpi=1
+
+ # Check MPI + GPU combination
+ if [ $has_gpu -eq 1 ]; then
+ print_message "MPI + GPU backend: Available" "SUCCESS" >&2
+ backends="$backends mpi-gpu"
+ else
+ print_message "MPI + GPU backend: Unavailable (No GPU)" "WARNING" >&2
+ fi
+ else
+ print_message "MPI backend: Unavailable (MPI not working)" "WARNING" >&2
+ fi
+
+ echo "$backends"
+}
+
+# =============================================================================
+# Test Execution Functions
+# =============================================================================
+
+choose_tests_to_run() {
+ # Display which tests will be run based on available backends
+ local backends="$1"
+ local mpi_available=0
+
+ echo "Test execution plan:"
+
+ # Show test directories
+ local test_dirs=$(get_test_directories)
+ if [ -n "$test_dirs" ]; then
+ echo " Test directories: $test_dirs"
+ else
+ echo " Test directories: all"
+ fi
+
+ # Check if MPI is actually available
+ if echo "$backends" | grep -q "mpi"; then
+ mpi_available=1
+ fi
+
+ # Show backends
+ if [ $RUN_CPU -eq 1 ] && echo "$backends" | grep -q "numpy"; then
+ if [ $RUN_MPI -eq 1 ] && [ $mpi_available -eq 1 ]; then
+ echo " - CPU backend (NumPy) - Serial and MPI tests (2 processes)"
+ else
+ echo " - CPU backend (NumPy) - Serial tests"
+ fi
+ fi
+
+ if [ $RUN_GPU -eq 1 ] && echo "$backends" | grep -q "cupy"; then
+ if [ $RUN_MPI -eq 1 ] && [ $mpi_available -eq 1 ]; then
+ echo " - GPU backend (CuPy) - Serial and MPI tests (2 processes)"
+ else
+ echo " - GPU backend (CuPy) - Serial tests"
+ fi
+ fi
+
+ if [ $RUN_MPI -eq 1 ] && [ $RUN_CPU -eq 0 ] && [ $RUN_GPU -eq 0 ]; then
+ if [ $mpi_available -eq 1 ]; then
+ echo " - CPU backend (NumPy) - MPI tests only (2 processes)"
+ if echo "$backends" | grep -q "cupy"; then
+ echo " - GPU backend (CuPy) - MPI tests only (2 processes)"
+ fi
+ else
+ echo " - MPI tests requested but MPI is not available"
+ fi
+ fi
+}
+
+get_test_directories() {
+ # Determine which test directories to run based on command-line arguments
+ local test_dirs=""
+
+ if [ $RUN_UNIT -eq 1 ] && [ $RUN_COMPONENT_INTEGRATION -eq 1 ]; then
+ test_dirs="unit/ component_integration/"
+ elif [ $RUN_UNIT -eq 1 ]; then
+ test_dirs="unit/"
+ elif [ $RUN_COMPONENT_INTEGRATION -eq 1 ]; then
+ test_dirs="component_integration/"
+ fi
+
+ echo "$test_dirs"
+}
+
+run_numpy_tests() {
+ # Run tests with NumPy backend
+ local test_dirs="$1" # Optional: specific test directories to run
+
+ echo "Running testing suite on CPU backend (NumPy)..."
+ echo "==============================================="
+
+ set_env_var "ARRAY_MODULE" "numpy"
+
+ # Determine which MPI launcher to use (priority: srun, mpirun, mpiexec)
+ if command -v srun >/dev/null 2>&1; then
+ MPI_LAUNCHER="srun -n 1"
+ elif command -v mpirun >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpirun -n 1"
+ elif command -v mpiexec >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpiexec -n 1"
+ else # No MPI available
+ MPI_LAUNCHER=""
+ fi
+
+ # Run pytest
+ if command -v pytest >/dev/null 2>&1; then
+ if [ -n "$test_dirs" ]; then
+ $MPI_LAUNCHER pytest $test_dirs -v
+ else
+ $MPI_LAUNCHER pytest . -v
+ fi
+ local exit_code=$?
+ return $exit_code
+ else
+ print_message "pytest not found - cannot run tests" "ERROR"
+ return 1
+ fi
+}
+
+run_cupy_tests() {
+ # Run tests with CuPy backend
+ local test_dirs="$1" # Optional: specific test directories to run
+
+ echo "Running testing suite on GPU backend (CuPy)..."
+ echo "==============================================="
+
+ set_env_var "ARRAY_MODULE" "cupy"
+
+ # Determine which MPI launcher to use (priority: srun, mpirun, mpiexec)
+ if command -v srun >/dev/null 2>&1; then
+ MPI_LAUNCHER="srun -n 1"
+ elif command -v mpirun >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpirun -n 1"
+ elif command -v mpiexec >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpiexec -n 1"
+ else # No MPI available
+ MPI_LAUNCHER=""
+ fi
+
+ # Run pytest
+ if command -v pytest >/dev/null 2>&1; then
+ if [ -n "$test_dirs" ]; then
+ $MPI_LAUNCHER pytest $test_dirs -v
+ else
+ $MPI_LAUNCHER pytest . -v
+ fi
+ local exit_code=$?
+ return $exit_code
+ else
+ print_message "pytest not found - cannot run tests" "ERROR"
+ return 1
+ fi
+}
+
+run_mpi_numpy_tests() {
+ # Run MPI tests with NumPy backend
+ local test_dirs="$1" # Optional: specific test directories to run
+
+ echo "Running MPI testing suite on CPU backend (NumPy)..."
+ echo "===================================================="
+
+ set_env_var "ARRAY_MODULE" "numpy"
+
+ # Determine which MPI launcher to use (priority: srun, mpirun, mpiexec)
+ if command -v srun >/dev/null 2>&1; then
+ MPI_LAUNCHER="srun"
+ elif command -v mpirun >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpirun"
+ elif command -v mpiexec >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpiexec"
+ fi
+
+ # Run with 2 processes using the MPI launcher determined during backend detection
+ echo "Running MPI tests with 2 processes using $MPI_LAUNCHER..."
+ if command -v pytest >/dev/null 2>&1; then
+ if [ -n "$test_dirs" ]; then
+ $MPI_LAUNCHER -n 2 pytest --with-mpi $test_dirs -v
+ else
+ $MPI_LAUNCHER -n 2 pytest --with-mpi . -v
+ fi
+ if [ $? -ne 0 ]; then
+ print_message "MPI tests with 2 processes failed" "ERROR"
+ return 1
+ else
+ print_message "MPI tests with 2 processes completed successfully" "SUCCESS"
+ return 0
+ fi
+ else
+ print_message "pytest not found - cannot run MPI tests" "ERROR"
+ return 1
+ fi
+}
+
+run_mpi_cupy_tests() {
+ # Run MPI tests with CuPy backend
+ local test_dirs="$1" # Optional: specific test directories to run
+
+ echo "Running MPI testing suite on GPU backend (CuPy)..."
+ echo "==================================================="
+
+ set_env_var "ARRAY_MODULE" "cupy"
+
+ # Determine which MPI launcher to use (priority: srun, mpirun, mpiexec)
+ if command -v srun >/dev/null 2>&1; then
+ MPI_LAUNCHER="srun"
+ elif command -v mpirun >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpirun"
+ elif command -v mpiexec >/dev/null 2>&1; then
+ MPI_LAUNCHER="mpiexec"
+ fi
+
+ # Run with 2 processes using the MPI launcher determined during backend detection
+ echo "Running MPI + GPU tests with 2 processes using $MPI_LAUNCHER..."
+ if command -v pytest >/dev/null 2>&1; then
+ if [ -n "$test_dirs" ]; then
+ $MPI_LAUNCHER -n 2 pytest --with-mpi $test_dirs -v
+ else
+ $MPI_LAUNCHER -n 2 pytest --with-mpi . -v
+ fi
+ if [ $? -ne 0 ]; then
+ print_message "MPI + GPU tests with 2 processes failed" "ERROR"
+ return 1
+ else
+ print_message "MPI + GPU tests with 2 processes completed successfully" "SUCCESS"
+ return 0
+ fi
+ else
+ print_message "pytest not found - cannot run MPI tests" "ERROR"
+ return 1
+ fi
+}
+
+# =============================================================================
+# Main Execution Logic
+# =============================================================================
+
+main() {
+ # Main function that orchestrates the test execution
+ local shell_type=$(detect_shell)
+
+ # Parse command-line arguments
+ parse_arguments "$@"
+
+ echo "=============================================="
+ echo " DALIA Tests Runner"
+ echo "=============================================="
+ echo "Detected shell: $shell_type"
+ echo "Current directory: $(pwd)"
+ echo
+
+ # Detect available backends
+ local available_backends=$(determine_available_backends)
+
+ # Show test plan
+ choose_tests_to_run "$available_backends"
+
+ # Ask for confirmation unless --yes was specified
+ if [ $AUTO_CONFIRM -eq 0 ]; then
+ echo -n "Proceed with test execution? [y/N]: "
+ read confirmation
+ case "$confirmation" in
+ [yY]|[yY][eE][sS])
+ echo "Starting tests..."
+ ;;
+ *)
+ echo "Test execution cancelled by user."
+ exit 0
+ ;;
+ esac
+ else
+ echo "Auto-confirmation enabled. Starting tests..."
+ fi
+
+ echo
+ local overall_success=0
+ local test_dirs=$(get_test_directories)
+
+ # Run serial tests if requested
+ if [ $RUN_CPU -eq 1 ] || [ $RUN_GPU -eq 1 ]; then
+ echo "===== SERIAL TESTS ====="
+
+ # Run NumPy tests if CPU backend is requested
+ if [ $RUN_CPU -eq 1 ] && echo "$available_backends" | grep -q "numpy"; then
+ if ! run_numpy_tests "$test_dirs"; then
+ overall_success=1
+ fi
+ echo
+ fi
+
+ # Run CuPy tests if GPU backend is requested and available
+ if [ $RUN_GPU -eq 1 ] && echo "$available_backends" | grep -q "cupy"; then
+ if ! run_cupy_tests "$test_dirs"; then
+ overall_success=1
+ fi
+ echo
+ elif [ $RUN_GPU -eq 1 ]; then
+ print_message "GPU backend requested but not available" "WARNING"
+ fi
+ fi
+
+ # Run MPI tests if requested and available
+ if [ $RUN_MPI -eq 1 ]; then
+ if echo "$available_backends" | grep -q "mpi"; then
+ echo "===== MPI TESTS ====="
+
+ # Run MPI + NumPy tests (always run MPI tests on CPU if MPI is available)
+ if ! run_mpi_numpy_tests "$test_dirs"; then
+ overall_success=1
+ fi
+ echo
+
+ # Run MPI + CuPy tests if GPU is also available and requested
+ if [ $RUN_GPU -eq 1 ] && echo "$available_backends" | grep -q "mpi-gpu"; then
+ if ! run_mpi_cupy_tests "$test_dirs"; then
+ overall_success=1
+ fi
+ echo
+ fi
+ else
+ print_message "MPI backend requested but not available - skipping MPI tests" "WARNING"
+ fi
+ fi
+
+ # Final summary
+ echo "=============================================="
+ if [ $overall_success -eq 0 ]; then
+ echo " All tests completed successfully!"
+ else
+ echo " Some tests failed - check output above"
+ fi
+ echo "=============================================="
+
+ exit $overall_success
+}
+
+# =============================================================================
+# Script Entry Point
+# =============================================================================
+
+# Check if script is being sourced or executed
+if [ "${BASH_SOURCE[0]}" = "${0}" ] 2>/dev/null || [ "${(%):-%x}" = "${0}" ] 2>/dev/null; then
+ # Script is being executed directly
+ main "$@"
+fi
+
diff --git a/tests/structured_solvers_utils.py b/tests/structured_solvers_utils.py
new file mode 100644
index 00000000..f8a82aa8
--- /dev/null
+++ b/tests/structured_solvers_utils.py
@@ -0,0 +1,265 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import numpy as np
+
+from dalia import ArrayLike, backend_flags, xp
+from tests import ATOLS, RANDOM_SEED, RTOLS
+
+np.random.seed(RANDOM_SEED)
+
+if backend_flags["cupy_avail"]:
+ import cupy as cp
+
+ cp.random.seed(cp.uint64(RANDOM_SEED))
+
+
+def _create_solver(
+ solver_type: str,
+ diagonal_blocksize: int,
+ n_diag_blocks: int,
+ arrowhead_blocksize: int = 0,
+ distributed: bool = False,
+):
+ from dalia.configs.dalia_config import SolverConfig
+
+ config = SolverConfig(type=solver_type)
+
+ if solver_type == "serinv":
+ if not distributed:
+ from dalia.solvers import SerinvSolver
+
+ return SerinvSolver(
+ config=config,
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ )
+ else:
+ import mpi4py.MPI as MPI
+
+ from dalia.solvers import DistSerinvSolver
+
+ return DistSerinvSolver(
+ config=config,
+ diagonal_blocksize=diagonal_blocksize,
+ arrowhead_blocksize=arrowhead_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ comm=MPI.COMM_WORLD,
+ nccl_comm=None,
+ )
+ else:
+ raise ValueError(f"Unknown solver type: {solver_type}")
+
+
+def _create_pobta(
+ diagonal_blocksize: int,
+ arrowhead_blocksize: int,
+ n_diag_blocks: int,
+):
+ """Returns a random, positive definite, block tridiagonal arrowhead matrix.
+
+ Parameters
+ ----------
+ diagonal_blocksize : int
+ Size of the diagonal blocks.
+ arrowhead_blocksize : int
+ Size of the arrowhead blocks.
+ n_diag_blocks : int
+ Number of diagonal blocks.
+
+ Returns
+ -------
+ A : ArrayLike
+ Random, positive definite, block tridiagonal arrowhead matrix.
+ """
+
+ A = xp.zeros(
+ (
+ diagonal_blocksize * n_diag_blocks + arrowhead_blocksize,
+ diagonal_blocksize * n_diag_blocks + arrowhead_blocksize,
+ ),
+ dtype=xp.float64,
+ )
+
+ # Fill the arrowhead blocks
+ A[-arrowhead_blocksize:, :-arrowhead_blocksize] = xp.random.rand(
+ arrowhead_blocksize, n_diag_blocks * diagonal_blocksize
+ )
+ A[-arrowhead_blocksize:, -arrowhead_blocksize:] = xp.random.rand(
+ arrowhead_blocksize, arrowhead_blocksize
+ )
+
+ # Fill the diagonal blocks
+ for i in range(n_diag_blocks):
+ A[
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ ] = xp.random.rand(diagonal_blocksize, diagonal_blocksize)
+
+ # Fill the off-diagonal blocks
+ if i < n_diag_blocks - 1:
+ A[
+ (i + 1) * diagonal_blocksize : (i + 2) * diagonal_blocksize,
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ ] = xp.random.rand(diagonal_blocksize, diagonal_blocksize)
+
+ # Make the matrix diagonally dominant (hence, if symmetric, SPD)
+ for i in range(A.shape[0]):
+ A[i, i] += 1 + xp.sum(A[i, :])
+
+ # Make symmetric
+ A = A + A.T
+
+ return A
+
+
+def _create_pobt(
+ diagonal_blocksize: int,
+ n_diag_blocks: int,
+):
+ """Returns a random, positive definite, block tridiagonal matrix.
+
+ Parameters
+ ----------
+ diagonal_blocksize : int
+ Size of the diagonal blocks.
+ n_diag_blocks : int
+ Number of diagonal blocks.
+
+ Returns
+ -------
+ A : ArrayLike
+ Random, positive definite, block tridiagonal matrix.
+ """
+
+ A = xp.zeros(
+ (
+ diagonal_blocksize * n_diag_blocks,
+ diagonal_blocksize * n_diag_blocks,
+ ),
+ dtype=xp.float64,
+ )
+
+ # Fill the diagonal blocks
+ for i in range(n_diag_blocks):
+ A[
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ ] = xp.random.rand(diagonal_blocksize, diagonal_blocksize)
+
+ # Fill the off-diagonal blocks
+ if i < n_diag_blocks - 1:
+ A[
+ (i + 1) * diagonal_blocksize : (i + 2) * diagonal_blocksize,
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ ] = xp.random.rand(diagonal_blocksize, diagonal_blocksize)
+
+ # Make symmetric
+ A = A + A.T
+
+ # Make positive definite by adding scaled identity
+ A = A + (xp.max(xp.abs(A)) + 1.0) * xp.eye(A.shape[0])
+
+ return A
+
+
+def _allclose_dense_structured(
+ A_reference: ArrayLike,
+ B_toverify: ArrayLike,
+ diagonal_blocksize: int,
+ n_diag_blocks: int,
+ arrowhead_blocksize: int = 0,
+ assert_upper_triangle: bool = False,
+):
+ """Check block-wise correctness of two structured matrices in dense storage format.
+
+ Parameters
+ ----------
+ A_reference : ArrayLike
+ First structured matrix to compare.
+ B_toverify : ArrayLike
+ Second structured matrix to compare.
+
+ diagonal_blocksize : int
+ Size of the diagonal blocks.
+ n_diag_blocks : int
+ Number of diagonal blocks.
+ arrowhead_blocksize : int, optional
+ Size of the arrowhead blocks, by default 0.
+ assert_upper_triangle : bool, optional
+ Whether to assert the upper triangle blocks as well, by default False.
+
+ Raises
+ ------
+ AssertionError
+ If any of the corresponding blocks are not close enough.
+ """
+ if arrowhead_blocksize > 0:
+ # Lower arrow blocks
+ assert xp.allclose(
+ A_reference[-arrowhead_blocksize:, :-arrowhead_blocksize],
+ B_toverify[-arrowhead_blocksize:, :-arrowhead_blocksize],
+ rtol=RTOLS["strict"],
+ atol=ATOLS["strict"],
+ )
+ if assert_upper_triangle:
+ # Upper arrow blocks
+ assert xp.allclose(
+ A_reference[:-arrowhead_blocksize, -arrowhead_blocksize:],
+ B_toverify[:-arrowhead_blocksize, -arrowhead_blocksize:],
+ rtol=RTOLS["strict"],
+ atol=ATOLS["strict"],
+ )
+
+ # Tip of the arrowhead
+ assert xp.allclose(
+ A_reference[-arrowhead_blocksize:, -arrowhead_blocksize:],
+ B_toverify[-arrowhead_blocksize:, -arrowhead_blocksize:],
+ rtol=RTOLS["strict"],
+ atol=ATOLS["strict"],
+ )
+
+ # Check the diagonal blocks
+ for i in range(n_diag_blocks):
+ assert xp.allclose(
+ A_reference[
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ ],
+ B_toverify[
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ ],
+ rtol=RTOLS["strict"],
+ atol=ATOLS["strict"],
+ )
+
+ # Check the off-diagonal (lower) blocks
+ if i < n_diag_blocks - 1:
+ assert xp.allclose(
+ A_reference[
+ (i + 1) * diagonal_blocksize : (i + 2) * diagonal_blocksize,
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ ],
+ B_toverify[
+ (i + 1) * diagonal_blocksize : (i + 2) * diagonal_blocksize,
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ ],
+ rtol=RTOLS["strict"],
+ atol=ATOLS["strict"],
+ )
+
+ if assert_upper_triangle:
+ # Check the off-diagonal (upper) blocks
+ assert xp.allclose(
+ A_reference[
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ (i + 1) * diagonal_blocksize : (i + 2) * diagonal_blocksize,
+ ],
+ B_toverify[
+ i * diagonal_blocksize : (i + 1) * diagonal_blocksize,
+ (i + 1) * diagonal_blocksize : (i + 2) * diagonal_blocksize,
+ ],
+ rtol=RTOLS["strict"],
+ atol=ATOLS["strict"],
+ )
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 00000000..9ec74d3a
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,21 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+from typing import Dict
+
+RTOLS: Dict[str, float] = {
+ "strict": 1e-14,
+ "relaxed": 1e-10,
+}
+
+ATOLS: Dict[str, float] = {
+ "strict": 1e-16,
+ "relaxed": 1e-12,
+}
+
+RANDOM_SEED = 63
+
+__all__ = [
+ "RTOLS",
+ "ATOLS",
+ "RANDOM_SEED",
+]
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
new file mode 100644
index 00000000..7dd9c468
--- /dev/null
+++ b/tests/unit/__init__.py
@@ -0,0 +1 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
diff --git a/tests/unit/solvers/__init__.py b/tests/unit/solvers/__init__.py
new file mode 100644
index 00000000..7dd9c468
--- /dev/null
+++ b/tests/unit/solvers/__init__.py
@@ -0,0 +1 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
diff --git a/tests/unit/solvers/structured_solvers/__init__.py b/tests/unit/solvers/structured_solvers/__init__.py
new file mode 100644
index 00000000..7dd9c468
--- /dev/null
+++ b/tests/unit/solvers/structured_solvers/__init__.py
@@ -0,0 +1 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
diff --git a/tests/unit/solvers/structured_solvers/conftest.py b/tests/unit/solvers/structured_solvers/conftest.py
new file mode 100644
index 00000000..d68be698
--- /dev/null
+++ b/tests/unit/solvers/structured_solvers/conftest.py
@@ -0,0 +1,62 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+from tests.structured_solvers_utils import (
+ _allclose_dense_structured,
+ _create_pobt,
+ _create_pobta,
+ _create_solver,
+)
+
+DIAGONAL_BLOCKSIZE = [
+ pytest.param(2, id="diagonal_blocksize=2"),
+ pytest.param(3, id="diagonal_blocksize=3"),
+]
+
+
+@pytest.fixture(params=DIAGONAL_BLOCKSIZE, autouse=True)
+def diagonal_blocksize(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+ARROWHEAD_BLOCKSIZE = [
+ pytest.param(0, id="arrowhead_blocksize=0"),
+ pytest.param(2, id="arrowhead_blocksize=2"),
+ pytest.param(3, id="arrowhead_blocksize=3"),
+]
+
+
+@pytest.fixture(params=ARROWHEAD_BLOCKSIZE, autouse=True)
+def arrowhead_blocksize(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+SOLVERS_TYPES = [
+ pytest.param("serinv", id="solvers_types=serinv"),
+]
+
+
+@pytest.fixture(params=SOLVERS_TYPES)
+def solver_type(request: pytest.FixtureRequest) -> str:
+ return request.param
+
+
+@pytest.fixture
+def create_solver():
+ return _create_solver
+
+
+@pytest.fixture
+def create_pobta():
+ return _create_pobta
+
+
+@pytest.fixture
+def create_pobt():
+ return _create_pobt
+
+
+@pytest.fixture
+def allclose_dense_structured():
+ return _allclose_dense_structured
diff --git a/tests/unit/solvers/structured_solvers/distributed/conftest.py b/tests/unit/solvers/structured_solvers/distributed/conftest.py
new file mode 100644
index 00000000..524cf6e7
--- /dev/null
+++ b/tests/unit/solvers/structured_solvers/distributed/conftest.py
@@ -0,0 +1,25 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+N_DIAG_BLOCKS_PER_PROCESS = [
+ pytest.param(3, id="n_diag_blocks=3"),
+ pytest.param(5, id="n_diag_blocks=5"),
+ pytest.param(10, id="n_diag_blocks=10"),
+]
+
+
+@pytest.fixture(params=N_DIAG_BLOCKS_PER_PROCESS, autouse=True)
+def n_diag_blocks_per_process(request: pytest.FixtureRequest) -> int:
+ return request.param
+
+
+NON_UNIFORM_PARTITION_SIZES = [
+ pytest.param(True, id="non_uniform_partition=True"),
+ pytest.param(False, id="non_uniform_partition=False"),
+]
+
+
+@pytest.fixture(params=NON_UNIFORM_PARTITION_SIZES, autouse=True)
+def non_uniform_partition(request: pytest.FixtureRequest) -> bool:
+ return request.param
diff --git a/tests/unit/solvers/structured_solvers/distributed/test_spmatrix_mapping.py b/tests/unit/solvers/structured_solvers/distributed/test_spmatrix_mapping.py
new file mode 100644
index 00000000..da3a8caf
--- /dev/null
+++ b/tests/unit/solvers/structured_solvers/distributed/test_spmatrix_mapping.py
@@ -0,0 +1,90 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import copy
+
+import pytest
+
+
+@pytest.mark.mpi(min_size=2)
+def test_spmatrix_mapping(
+ allclose_dense_structured,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks_per_process,
+ arrowhead_blocksize,
+ non_uniform_partition,
+):
+ """Test the mapping functions from structured to spmatrix and back."""
+ import mpi4py.MPI as MPI
+
+ # Generate test matrix based on sparsity pattern
+ n_diag_blocks = n_diag_blocks_per_process * MPI.COMM_WORLD.Get_size() + (
+ 1 if non_uniform_partition else 0
+ )
+
+ if arrowhead_blocksize > 0: # bta
+ A_initial = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A_initial = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Make deep copies of A_initial to keep as references
+ A_reference = copy.deepcopy(A_initial)
+ A_pattern = copy.deepcopy(A_initial)
+
+ # Create solver
+ solver = create_solver(
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+ distributed=True,
+ )
+
+ # Run mapping functions
+ for _ in range(2):
+ # Run twice to test for potential JIT caching issues
+ solver._spmatrix_to_structured(
+ A_initial,
+ sparsity="bta" if arrowhead_blocksize > 0 else "bt",
+ )
+
+ A_solver = solver._structured_to_spmatrix(
+ A_pattern,
+ sparsity="bta" if arrowhead_blocksize > 0 else "bt",
+ symmetrize=True,
+ )
+
+ # Verify that the initial matrix is identical to the mapped one
+ allclose_dense_structured(
+ A_reference=(
+ A_reference.toarray() if hasattr(A_reference, "toarray") else A_reference
+ ),
+ B_toverify=A_solver.toarray() if hasattr(A_solver, "toarray") else A_solver,
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ )
+
+ # Verify that the 2 matrices given to the mapping functions are untouched
+ allclose_dense_structured(
+ A_reference=(
+ A_reference.toarray() if hasattr(A_reference, "toarray") else A_reference
+ ),
+ B_toverify=A_initial.toarray() if hasattr(A_initial, "toarray") else A_initial,
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ )
+
+ allclose_dense_structured(
+ A_reference=(
+ A_reference.toarray() if hasattr(A_reference, "toarray") else A_reference
+ ),
+ B_toverify=A_pattern.toarray() if hasattr(A_pattern, "toarray") else A_pattern,
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ )
diff --git a/tests/unit/solvers/structured_solvers/sequential/conftest.py b/tests/unit/solvers/structured_solvers/sequential/conftest.py
new file mode 100644
index 00000000..76ea69a6
--- /dev/null
+++ b/tests/unit/solvers/structured_solvers/sequential/conftest.py
@@ -0,0 +1,15 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import pytest
+
+N_DIAG_BLOCKS = [
+ pytest.param(1, id="n_diag_blocks=1"),
+ pytest.param(2, id="n_diag_blocks=2"),
+ pytest.param(3, id="n_diag_blocks=3"),
+ pytest.param(4, id="n_diag_blocks=4"),
+]
+
+
+@pytest.fixture(params=N_DIAG_BLOCKS, autouse=True)
+def n_diag_blocks(request: pytest.FixtureRequest) -> int:
+ return request.param
diff --git a/tests/unit/solvers/structured_solvers/sequential/test_spmatrix_mapping.py b/tests/unit/solvers/structured_solvers/sequential/test_spmatrix_mapping.py
new file mode 100644
index 00000000..a068ac74
--- /dev/null
+++ b/tests/unit/solvers/structured_solvers/sequential/test_spmatrix_mapping.py
@@ -0,0 +1,79 @@
+# Copyright 2024-2025 DALIA authors. All rights reserved.
+
+import copy
+
+import pytest
+
+
+@pytest.mark.mpi_skip()
+def test_spmatrix_mapping(
+ allclose_dense_structured,
+ create_solver,
+ create_pobta,
+ create_pobt,
+ solver_type,
+ diagonal_blocksize,
+ n_diag_blocks,
+ arrowhead_blocksize,
+):
+ """Test the mapping functions from structured to spmatrix and back."""
+ # Generate test matrix based on sparsity pattern
+ if arrowhead_blocksize > 0: # bta
+ A_initial = create_pobta(diagonal_blocksize, arrowhead_blocksize, n_diag_blocks)
+ else: # bt
+ A_initial = create_pobt(diagonal_blocksize, n_diag_blocks)
+
+ # Make deep copies of A_initial to keep as references
+ A_reference = copy.deepcopy(A_initial)
+ A_pattern = copy.deepcopy(A_initial)
+
+ # Create solver
+ solver = create_solver(
+ solver_type, diagonal_blocksize, n_diag_blocks, arrowhead_blocksize
+ )
+
+ # Run mapping functions
+ for _ in range(2):
+ # Run twice to test for potential JIT caching issues
+ solver._spmatrix_to_structured(
+ A_initial,
+ sparsity="bta" if arrowhead_blocksize > 0 else "bt",
+ )
+
+ A_solver = solver._structured_to_spmatrix(
+ A_pattern,
+ sparsity="bta" if arrowhead_blocksize > 0 else "bt",
+ symmetrize=True,
+ )
+
+ # Verify that the initial matrix is identical to the mapped one
+ allclose_dense_structured(
+ A_reference=(
+ A_reference.toarray() if hasattr(A_reference, "toarray") else A_reference
+ ),
+ B_toverify=A_solver.toarray() if hasattr(A_solver, "toarray") else A_solver,
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ )
+
+ # Verify that the 2 matrices given to the mapping functions are untouched
+ allclose_dense_structured(
+ A_reference=(
+ A_reference.toarray() if hasattr(A_reference, "toarray") else A_reference
+ ),
+ B_toverify=A_initial.toarray() if hasattr(A_initial, "toarray") else A_initial,
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ )
+
+ allclose_dense_structured(
+ A_reference=(
+ A_reference.toarray() if hasattr(A_reference, "toarray") else A_reference
+ ),
+ B_toverify=A_pattern.toarray() if hasattr(A_pattern, "toarray") else A_pattern,
+ diagonal_blocksize=diagonal_blocksize,
+ n_diag_blocks=n_diag_blocks,
+ arrowhead_blocksize=arrowhead_blocksize,
+ )