This library is in the early stages of development.
PyTorch versions are not yet supported.
You can choose between the following installation methods:
# From PyPI (recommended)
pip install np-family
# Latest version (from current branch; dev-v4)
pip install git+https://github.com/yuneg11/Neural-Process-Family@dev-v4
# Specific release (from tag; 0.0.1.dev0)
pip install git+https://github.com/yuneg11/Neural-Process-Family@0.0.1.dev0Then, you can use the library as follows:
from npf.jax.models import CNP
cnp = CNP(y_dim=1)You should handle other logics (include train, evaluation, etc...)
# Dependencies
pip install rich nxcl==0.0.3.dev3
## And ML frameworks (JAX, PyTorch)
# ex) pip install jax
# Clone the repository
git clone https://github.com/yuneg11/Neural-Process-Family npf
cd npfThen, you can run the experiment, for example:
python scripts/jax/train.py -f configs/gp/rbf/inf/anp.yaml -lr 0.0001 --model.train_kwargs.num_latents 30The output will be saved under outs/ directory.
Details will be added in the future.
python -m npf.jax.data.save \
--root <dataset-root> \
--dataset <dataset-name><dataset-root>: The root path to save dataset. Default:./datasets/<dataset-name>: The name of the dataset to save. See below sections for available datasets.
You should install torch and torchvision to download the datastes.
You can find the details in the download page.
For example,
# CUDA 11.3
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch- MNIST:
mnist - CIFAR10:
cifar10 - CIFAR100:
cifar100 - CelebA:
celeba - SVHN:
svhn
You should install numba and wget to simulate or download the datasets.
pip install numba wget-
Lotka Volterra:
lotka_volterraTODO: See
npf.jax.data.save:save_lotka_volterrafor more detailed options.
- CNP: Conditional Neural Process
- NP: Neural Process
- CANP: Conditional Attentive Neural Process
- ANP: Attentive Neural Process
- BNP: Bootstrapping Neural Process
- BANP: Bootstrapping Attentive Neural Process
- NeuBNP: Neural Bootstrapping Neural Process
- NeuBANP: Neural Bootstrapping Attentive Neural Process
- ConvCNP: Convolutional Conditional Neural Process
- ConvNP: Convolutional Neural Process
python scripts/jax/train.py -f <config-file> [additional-options]You can use your own config file or use the provided config files in the configs directory.
For example, the following command will train a CNP model with learning rate of 0.0001 for 100 epochs:
python scripts/jax/train.py -f configs/gp/rbf/inf/anp.yaml \
-lr 0.0001 \
--train.num_epochs 100You can see the help of the config file by using the following command:
python scripts/jax/train.py -f <config-file> --help# From a trained model directory
python scripts/jax/test.py -d <model-output-dir> [additional-options]
# From a new config file and a trained model checkpoint
python scripts/jax/test.py -f <config-file> -c <checkpoint-file-path> [additional-options]You can directly test the trained model by specifying the output directory. For example:
python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vwehwhere outs/CNP/Train/RBF/Inf/220704-181313-vweh is the output directory of the trained model.
You can also replace or add the test-specific configs from the config file using the -tf / --test-config-file option.
For example:
python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh \
-tf configs/gp/robust/matern.yaml# From a trained model directory
python scripts/jax/test_bo.py -d <model-output-dir> [additional-options]
# From a new config file and a trained model checkpoint
python scripts/jax/test_bo.py -f <config-file> -c <checkpoint-file-path> [additional-options]Similar to above the test script, you can directly test the trained model by specifying the output directory.
For example:
python scripts/jax/test_bo.py -d outs/CNP/Train/RBF/Inf/220704-181313-vwehYou can also replace or add the test-specific configs from the config file using the -bf / --bo-config-file option.
For example:
python scripts/jax/test.py -d outs/CNP/Train/RBF/Inf/220704-181313-vweh \
-bf configs/gp/rbf/bo_config.yaml-
1D regression (
x:[B, P, 1],y:[B, P, 1],mask:[B, P])- Gaussian processes, etc...
-
2D Image (
x:[B, P, P, 2],y:[B, P, P, (1 or 3)],mask:[B, P, P])- Image completion, super resolution, etc...
-
Bayesian optimization (
x:[B, P, D],y:[B, P, 1],mask:[B, P])
x:[batch, *data_specific_dims, data_dim]y:[batch, *data_specific_dims, data_dim]mask:[batch, *data_specific_dims]outs:[batch, *model_specific_dims, *data_specific_dims, data_dim]
-
At
CNP1D regression:x:[batch, point, 1]y:[batch, point, 1]mask:[batch, point]outs:[batch, point, 1]
-
At
NP1D regression:x:[batch, point, 1]y:[batch, point, 1]mask:[batch, point]outs:[batch, latent, point, 1]
-
At
CNP2D image regression:x:[batch, height, width, 2]y:[batch, height, width, 1 or 3]mask:[batch, height, width]outs:[batch, height, width, 1 or 3]
-
At
NP2D image regression:x:[batch, height, width, 2]y:[batch, height, width, 1 or 3]mask:[batch, height, width]outs:[batch, latent, height, width, 1 or 3]
-
At
BNP1D regression:x:[batch, point, 1]y:[batch, point, 1]mask:[batch, point]outs:[batch, sample, point, 1]
-
At
BNP2D image regression:x:[batch, height, width, 2]y:[batch, height, width, 1 or 3]mask:[batch, height, width]outs:[batch, sample, height, width, 1 or 3]
- We used Cloud TPUs supported by Google’s TPU Research Cloud (TRC).
- Synthetic GP dataset codes based on juho-lee/bnp.
- SetConv modules based on wesselb/gabriel-convcnp.