diff --git a/.gitignore b/.gitignore index 5fb6284..939e37f 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,7 @@ __pycache__/ .eggs/ # Jupyter -.ipynb_checkpoints/ \ No newline at end of file +.ipynb_checkpoints/ + +# Mac File System +.DS_Store \ No newline at end of file diff --git a/model/__init__.py b/__init__.py similarity index 100% rename from model/__init__.py rename to __init__.py diff --git a/devtools/create_env.sh b/devtools/create_env.sh index 12deec1..0b7ae8a 100644 --- a/devtools/create_env.sh +++ b/devtools/create_env.sh @@ -1,10 +1,12 @@ -# Developed by Kevin A. Spiekermann +# Developed by Xiaorui Dong and Kevin A. Spiekermann # This script does the following tasks: # - creates the conda -# - prompts user for desired CUDA version # - installs PyTorch with specified CUDA version in the environment # - installs torch torch-geometric in the environment +SCRIPT_DIR=$(dirname $0) + +CONDA_ENV_NAME="GeoMol" # get OS type unameOut="$(uname -s)" @@ -17,66 +19,17 @@ case "${unameOut}" in esac echo "Running ${machine}..." +if [ "$machine" != "MacOS" ]; then + # Prompt the user to input their desired CUDA version or 'cpu' + echo "Please input your desired CUDA version in the format xx.xx (e.g., 10.2, 12.3) or 'cpu' for no CUDA available:" + read cuda_input -# request user to select one of the supported CUDA versions -# source: https://pytorch.org/get-started/locally/ -PS3='Please enter 1, 2, 3, or 4 to specify the desired CUDA version from the options above: ' -options=("9.2" "10.1" "10.2" "cpu" "Quit") -select opt in "${options[@]}" -do - case $opt in - "9.2") - CUDA="cudatoolkit=9.2" - CUDA_VERSION="cu92" - break - ;; - "10.1") - CUDA="cudatoolkit=10.1" - CUDA_VERSION="cu101" - break - ;; - "10.2") - CUDA="cudatoolkit=10.2" - CUDA_VERSION="cu102" - break - ;; - "cpu") - # "cpuonly" works for Linux and Windows - CUDA="cpuonly" - # Mac does not use "cpuonly" - if [ $machine == "Mac" ] - then - CUDA=" " - fi - CUDA_VERSION="cpu" - break - ;; - "Quit") - exit - ;; - *) echo "invalid option $REPLY";; - esac -done - -echo "Creating conda environment..." -echo "Running: conda env create -f environment.yml" -conda env create -f devtools/environment.yml +if [ "$machine" == "MacOS" ] && [ "$(uname -m)" == "arm64" ]; then -# activate the environment to install torch-geometric -source activate GeoMol + $SHELL $SCRIPT_DIR/install_pyg_macos_arm64.sh -n $CONDA_ENV_NAME -echo "Installing PyTorch with requested CUDA version..." -echo "Running: conda install pytorch torchvision $CUDA -c pytorch" -conda install pytorch torchvision $CUDA -c pytorch +else -echo "Installing torch-geometric..." -echo "Using CUDA version: $CUDA_VERSION" -# get PyTorch version -TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") -echo "Using PyTorch version: $TORCH_VERSION" + source $SCRIPT_DIR/install_pyg.sh -n $CONDA_ENV_NAME -c $cuda_input -pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html -pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html -pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html -pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html -pip install torch-geometric +fi diff --git a/devtools/environment.yml b/devtools/environment.yml index 02962b2..6750a7f 100644 --- a/devtools/environment.yml +++ b/devtools/environment.yml @@ -1,146 +1,10 @@ name: GeoMol channels: - - rdkit - - anaconda - - conda-forge - defaults + - conda-forge dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=4.5=1_gnu - - argon2-cffi=20.1.0=py37h5e8e339_2 - - async_generator=1.10=py_0 - - attrs=21.2.0=pyhd8ed1ab_0 - - backcall=0.2.0=pyh9f0ad1d_0 - - backports=1.0=py_2 - - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 - - blas=1.0=mkl - - bleach=3.3.0=pyh44b312d_0 - - bzip2=1.0.8=h7b6447c_0 - - ca-certificates=2021.5.30=ha878542_0 - - cairo=1.16.0=hf32fb01_1 - - certifi=2021.5.30=py37h89c1867_0 - - cffi=1.14.5=py37hc58025e_0 - - cycler=0.10.0=py37_0 - - dbus=1.13.18=hb2f20db_0 - - decorator=4.4.2=pyhd3eb1b0_0 - - defusedxml=0.7.1=pyhd8ed1ab_0 - - entrypoints=0.3=pyhd8ed1ab_1003 - - expat=2.4.1=h2531618_2 - - fontconfig=2.13.1=h6c09931_0 - - freetype=2.10.4=h5ab3b9f_0 - - glib=2.68.2=h36276a3_0 - - gst-plugins-base=1.14.0=h8213a91_2 - - gstreamer=1.14.0=h28cd5cc_2 - - icu=58.2=he6710b0_3 - - importlib-metadata=4.6.0=py37h89c1867_0 - - intel-openmp=2021.2.0=h06a4308_610 - - ipykernel=5.5.5=py37h085eea5_0 - - ipython=7.25.0=py37h085eea5_1 - - ipython_genutils=0.2.0=py_1 - - ipywidgets=7.6.3=pyhd3deb0d_0 - - jedi=0.18.0=py37h89c1867_2 - - jinja2=3.0.1=pyhd8ed1ab_0 - - jpeg=9b=h024ee3a_2 - - jsonschema=3.2.0=pyhd8ed1ab_3 - - jupyter=1.0.0=py37h89c1867_6 - - jupyter_client=6.1.12=pyhd8ed1ab_0 - - jupyter_console=6.4.0=pyhd8ed1ab_0 - - jupyter_core=4.7.1=py37h89c1867_0 - - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 - - jupyterlab_widgets=1.0.0=pyhd8ed1ab_1 - - kiwisolver=1.3.1=py37h2531618_0 - - lcms2=2.12=h3be6417_0 - - ld_impl_linux-64=2.35.1=h7274673_9 - - libboost=1.73.0=h3ff78a5_11 - - libffi=3.3=he6710b0_2 - - libgcc-ng=9.3.0=h5101ec6_17 - - libgfortran-ng=7.5.0=ha8ba4b0_17 - - libgfortran4=7.5.0=ha8ba4b0_17 - - libgomp=9.3.0=h5101ec6_17 - - libpng=1.6.37=hbc83047_0 - - libsodium=1.0.18=h36c2ea0_1 - - libstdcxx-ng=9.3.0=hd4cf53a_17 - - libtiff=4.2.0=h85742a9_0 - - libuuid=1.0.3=h1bed415_2 - - libwebp-base=1.2.0=h27cfd23_0 - - libxcb=1.14=h7b6447c_0 - - libxml2=2.9.12=h03d6c58_0 - - lz4-c=1.9.3=h2531618_0 - - markupsafe=2.0.1=py37h5e8e339_0 - - matplotlib=3.3.4=py37h06a4308_0 - - matplotlib-base=3.3.4=py37h62a2d02_0 - - matplotlib-inline=0.1.2=pyhd8ed1ab_2 - - mistune=0.8.4=py37h5e8e339_1004 - - mkl=2021.2.0=h06a4308_296 - - mkl-service=2.3.0=py37h27cfd23_1 - - mkl_fft=1.3.0=py37h42c9631_2 - - mkl_random=1.2.1=py37ha9443f7_2 - - nbclient=0.5.3=pyhd8ed1ab_0 - - nbconvert=6.1.0=py37h89c1867_0 - - nbformat=5.1.3=pyhd8ed1ab_0 - - ncurses=6.2=he6710b0_1 - - nest-asyncio=1.5.1=pyhd8ed1ab_0 - - networkx=2.5.1=pyhd3eb1b0_0 - - notebook=6.4.0=pyha770c72_0 - - numpy=1.20.2=py37h2d18471_0 - - numpy-base=1.20.2=py37hfae3a4d_0 - - olefile=0.46=py37_0 - - openssl=1.1.1k=h7f98852_0 - - packaging=20.9=pyh44b312d_0 - - pandas=1.2.5=py37h295c915_0 - - pandoc=2.14.0.3=h7f98852_0 - - pandocfilters=1.4.2=py_1 - - parso=0.8.2=pyhd8ed1ab_0 - - pcre=8.45=h295c915_0 - - pexpect=4.8.0=pyh9f0ad1d_2 - - pickleshare=0.7.5=py_1003 - - pillow=8.2.0=py37he98fc37_0 - - pip=21.1.3=py37h06a4308_0 - - pixman=0.40.0=h7b6447c_0 - - pot=0.7.0=py37h3340039_0 - - prometheus_client=0.11.0=pyhd8ed1ab_0 - - prompt-toolkit=3.0.19=pyha770c72_0 - - prompt_toolkit=3.0.19=hd8ed1ab_0 - - ptyprocess=0.7.0=pyhd3deb0d_0 - - py-boost=1.73.0=py37ha9443f7_11 - - py3dmol=0.9.1=pyhd8ed1ab_0 - - pycparser=2.20=pyh9f0ad1d_2 - - pygments=2.9.0=pyhd8ed1ab_0 - - pyparsing=2.4.7=pyhd3eb1b0_0 - - pyqt=5.9.2=py37h05f1152_2 - - pyrsistent=0.17.3=py37h5e8e339_2 - - python=3.7.10=h12debd9_4 - - python-dateutil=2.8.1=pyhd3eb1b0_0 - - python_abi=3.7=2_cp37m - - pytz=2021.1=pyhd3eb1b0_0 - - pyyaml=5.3.1=py37h7b6447c_1 - - pyzmq=22.1.0=py37h336d617_0 - - qt=5.9.7=h5867ecd_1 - - qtconsole=5.1.1=pyhd8ed1ab_0 - - qtpy=1.9.0=py_0 - - rdkit=2020.09.1.0=py37hd50e099_1 - - readline=8.1=h27cfd23_0 - - scipy=1.6.2=py37had2a1c9_1 - - seaborn=0.11.1=pyhd3eb1b0_0 - - send2trash=1.7.1=pyhd8ed1ab_0 - - setuptools=52.0.0=py37h06a4308_0 - - sip=4.19.8=py37hf484d3e_0 - - six=1.16.0=pyhd3eb1b0_0 - - sqlite=3.36.0=hc218d9a_0 - - terminado=0.10.1=py37h89c1867_0 - - testpath=0.5.0=pyhd8ed1ab_0 - - tk=8.6.10=hbc83047_0 - - tornado=6.1=py37h27cfd23_0 - - tqdm=4.61.1=pyhd3eb1b0_1 - - traitlets=5.0.5=py_0 - - typing_extensions=3.10.0.0=pyha770c72_0 - - wcwidth=0.2.5=pyh9f0ad1d_2 - - webencodings=0.5.1=py_1 - - wheel=0.36.2=pyhd3eb1b0_0 - - widgetsnbextension=3.5.1=py37h89c1867_4 - - xz=5.2.5=h7b6447c_0 - - yaml=0.2.5=h7b6447c_0 - - zeromq=4.3.4=h9c3ff4c_0 - - zipp=3.4.1=pyhd8ed1ab_0 - - zlib=1.2.11=h7b6447c_3 - - zstd=1.4.9=haebb681_0 + - rdkit >=2020.03.2 + - networkx + - pot + - yaml + - pyyaml diff --git a/devtools/environment_reproduce.yml b/devtools/environment_reproduce.yml new file mode 100644 index 0000000..02962b2 --- /dev/null +++ b/devtools/environment_reproduce.yml @@ -0,0 +1,146 @@ +name: GeoMol +channels: + - rdkit + - anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - argon2-cffi=20.1.0=py37h5e8e339_2 + - async_generator=1.10=py_0 + - attrs=21.2.0=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=py_2 + - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - blas=1.0=mkl + - bleach=3.3.0=pyh44b312d_0 + - bzip2=1.0.8=h7b6447c_0 + - ca-certificates=2021.5.30=ha878542_0 + - cairo=1.16.0=hf32fb01_1 + - certifi=2021.5.30=py37h89c1867_0 + - cffi=1.14.5=py37hc58025e_0 + - cycler=0.10.0=py37_0 + - dbus=1.13.18=hb2f20db_0 + - decorator=4.4.2=pyhd3eb1b0_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - entrypoints=0.3=pyhd8ed1ab_1003 + - expat=2.4.1=h2531618_2 + - fontconfig=2.13.1=h6c09931_0 + - freetype=2.10.4=h5ab3b9f_0 + - glib=2.68.2=h36276a3_0 + - gst-plugins-base=1.14.0=h8213a91_2 + - gstreamer=1.14.0=h28cd5cc_2 + - icu=58.2=he6710b0_3 + - importlib-metadata=4.6.0=py37h89c1867_0 + - intel-openmp=2021.2.0=h06a4308_610 + - ipykernel=5.5.5=py37h085eea5_0 + - ipython=7.25.0=py37h085eea5_1 + - ipython_genutils=0.2.0=py_1 + - ipywidgets=7.6.3=pyhd3deb0d_0 + - jedi=0.18.0=py37h89c1867_2 + - jinja2=3.0.1=pyhd8ed1ab_0 + - jpeg=9b=h024ee3a_2 + - jsonschema=3.2.0=pyhd8ed1ab_3 + - jupyter=1.0.0=py37h89c1867_6 + - jupyter_client=6.1.12=pyhd8ed1ab_0 + - jupyter_console=6.4.0=pyhd8ed1ab_0 + - jupyter_core=4.7.1=py37h89c1867_0 + - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 + - jupyterlab_widgets=1.0.0=pyhd8ed1ab_1 + - kiwisolver=1.3.1=py37h2531618_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - libboost=1.73.0=h3ff78a5_11 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libgomp=9.3.0=h5101ec6_17 + - libpng=1.6.37=hbc83047_0 + - libsodium=1.0.18=h36c2ea0_1 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtiff=4.2.0=h85742a9_0 + - libuuid=1.0.3=h1bed415_2 + - libwebp-base=1.2.0=h27cfd23_0 + - libxcb=1.14=h7b6447c_0 + - libxml2=2.9.12=h03d6c58_0 + - lz4-c=1.9.3=h2531618_0 + - markupsafe=2.0.1=py37h5e8e339_0 + - matplotlib=3.3.4=py37h06a4308_0 + - matplotlib-base=3.3.4=py37h62a2d02_0 + - matplotlib-inline=0.1.2=pyhd8ed1ab_2 + - mistune=0.8.4=py37h5e8e339_1004 + - mkl=2021.2.0=h06a4308_296 + - mkl-service=2.3.0=py37h27cfd23_1 + - mkl_fft=1.3.0=py37h42c9631_2 + - mkl_random=1.2.1=py37ha9443f7_2 + - nbclient=0.5.3=pyhd8ed1ab_0 + - nbconvert=6.1.0=py37h89c1867_0 + - nbformat=5.1.3=pyhd8ed1ab_0 + - ncurses=6.2=he6710b0_1 + - nest-asyncio=1.5.1=pyhd8ed1ab_0 + - networkx=2.5.1=pyhd3eb1b0_0 + - notebook=6.4.0=pyha770c72_0 + - numpy=1.20.2=py37h2d18471_0 + - numpy-base=1.20.2=py37hfae3a4d_0 + - olefile=0.46=py37_0 + - openssl=1.1.1k=h7f98852_0 + - packaging=20.9=pyh44b312d_0 + - pandas=1.2.5=py37h295c915_0 + - pandoc=2.14.0.3=h7f98852_0 + - pandocfilters=1.4.2=py_1 + - parso=0.8.2=pyhd8ed1ab_0 + - pcre=8.45=h295c915_0 + - pexpect=4.8.0=pyh9f0ad1d_2 + - pickleshare=0.7.5=py_1003 + - pillow=8.2.0=py37he98fc37_0 + - pip=21.1.3=py37h06a4308_0 + - pixman=0.40.0=h7b6447c_0 + - pot=0.7.0=py37h3340039_0 + - prometheus_client=0.11.0=pyhd8ed1ab_0 + - prompt-toolkit=3.0.19=pyha770c72_0 + - prompt_toolkit=3.0.19=hd8ed1ab_0 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - py-boost=1.73.0=py37ha9443f7_11 + - py3dmol=0.9.1=pyhd8ed1ab_0 + - pycparser=2.20=pyh9f0ad1d_2 + - pygments=2.9.0=pyhd8ed1ab_0 + - pyparsing=2.4.7=pyhd3eb1b0_0 + - pyqt=5.9.2=py37h05f1152_2 + - pyrsistent=0.17.3=py37h5e8e339_2 + - python=3.7.10=h12debd9_4 + - python-dateutil=2.8.1=pyhd3eb1b0_0 + - python_abi=3.7=2_cp37m + - pytz=2021.1=pyhd3eb1b0_0 + - pyyaml=5.3.1=py37h7b6447c_1 + - pyzmq=22.1.0=py37h336d617_0 + - qt=5.9.7=h5867ecd_1 + - qtconsole=5.1.1=pyhd8ed1ab_0 + - qtpy=1.9.0=py_0 + - rdkit=2020.09.1.0=py37hd50e099_1 + - readline=8.1=h27cfd23_0 + - scipy=1.6.2=py37had2a1c9_1 + - seaborn=0.11.1=pyhd3eb1b0_0 + - send2trash=1.7.1=pyhd8ed1ab_0 + - setuptools=52.0.0=py37h06a4308_0 + - sip=4.19.8=py37hf484d3e_0 + - six=1.16.0=pyhd3eb1b0_0 + - sqlite=3.36.0=hc218d9a_0 + - terminado=0.10.1=py37h89c1867_0 + - testpath=0.5.0=pyhd8ed1ab_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py37h27cfd23_0 + - tqdm=4.61.1=pyhd3eb1b0_1 + - traitlets=5.0.5=py_0 + - typing_extensions=3.10.0.0=pyha770c72_0 + - wcwidth=0.2.5=pyh9f0ad1d_2 + - webencodings=0.5.1=py_1 + - wheel=0.36.2=pyhd3eb1b0_0 + - widgetsnbextension=3.5.1=py37h89c1867_4 + - xz=5.2.5=h7b6447c_0 + - yaml=0.2.5=h7b6447c_0 + - zeromq=4.3.4=h9c3ff4c_0 + - zipp=3.4.1=pyhd8ed1ab_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.4.9=haebb681_0 diff --git a/devtools/initialize_conda.sh b/devtools/initialize_conda.sh new file mode 100644 index 0000000..f44d76a --- /dev/null +++ b/devtools/initialize_conda.sh @@ -0,0 +1,19 @@ +echo "Initializing Conda..." + +if which mamba > /dev/null; then + conda_bin="mamba" + echo "As mamba is available, using mamba by default..." +else + conda_bin="conda" +fi + +conda_base_dir=$(dirname $(dirname $CONDA_EXE)) + +if [ "$conda_bin" = "mamba" ]; then + source "$conda_base_dir/etc/profile.d/conda.sh" + source "$conda_base_dir/etc/profile.d/mamba.sh" +else + source "$conda_base_dir/etc/profile.d/conda.sh" +fi + +export conda_bin diff --git a/devtools/install_pyg.sh b/devtools/install_pyg.sh new file mode 100644 index 0000000..f79653f --- /dev/null +++ b/devtools/install_pyg.sh @@ -0,0 +1,122 @@ +# A script to install Pytorch geometric on normal platform +# Author: Xiaorui Dong +# Inspired by this https://medium.com/@jgbrasier/installing-pytorch-geometric-on-mac-m1-with-accelerated-gpu-support-2e7118535c50 + +CONDA_ENV_NAME="GeoMol" +PYTHON_VERSION="3.12" +CUDA_VERSION="cpu" +SCRIPT_DIR=$(dirname $0) # Assume the other scripts are available in the same directory as this file + +# Function to display usage +usage() { + echo "Usage: $0 [-n ] [--name ] [-v ] [--python-version ] [- ] [--cuda-version ]" + exit 1 +} + +# Parse short options (-n and -v) +while getopts ":n:v:c:" opt; do + case ${opt} in + n ) + CONDA_ENV_NAME=$OPTARG + ;; + v ) + PYTHON_VERSION=$OPTARG + ;; + c ) + CUDA_VERSION=$OPTARG + ;; + \? ) + usage + ;; + esac +done + +# Remove the processed options from the parameters +shift $((OPTIND -1)) + +# Parse long options (--name and --version) +for arg in "$@"; do + case $arg in + --name=*) + CONDA_ENV_NAME="${arg#*=}" + shift # Remove --name from processing + ;; + --python_version=*) + PYTHON_VERSION="${arg#*=}" + shift # Remove --version from processing + ;; + --cuda_version=*) + CUDA_VERSION="${arg#*=}" + shift # Remove --version from processing + ;; + *) + usage + ;; + esac +done + +# parse cuda +# Using regex to capture the major and minor version numbers for detailed matching +if [[ "$(uname)" != 'Darwin' ]]; then + if [[ $CUDA_VERSION =~ ^([0-9]+)\.([0-9]+)(\.([0-9]+))?$ ]]; then + major_version="${BASH_REMATCH[1]}" + minor_version="${BASH_REMATCH[2]}" + cuda_version_formatted="${major_version}.${minor_version}" + + # Construct the CUDA and CUDA_VERSION variables based on input + CUDA="cudatoolkit=$cuda_version_formatted" + CUDA_VERSION="cu${major_version}${minor_version}" + elif [ "$cuda_input" == "cpu" ]; then + # For CPU-only selection + CUDA="cpuonly" + CUDA_VERSION="cpu" + else + echo "Invalid input. Please ensure you enter a valid CUDA version in the format xx.xx or 'cpu'." + exit 1 + fi +else + CUDA="cpuonly" + CUDA_VERSION="cpu" +fi +echo "You selected CUDA version: $CUDA_VERSION ($CUDA)" + +source $SCRIPT_DIR/initialize_conda.sh + +if conda env list | grep -qw $CONDA_ENV_NAME; then + $conda_bin activate $CONDA_ENV_NAME +else + $conda_bin create -n $CONDA_ENV_NAME python=$PYTHON_VERSION -y + $conda_bin activate $CONDA_ENV_NAME +fi + +# check Python version +PYTHON_VERSION=$(python --version) +echo "Using Python version: $PYTHON_VERSION" + +# install PyTorch +echo "Installing PyTorch with requested CUDA version $CUDA_VERSION..." +# echo "Running: conda install pytorch torchvision $CUDA -c pytorch -y" +# $conda_bin install pytorch torchvision $CUDA -c pytorch -y +echo "Running: pip install torch torchvision" +pip install torch torchvision --index-url https://download.pytorch.org/whl/$CUDA_VERSION + +# get PyTorch version +TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") +if [ -n $TORCH_VERSION ]; then + echo "Using PyTorch version: $TORCH_VERSION" +else + echo "Cannot find a matched PyTorch version with $CUDA_VERSION for Python $PYTHON_VERSION. Exit." + # echo "Removing the installed environment" + # source deactivate + # $conda_bin env remove -n $environmentName + exit 1 +fi + +# install torch_geometric +echo "Installing torch-geometric..." +echo "Using CUDA version: $CUDA_VERSION" +pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-${TORCH_VERSION}+${CUDA_VERSION}.html +pip install torch-geometric + +# install other package +$conda_bin env update -f $SCRIPT_DIR/environment.yml -n $CONDA_ENV_NAME \ No newline at end of file diff --git a/devtools/install_pyg_macos_arm64.sh b/devtools/install_pyg_macos_arm64.sh new file mode 100644 index 0000000..b13c7b5 --- /dev/null +++ b/devtools/install_pyg_macos_arm64.sh @@ -0,0 +1,82 @@ +# A script to install Pytorch geometric on Mchip MacOS +# Author: Xiaorui Dong +# Inspired by this https://medium.com/@jgbrasier/installing-pytorch-geometric-on-mac-m1-with-accelerated-gpu-support-2e7118535c50 + +CONDA_ENV_NAME="GeoMol" +PYTHON_VERSION="3.12" +SCRIPT_DIR=$(dirname $0) # Assume the other scripts are available in the same directory as this file + +# Function to display usage +usage() { + echo "Usage: $0 [-n ] [--name ] [-v ] [--version ]" + exit 1 +} + +# Parse short options (-n and -v) +while getopts ":n:v:" opt; do + case ${opt} in + n ) + CONDA_ENV_NAME=$OPTARG + ;; + v ) + PYTHON_VERSION=$OPTARG + ;; + \? ) + usage + ;; + esac +done + +# Remove the processed options from the parameters +shift $((OPTIND -1)) + +# Parse long options (--name and --version) +for arg in "$@"; do + case $arg in + --name=*) + CONDA_ENV_NAME="${arg#*=}" + shift # Remove --name from processing + ;; + --version=*) + PYTHON_VERSION="${arg#*=}" + shift # Remove --version from processing + ;; + *) + # Handle unrecognized options + usage + ;; + esac +done + +source $SCRIPT_DIR/initialize_conda.sh + +if conda env list | grep -qw $CONDA_ENV_NAME; then + $conda_bin activate $CONDA_ENV_NAME +else + $conda_bin create -n $CONDA_ENV_NAME python=$PYTHON_VERSION -y + $conda_bin activate $CONDA_ENV_NAME +fi + +PYTHON_VERSION=$(python --version) +echo "Using Python version: $PYTHON_VERSION" + +# make sure compiler are correctly installed +$conda_bin install -y clang_osx-arm64 clangxx_osx-arm64 gfortran_osx-arm64 + +os_version=$(sw_vers -productVersion) + +# install PyTorch and pytorch_geometric with the correct compiler +echo "Installing PyTorch..." +MACOSX_DEPLOYMENT_TARGET=$os_version CC=clang CXX=clang++ python -m pip --no-cache-dir install torch torchvision +TORCH_VERSION=$(python -c "import torch; print(torch.__version__)") + +MACOSX_DEPLOYMENT_TARGET=$os_version CC=clang CXX=clang++ \ +python -m pip --no-cache-dir install torch_scatter torch_sparse torch_cluster torch_spline_conv \ +-f https://data.pyg.org/whl/torch-${TORCH_VERSION}+cpu.html + +MACOSX_DEPLOYMENT_TARGET=$os_version CC=clang CXX=clang++ \ +python -m pip --no-cache-dir install torch-geometric + +# install other packages +$conda_bin env update -f $SCRIPT_DIR/environment.yml -n $CONDA_ENV_NAME +$conda_bin install nomkl diff --git a/generate_confs.py b/generate_confs.py index c17c1ea..850652c 100644 --- a/generate_confs.py +++ b/generate_confs.py @@ -9,10 +9,9 @@ import torch import yaml -from model.model import GeoMol -from model.featurization import featurize_mol_from_smiles -from torch_geometric.data import Batch -from model.inference import construct_conformers +from geomol.model import GeoMol +from geomol.featurization import featurize_mol_from_smiles, from_data_list +from geomol.inference import construct_conformers parser = ArgumentParser() @@ -45,17 +44,17 @@ conformer_dict = {} for smi, n_confs in tqdm(test_data.values): - + # create data object (skip smiles rdkit can't handle) tg_data = featurize_mol_from_smiles(smi, dataset=dataset) if not tg_data: print(f'failed to featurize SMILES: {smi}') continue - + # generate model predictions - data = Batch.from_data_list([tg_data]) - model(data, inference=True, n_model_confs=n_confs*2) - + data = from_data_list([tg_data]) + model(data, inference=True, n_model_confs=n_confs * 2) + # set coords n_atoms = tg_data.x.size(0) model_coords = construct_conformers(data, model) @@ -73,9 +72,9 @@ except Exception as e: pass mols.append(mol) - + conformer_dict[smi] = mols - + # save to file if args.out: with open(f'{args.out}', 'wb') as f: diff --git a/model/GNN.py b/geomol/GNN.py similarity index 95% rename from model/GNN.py rename to geomol/GNN.py index d332c18..466931d 100644 --- a/model/GNN.py +++ b/geomol/GNN.py @@ -13,11 +13,12 @@ class MLP(nn.Module): Inputs: in_dim (int): number of features contained in the input layer. - out_dim (int): number of features input and output from each hidden layer, + out_dim (int): number of features input and output from each hidden layer, including the output layer. num_layers (int): number of layers in the network activation (torch function): activation function to be used during the hidden layers """ + def __init__(self, in_dim, out_dim, num_layers, activation=torch.nn.ReLU(), layer_norm=False, batch_norm=False): super(MLP, self).__init__() self.layers = nn.ModuleList() @@ -30,11 +31,13 @@ def __init__(self, in_dim, out_dim, num_layers, activation=torch.nn.ReLU(), laye self.layers.append(nn.Linear(in_dim, h_dim)) else: self.layers.append(nn.Linear(h_dim, h_dim)) - if layer_norm: self.layers.append(nn.LayerNorm(h_dim)) - if batch_norm: self.layers.append(nn.BatchNorm1d(h_dim)) + if layer_norm: + self.layers.append(nn.LayerNorm(h_dim)) + if batch_norm: + self.layers.append(nn.BatchNorm1d(h_dim)) self.layers.append(activation) self.layers.append(nn.Linear(h_dim, out_dim)) - + def forward(self, x): for i in range(len(self.layers)): x = self.layers[i](x) @@ -46,6 +49,7 @@ class MetaLayer(torch.nn.Module): `"Relational Inductive Biases, Deep Learning, and Graph Networks" `_ paper. """ + def __init__(self, edge_model=None, node_model=None): super(MetaLayer, self).__init__() self.edge_model = edge_model diff --git a/geomol/__init__.py b/geomol/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/cycle_utils.py b/geomol/cycle_utils.py similarity index 97% rename from model/cycle_utils.py rename to geomol/cycle_utils.py index 213cf27..005ed41 100644 --- a/model/cycle_utils.py +++ b/geomol/cycle_utils.py @@ -48,7 +48,7 @@ def align_coords_Kabsch(p_cycle_coords, q_cycle_coords, p_mask, q_mask=None): H = torch.matmul(p_cycle_coords_centered.permute(0, 1, 3, 2), q_cycle_coords_centered.unsqueeze(0)) u, s, v = torch.svd(H) d = torch.sign(torch.det(torch.matmul(v, u.permute(0, 1, 3, 2)))) - R_1 = torch.diag_embed(torch.ones([p_cycle_coords.size(0), q_cycle_coords.size(0), 3])) + R_1 = torch.diag_embed(torch.ones([p_cycle_coords.size(0), q_cycle_coords.size(0), 3], device=u.device)) R_1[:, :, 2, 2] = d R = torch.matmul(v, torch.matmul(R_1, u.permute(0, 1, 3, 2))) b = q_cycle_coords[:, q_mask].mean(dim=1) - torch.matmul(R, p_cycle_coords[:, :, p_mask].mean(dim=2).unsqueeze( diff --git a/geomol/featurization.py b/geomol/featurization.py new file mode 100644 index 0000000..4357e2f --- /dev/null +++ b/geomol/featurization.py @@ -0,0 +1,375 @@ +from rdkit import Chem +from rdkit.Chem.rdchem import ChiralType, HybridizationType +from rdkit.Chem.rdchem import BondType as BT + +import glob +import os.path as osp +import pickle +import random +from typing import Optional +from packaging import version + +import numpy as np +import torch +import torch.nn.functional as F +import torch_geometric as tg +from torch_geometric.data import Batch, Data, DataLoader, Dataset +from torch_scatter import scatter + +from geomol.utils import get_dihedral_pairs + +tg_version_ge_2 = version.parse(tg.__version__) > version.parse('2.0.0') + +bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} +chirality = {ChiralType.CHI_TETRAHEDRAL_CW: -1., + ChiralType.CHI_TETRAHEDRAL_CCW: 1., + ChiralType.CHI_UNSPECIFIED: 0, + ChiralType.CHI_OTHER: 0} +dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]') + +qm9_types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} +drugs_types = {'H': 0, 'Li': 1, 'B': 2, 'C': 3, 'N': 4, 'O': 5, 'F': 6, 'Na': 7, 'Mg': 8, 'Al': 9, 'Si': 10, + 'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20, + 'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30, + 'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34} +dataset_types = {'qm9': qm9_types, 'drugs': drugs_types} + + +def one_k_encoding(value, choices): + """ + Creates a one-hot encoding with an extra category for uncommon values. + :param value: The value for which the encoding should be one. + :param choices: A list of possible values. + :return: A one-hot encoding of the :code:`value` in a list of length :code:`len(choices) + 1`. + If :code:`value` is not in :code:`choices`, then the final element in the encoding is 1. + """ + encoding = [0] * (len(choices) + 1) + index = choices.index(value) if value in choices else -1 + encoding[index] = 1 + + return encoding + + +class geom_confs(Dataset): + + dataset = '' + + def __init__(self, + root, + split_path, + mode, + transform=None, + pre_transform=None, + max_confs=10): + super().__init__(root, transform, pre_transform) + + self.root = root + self.split_idx = 0 if mode == 'train' else 1 if mode == 'val' else 2 + self.split = np.load(split_path, allow_pickle=True)[self.split_idx] + self.bonds = bonds + + self.dihedral_pairs = {} # for memoization + all_files = sorted(glob.glob(osp.join(self.root, '*.pickle'))) + self.pickle_files = [f for i, f in enumerate(all_files) + if i in self.split] + self.max_confs = max_confs + self.types = dataset_types[self.dataset] + + def len(self): + # return len(self.pickle_files) # should we change this to an integer for random sampling? + return 10000 if self.split_idx == 0 else 1000 + + def get(self, idx): + data = None + while not data: + pickle_file = random.choice(self.pickle_files) + mol_dic = self.open_pickle(pickle_file) + data = self.featurize_mol(mol_dic) + + if idx in self.dihedral_pairs: + data.edge_index_dihedral_pairs = self.dihedral_pairs[idx] + else: + data.edge_index_dihedral_pairs = get_dihedral_pairs(data.edge_index, data=data) + + return data + + def open_pickle(self, mol_path): + with open(mol_path, "rb") as f: + dic = pickle.load(f) + return dic + + def featurize_mol(self, mol_dic): + confs, name = mol_dic['conformers'], mol_dic["smiles"] + random.shuffle(confs) # shuffle confs + + # filter mols rdkit can't intrinsically handle + try: + canonical_smi = Chem.MolToSmiles(Chem.MolFromSmiles(name)) + except Exception: + return None + + # skip conformers without dihedrals + if _check_mol(confs[0]['rd_mol'], smiles=name) is None: + return None + + n_atom = confs[0]['rd_mol'].GetNumAtoms() + pos = torch.zeros([self.max_confs, n_atom, 3]) + pos_mask = torch.zeros(self.max_confs, dtype=torch.int64) + k = 0 + for conf in confs: + mol = conf['rd_mol'] + + # skip mols with atoms with more than 4 neighbors for now + n_neighbors = [len(a.GetNeighbors()) for a in mol.GetAtoms()] + if np.max(n_neighbors) > 4: + continue + + # filter for conformers that may have reacted + try: + conf_canonical_smi = Chem.MolToSmiles(Chem.RemoveHs(mol)) + except Exception: + continue + + if conf_canonical_smi != canonical_smi: + continue + + pos[k] = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float) + pos_mask[k] = 1 + k += 1 + correct_mol = mol + if k == self.max_confs: + break + + # return None if no non-reactive conformers were found + if k == 0: + return None + + x, z, edge_index, edge_attr, neighbor_dict, chiral_tag \ + = _mol_to_features(correct_mol, self.dataset) + + data = Data(x=x, z=z, pos=[pos], + edge_index=edge_index, edge_attr=edge_attr, + neighbors=neighbor_dict, + chiral_tag=chiral_tag, + name=name, mol=correct_mol, + boltzmann_weight=conf['boltzmannweight'], + degeneracy=conf['degeneracy'], + pos_mask=pos_mask) + return data + + +class qm9_confs(geom_confs): + + dataset = 'qm9' + + +class drugs_confs(geom_confs): + + dataset = 'drugs' + + +def construct_loader(args, modes=('train', 'val')): + + if isinstance(modes, str): + modes = [modes] + + loaders = [] + for mode in modes: + if args.dataset == 'qm9': + dataset = qm9_confs(args.data_dir, args.split_path, mode, max_confs=args.n_true_confs) + elif args.dataset == 'drugs': + dataset = drugs_confs(args.data_dir, args.split_path, mode, max_confs=args.n_true_confs) + loader = DataLoader(dataset=dataset, + batch_size=args.batch_size, + shuffle=False if mode == 'test' else True, + num_workers=args.num_workers, + pin_memory=False) + loaders.append(loader) + + if len(loaders) == 1: + return loaders[0] + else: + return loaders + + +def smiles_to_mol(smiles: str, + check_mol: bool = True): + """ + Convert a SMILES string to a RDKit molecule. + """ + try: + mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) + except Exception: + return None + if check_mol: + return _check_mol(mol, smiles=smiles) + return mol + + +def _check_mol(mol, + smiles=None): + """ + Check if a molecule is valid. + """ + # filter fragments + if smiles is not None: + if '.' in smiles: + return None + else: + frags = Chem.rdmolops.GetMolFrags(mol, + asMols=False) + if len(frags) > 1: + return None + + # filter out mols model can't make predictions for + if mol.GetNumAtoms() < 4: + return None + if mol.GetNumBonds() < 4: + # in Lucky' original implementation + # this criteria is included in geom_confs.featurize_mol + # but not included in featurize_mol_from_smiles + # add it here anyway + return None + if not mol.HasSubstructMatch(dihedral_pattern): + return None + return mol + + +def _mol_to_features(mol, + dataset: str = 'qm9'): + """ + Prepare necessary information for converting a RDKit mol object to a torch_geometry_data object. + """ + types = dataset_types[dataset] + + type_idx = [] + atomic_number = [] + atom_features = [] + chiral_tag = [] + neighbor_dict = {} + ring = mol.GetRingInfo() + + n_atom = mol.GetNumAtoms() + # Atomic features + for i, atom in enumerate(mol.GetAtoms()): + type_idx.append(types[atom.GetSymbol()]) + if len(atom.GetNeighbors()) > 1: + n_ids = [n.GetIdx() for n in atom.GetNeighbors()] + neighbor_dict[i] = torch.tensor(n_ids) + chiral_tag.append(chirality[atom.GetChiralTag()]) + atomic_number.append(atom.GetAtomicNum()) + atom_features.extend([atom.GetAtomicNum(), + 1 if atom.GetIsAromatic() else 0]) + atom_features.extend(one_k_encoding( + atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6])) + atom_features.extend(one_k_encoding(atom.GetHybridization(), [ + HybridizationType.SP, + HybridizationType.SP2, + HybridizationType.SP3, + HybridizationType.SP3D, + HybridizationType.SP3D2])) + atom_features.extend(one_k_encoding( + atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])) + atom_features.extend(one_k_encoding( + atom.GetFormalCharge(), [-1, 0, 1])) + atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)), + int(ring.IsAtomInRingOfSize(i, 4)), + int(ring.IsAtomInRingOfSize(i, 5)), + int(ring.IsAtomInRingOfSize(i, 6)), + int(ring.IsAtomInRingOfSize(i, 7)), + int(ring.IsAtomInRingOfSize(i, 8))]) + atom_features.extend(one_k_encoding( + int(ring.NumAtomRings(i)), [0, 1, 2, 3])) + + z = torch.tensor(atomic_number, dtype=torch.long) + chiral_tag = torch.tensor(chiral_tag, dtype=torch.float) + + # Edge features + row, col, edge_type, bond_features = [], [], [], [] + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + row += [start, end] + col += [end, start] + edge_type += 2 * [bonds[bond.GetBondType()]] + bt = tuple(sorted( + [bond.GetBeginAtom().GetAtomicNum(), + bond.GetEndAtom().GetAtomicNum()] + )), bond.GetBondTypeAsDouble() + bond_features += 2 * [int(bond.IsInRing()), + int(bond.GetIsConjugated()), + int(bond.GetIsAromatic())] + + edge_index = torch.tensor([row, col], dtype=torch.long) + edge_type = torch.tensor(edge_type, dtype=torch.long) + edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) + + perm = (edge_index[0] * n_atom + edge_index[1]).argsort() + edge_index = edge_index[:, perm] + edge_type = edge_type[perm] + edge_attr = edge_attr[perm] + + row, col = edge_index + hs = (z == 1).to(torch.float) + num_hs = scatter(hs[row], col, dim_size=n_atom).tolist() + + x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types)) + x2 = torch.tensor(atom_features).view(n_atom, -1) + x = torch.cat([x1.to(torch.float), x2], dim=-1) + + return x, z, edge_index, edge_attr, neighbor_dict, chiral_tag + + +def featurize_mol(mol, + dataset: str = 'qm9', + smiles: Optional[str] = None, + name: str = ''): + """ + Featurize a molecule. + """ + mol = _check_mol(mol, smiles=smiles) + name = smiles if (smiles and not name) else name + + if mol: + x, _, edge_index, edge_attr, neighbor_dict, chiral_tag \ + = _mol_to_features(mol, dataset=dataset) + data = Data(x=x, + edge_index=edge_index, + edge_attr=edge_attr, + neighbors=neighbor_dict, + chiral_tag=chiral_tag, + name=name) + data.edge_index_dihedral_pairs \ + = get_dihedral_pairs(data.edge_index, + data=data) + return data + + +def featurize_mol_from_smiles(smiles: str, + dataset='qm9'): + """ + Featurize a molecule from a SMILES string. + """ + mol = smiles_to_mol(smiles, check_mol=True) + if mol: + return featurize_mol(mol, + dataset=dataset, + name=smiles) + + +def from_data_list(data_list: list): + """ + Creates a batch object from a list of data objects. This is useful for inference with an improvisational list of features from different molecules. + This function is a wrapper for the torch_geometric function Batch.from_data_list + with a special treatment for the neighbors attribute. If without the + treatment, neighbors will be collapsed into a single dict and only have keys in the + first elements, causing an error raised in "get_neighbor_ids". + + It has only been tested and applied for torch_geometric over version 2.0.0. + """ + if tg_version_ge_2: + batch_data = Batch.from_data_list(data_list, + exclude_keys=['neighbors']) + batch_data.neighbors = [d.neighbors for d in data_list] + else: + batch_data = Batch.from_data_list(data_list) + return batch_data diff --git a/model/inference.py b/geomol/inference.py similarity index 84% rename from model/inference.py rename to geomol/inference.py index e626f2e..93fb117 100644 --- a/model/inference.py +++ b/geomol/inference.py @@ -2,19 +2,19 @@ import numpy as np import networkx as nx import torch_geometric as tg -from model.utils import batch_dihedrals -from model.cycle_utils import * +from geomol.utils import batch_dihedrals +from geomol.cycle_utils import * -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +def construct_conformers(data, model, device=None): -def construct_conformers(data, model): + device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') G = nx.to_undirected(tg.utils.to_networkx(data)) cycles = nx.cycle_basis(G) - new_pos = torch.zeros([data.batch.size(0), model.n_model_confs, 3]) - dihedral_pairs = model.dihedral_pairs.t().detach().numpy() + new_pos = torch.zeros([data.batch.size(0), model.n_model_confs, 3], device=device) + dihedral_pairs = model.dihedral_pairs.t().detach().cpu().numpy() Sx = [] Sy = [] @@ -36,7 +36,13 @@ def construct_conformers(data, model): if any(x_cycle_check) and any(y_cycle_check): # both in new cycle cycle_indices = get_current_cycle_indices(cycles, x_cycle_check, x_index) - cycle_avg_coords, cycle_avg_indices = smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, i) # i instead of i+1 + cycle_avg_coords, cycle_avg_indices \ + = smooth_cycle_coords(model, + cycle_indices, + new_pos, + dihedral_pairs, + i, # i instead of i+1 + device) # new graph if x_index not in Sx: @@ -48,9 +54,11 @@ def construct_conformers(data, model): p_mask = [True if a in Sx else False for a in sorted(cycle_avg_indices)] q_mask = [True if a in sorted(cycle_avg_indices) else False for a in Sx] p_reorder = sorted(range(len(cycle_avg_indices)), key=lambda k: cycle_avg_indices[k]) - aligned_cycle_coords = align_coords_Kabsch(cycle_avg_coords[p_reorder].permute(1, 0, 2).unsqueeze(0), new_pos[Sx].permute(1, 0, 2), p_mask, q_mask) + aligned_cycle_coords = align_coords_Kabsch( + cycle_avg_coords[p_reorder].permute(1, 0, 2).unsqueeze(0), + new_pos[Sx].permute(1, 0, 2), p_mask, q_mask) aligned_cycle_coords = aligned_cycle_coords.squeeze(0).permute(1, 0, 2) - cycle_avg_indices_reordered = [cycle_avg_indices[l] for l in p_reorder] + cycle_avg_indices_reordered = [cycle_avg_indices[i] for i in p_reorder] # apply to all new coordinates? new_pos[cycle_avg_indices_reordered] = aligned_cycle_coords @@ -63,10 +71,10 @@ def construct_conformers(data, model): if any(y_cycle_check): cycle_indices = get_current_cycle_indices(cycles, y_cycle_check, y_index) cycle_added = True - in_cycle = len(cycle_indices)+1 + in_cycle = len(cycle_indices) + 1 # new graph - p_coords = torch.zeros([4, model.n_model_confs, 3]) + p_coords = torch.zeros([4, model.n_model_confs, 3], device=device) p_idx = model.neighbors[x_index] if x_index not in Sx: @@ -80,11 +88,11 @@ def construct_conformers(data, model): # update indices Sx.extend([x_index]) - Sx.extend(model.neighbors[x_index].detach().numpy()) + Sx.extend(model.neighbors[x_index].detach().cpu().numpy()) Sx = list(set(Sx)) Sy.extend([y_index]) - Sy.extend(model.neighbors[y_index].detach().numpy()) + Sy.extend(model.neighbors[y_index].detach().cpu().numpy()) # set px p_X = new_pos[x_index] @@ -94,12 +102,12 @@ def construct_conformers(data, model): # set Y if cycle_added: - cycle_avg_coords, cycle_avg_indices = smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, i+1) - cycle_avg_coords = cycle_avg_coords - cycle_avg_coords[cycle_avg_indices == y_index] # move y to origin + cycle_avg_coords, cycle_avg_indices = smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, i + 1, device) + cycle_avg_coords = cycle_avg_coords - cycle_avg_coords[cycle_avg_indices == y_index] # move y to origin q_idx = model.neighbors[y_index] q_coords_mask = [True if a in q_idx else False for a in cycle_avg_indices] - q_coords = torch.zeros([4, model.n_model_confs, 3]) - q_reorder = np.argsort([np.where(a == q_idx)[0][0] for a in torch.tensor(cycle_avg_indices)[q_coords_mask]]) + q_coords = torch.zeros([4, model.n_model_confs, 3], device=device) + q_reorder = torch.argsort(torch.tensor([torch.where(a == q_idx)[0][0] for a in torch.tensor(cycle_avg_indices)[q_coords_mask]])) q_coords[0:sum(q_coords_mask)] = cycle_avg_coords[q_coords_mask][q_reorder] new_pos_Sy = cycle_avg_coords.clone() Sy = cycle_avg_indices @@ -121,12 +129,12 @@ def construct_conformers(data, model): # translate q new_p_Y = new_pos_Sx_2[Sx == y_index] - transform_matrix = torch.diag(torch.tensor([-1., -1., 1.])).unsqueeze(0).unsqueeze(0) + transform_matrix = torch.diag(torch.tensor([-1., -1., 1.], device=device)).unsqueeze(0).unsqueeze(0) new_pos_Sy_3 = torch.matmul(transform_matrix, new_pos_Sy_2.unsqueeze(-1)).squeeze(-1) + new_p_Y # rotate by gamma H_gamma = calculate_gamma(model.n_model_confs, model.dihedral_mask[i], model.c_ij[i], model.v_star[i], Sx, Sy, - p_idx, q_idx, x_index, y_index, new_pos_Sx_2, new_pos_Sy_3, new_p_Y) + p_idx, q_idx, x_index, y_index, new_pos_Sx_2, new_pos_Sy_3, new_p_Y, device) new_pos_Sx_3 = torch.matmul(H_gamma.unsqueeze(0), new_pos_Sx_2.unsqueeze(-1)).squeeze(-1) # update all coordinates @@ -141,16 +149,16 @@ def construct_conformers(data, model): return new_pos -def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_start_idx): +def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_start_idx, device): # find index of cycle starting position cycle_len = len(cycle_indices) # get dihedral pairs corresponding to current cycle - cycle_pairs = dihedral_pairs[cycle_start_idx:cycle_start_idx+cycle_len] + cycle_pairs = dihedral_pairs[cycle_start_idx:cycle_start_idx + cycle_len] # create indices for cycle - cycle_i = np.arange(cycle_start_idx, cycle_start_idx+cycle_len) + cycle_i = np.arange(cycle_start_idx, cycle_start_idx + cycle_len) # create ordered dihedral pairs and indices which each start at a different point in the cycle cycle_dihedral_pair_orders = np.stack([np.roll(cycle_pairs, -i, axis=0) for i in range(len(cycle_pairs))])[:-1] @@ -168,7 +176,7 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta x_indices, y_indices = pairs.transpose() - p_coords = torch.zeros([cycle_len, 4, model.n_model_confs, 3]) + p_coords = torch.zeros([cycle_len, 4, model.n_model_confs, 3], device=device) p_idx = [model.neighbors[x] for x in x_indices] if ii == 0: @@ -218,13 +226,13 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta # translate q new_p_Y = new_pos_Sx_2[i][Sx_cycle[i] == y_indices[i]].squeeze(-1) - transform_matrix = torch.diag(torch.tensor([-1., -1., 1.])).unsqueeze(0).unsqueeze(0) + transform_matrix = torch.diag(torch.tensor([-1., -1., 1.], device=device)).unsqueeze(0).unsqueeze(0) new_pos_Sy_3 = torch.matmul(transform_matrix, new_pos_Sy_2[i].unsqueeze(-1)).squeeze(-1) + new_p_Y # rotate by gamma H_gamma = calculate_gamma(model.n_model_confs, model.dihedral_mask[ids[i]], model.c_ij[ids[i]], model.v_star[ids[i]], Sx_cycle[i], Sy_cycle[i], p_idx[i], q_idx[i], pairs[i][0], - pairs[i][1], new_pos_Sx_2[i], new_pos_Sy_3, new_p_Y) + pairs[i][1], new_pos_Sx_2[i], new_pos_Sy_3, new_p_Y, device) new_pos_Sx_3 = torch.matmul(H_gamma.unsqueeze(0), new_pos_Sx_2[i].unsqueeze(-1)).squeeze(-1) # update all coordinates @@ -239,7 +247,7 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta if not np.all(ids == cycle_i_orders[-1]): Sy_cycle = [[] for i in range(cycle_len)] else: - cycle_mask = torch.ones([cycle_pos.size(0), cycle_pos.size(1)]) + cycle_mask = torch.ones([cycle_pos.size(0), cycle_pos.size(1)], device=device) for i in range(cycle_len): cycle_mask[i, y_indices[i]] = 0 y_neighbor_ids = model.neighbors[y_indices[i]] @@ -258,7 +266,11 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta p_cycle_coords_aligned = align_coords_Kabsch(p_cycle_coords, q_cycle_coords, cycle_rmsd_mask).permute(0, 2, 1, 3) # average aligned coords - cycle_avg_coords_ = torch.vstack([q_cycle_coords_aligned.unsqueeze(0), p_cycle_coords_aligned]) * cycle_mask[:, Sx_cycle[0]].unsqueeze(-1).unsqueeze(-1) + cycle_avg_coords_ \ + = torch.vstack([q_cycle_coords_aligned.unsqueeze(0), + p_cycle_coords_aligned]) \ + * cycle_mask[:, Sx_cycle[0]].unsqueeze(-1).unsqueeze(-1) + cycle_avg_coords = cycle_avg_coords_.sum(dim=0) / cycle_mask[:, Sx_cycle[0]].sum(dim=0).unsqueeze(-1).unsqueeze(-1) return cycle_avg_coords, Sx_cycle[0] @@ -266,10 +278,10 @@ def smooth_cycle_coords(model, cycle_indices, new_pos, dihedral_pairs, cycle_sta def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pairs, neighbors, model_p_coords, model_q_coords, dihedral_x_mask, dihedral_y_mask, x_map_to_neighbor_y, - y_map_to_neighbor_x, dihedral_mask, c_ij, v_star): + y_map_to_neighbor_x, dihedral_mask, c_ij, v_star, device): pos = torch.cat([torch.cat([p[0][i] for p in data.pos]).unsqueeze(1) for i in range(n_true_confs)], dim=1) - new_pos = torch.zeros([pos.size(0), n_model_confs, 3]).to(device) + new_pos = torch.zeros([pos.size(0), n_model_confs, 3], device=device) dihedral_pairs = dihedral_pairs.t().detach().cpu().numpy() Sx = [] @@ -284,7 +296,7 @@ def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pai continue # new graph - p_coords = torch.zeros([4, n_model_confs, 3]).to(device) + p_coords = torch.zeros([4, n_model_confs, 3], device=device) p_idx = neighbors[x_index] if x_index not in Sx: @@ -298,11 +310,11 @@ def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pai # update indices Sx.extend([x_index]) - Sx.extend(neighbors[x_index].detach().numpy()) + Sx.extend(neighbors[x_index].detach().cpu().numpy()) Sx = list(set(Sx)) Sy.extend([y_index]) - Sy.extend(neighbors[y_index].detach().numpy()) + Sy.extend(neighbors[y_index].detach().cpu().numpy()) # set px p_X = new_pos[x_index] @@ -327,12 +339,12 @@ def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pai # translate q new_p_Y = new_pos_Sx_2[Sx == y_index] - transform_matrix = torch.diag(torch.tensor([-1., -1., 1.])).unsqueeze(0).unsqueeze(0) + transform_matrix = torch.diag(torch.tensor([-1., -1., 1.], device=device)).unsqueeze(0).unsqueeze(0) new_pos_Sy_3 = torch.matmul(transform_matrix, new_pos_Sy_2.unsqueeze(-1)).squeeze(-1) + new_p_Y # rotate by gamma H_gamma = calculate_gamma(n_model_confs, dihedral_mask[i], c_ij[i], v_star[i], Sx, Sy, p_idx, q_idx, x_index, - y_index, new_pos_Sx_2, new_pos_Sy_3, new_p_Y) + y_index, new_pos_Sx_2, new_pos_Sy_3, new_p_Y, device) new_pos_Sx_3 = torch.matmul(H_gamma.unsqueeze(0), new_pos_Sx_2.unsqueeze(-1)).squeeze(-1) # update all coordinates @@ -353,10 +365,10 @@ def construct_conformers_acyclic(data, n_true_confs, n_model_confs, dihedral_pai def calculate_gamma(n_model_confs, dihedral_mask, c_ij, v_star, Sx, Sy, p_idx, q_idx, x_index, y_index, - new_pos_Sx_2, new_pos_Sy_3, new_p_Y): + new_pos_Sx_2, new_pos_Sy_3, new_p_Y, device): # calculate current dihedrals - pT_prime = torch.zeros([3, n_model_confs, 3]).to(device) - qZ_translated = torch.zeros([3, n_model_confs, 3]).to(device) + pT_prime = torch.zeros([3, n_model_confs, 3], device=device) + qZ_translated = torch.zeros([3, n_model_confs, 3], device=device) pY_prime = new_p_Y.repeat(9, 1, 1) qX = torch.zeros_like(pY_prime) @@ -368,19 +380,24 @@ def calculate_gamma(n_model_confs, dihedral_mask, c_ij, v_star, Sx, Sy, p_idx, q qZ_translated[:len(q_ids_in_Sy)] = new_pos_Sy_3[q_ids_in_Sy] XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos = batch_dihedrals(pT_prime[pT_idx], qX, pY_prime, qZ_translated[qZ_idx]) - A_ij = build_A_matrix_inf(XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos, n_model_confs) * dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + A_ij = build_A_matrix_inf( + XYTi_XYZj_curr_sin, + XYTi_XYZj_curr_cos, + n_model_confs, + device=device, + ) * dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # build A matrix A_curr = torch.sum(A_ij * c_ij.unsqueeze(-1), dim=0) determinants = torch.det(A_curr) + 1e-10 - A_curr_inv_ = A_curr.view(n_model_confs, 4)[:, [3, 1, 2, 0]] * torch.tensor([[1., -1., -1., 1.]]) + A_curr_inv_ = A_curr.view(n_model_confs, 4)[:, [3, 1, 2, 0]] * torch.tensor([[1., -1., -1., 1.]], device=device) A_curr_inv = (A_curr_inv_ / determinants.unsqueeze(-1)).view(n_model_confs, 2, 2) A_curr_inv_v_star = torch.matmul(A_curr_inv, v_star.unsqueeze(-1)).squeeze(-1) # get gamma matrix v_gamma = A_curr_inv_v_star / (A_curr_inv_v_star.norm(dim=-1, keepdim=True) + 1e-10) gamma_cos, gamma_sin = v_gamma.split(1, dim=-1) - H_gamma = build_gamma_rotation_inf(gamma_sin.squeeze(-1), gamma_cos.squeeze(-1), n_model_confs) + H_gamma = build_gamma_rotation_inf(gamma_sin.squeeze(-1), gamma_cos.squeeze(-1), n_model_confs, device) return H_gamma @@ -417,9 +434,9 @@ def rotation_matrix_inf_v2(neighbor_coords, neighbor_map): return H -def build_A_matrix_inf(curr_sin, curr_cos, n_model_confs): +def build_A_matrix_inf(curr_sin, curr_cos, n_model_confs, device): - A_ij = torch.FloatTensor([[[[0, 0], [0, 0]]]]).repeat(9, n_model_confs, 1, 1) + A_ij = torch.FloatTensor([[[[0, 0], [0, 0]]]]).repeat(9, n_model_confs, 1, 1).to(device) A_ij[:, :, 0, 0] = curr_cos A_ij[:, :, 0, 1] = curr_sin A_ij[:, :, 1, 0] = curr_sin @@ -428,8 +445,8 @@ def build_A_matrix_inf(curr_sin, curr_cos, n_model_confs): return A_ij -def build_gamma_rotation_inf(gamma_sin, gamma_cos, n_model_confs): - H_gamma = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(n_model_confs, 1, 1) +def build_gamma_rotation_inf(gamma_sin, gamma_cos, n_model_confs, device): + H_gamma = torch.FloatTensor([[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]).repeat(n_model_confs, 1, 1).to(device) H_gamma[:, 1, 1] = gamma_cos H_gamma[:, 1, 2] = -gamma_sin H_gamma[:, 2, 1] = gamma_sin diff --git a/model/model.py b/geomol/model.py similarity index 87% rename from model/model.py rename to geomol/model.py index 56ba621..399be04 100644 --- a/model/model.py +++ b/geomol/model.py @@ -6,8 +6,8 @@ from torch_geometric.nn import global_add_pool from torch_scatter import scatter -from model.GNN import GNN, MLP -from model.utils import * +from geomol.GNN import GNN, MLP +from geomol.utils import * from itertools import permutations import numpy as np @@ -27,7 +27,6 @@ def __init__(self, hyperparams, num_node_features, num_edge_features): self.loss_type = hyperparams['loss_type'] self.teacher_force = hyperparams['teacher_force'] self.random_alpha = hyperparams['random_alpha'] - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.gnn = GNN(node_dim=num_node_features + self.random_vec_dim, edge_dim=num_edge_features + self.random_vec_dim, @@ -79,8 +78,8 @@ def forward(self, data, ignore_neighbors=False, inference=False, n_model_confs=N self.generate_model_prediction(data.x, data.edge_index, data.edge_attr, data.batch, data.chiral_tag) return - x, edge_index, edge_attr, pos_list, batch, pos_mask, chiral_tag = \ - data.x, data.edge_index, data.edge_attr, data.pos, data.batch, data.pos_mask, data.chiral_tag + x, edge_index, edge_attr, pos_list, batch, pos_mask, chiral_tag \ + = data.x, data.edge_index, data.edge_attr, data.pos, data.batch, data.pos_mask, data.chiral_tag # assign neighborhoods self.assign_neighborhoods(x, edge_index, edge_attr, batch, data) @@ -107,18 +106,18 @@ def forward(self, data, ignore_neighbors=False, inference=False, n_model_confs=N pos_mask_L2 = pos_mask.view(molecule_loss.size(2), self.n_true_confs).t() pos_mask_L1 = pos_mask_L2.unsqueeze(1).repeat(1, self.n_model_confs, 1) - molecule_loss = torch.where(pos_mask_L1 == 1, molecule_loss, torch.FloatTensor([9e99]).to(self.device)) + molecule_loss = torch.where(pos_mask_L1 == 1, molecule_loss, torch.FloatTensor([9e99], device=molecule_loss.device)) if self.loss_type == 'implicit_mle': if DEBUG_NEIGHBORHOOD_PAIRS or self.teacher_force: L1 = torch.where(pos_mask_L2 == 1, torch.min(molecule_loss, dim=0).values, - torch.FloatTensor([0]).to(self.device)).sum(dim=0) / pos_mask_L2.sum(dim=0) + torch.FloatTensor([0], device=pos_mask_L2.device)).sum(dim=0) / pos_mask_L2.sum(dim=0) else: L1 = torch.min(molecule_loss, dim=0).values.sum(dim=0) / self.n_model_confs L2 = torch.where(pos_mask_L2 == 1, torch.min(molecule_loss, dim=1).values, - torch.FloatTensor([0]).to(self.device)).sum(dim=0) / pos_mask_L2.sum(dim=0) + torch.FloatTensor([0], device=pos_mask_L2.device)).sum(dim=0) / pos_mask_L2.sum(dim=0) # logging self.run_writer_mle(True if L1.mean() > L2.mean() else False, molecule_loss, pos_mask_L2) @@ -139,13 +138,13 @@ def forward(self, data, ignore_neighbors=False, inference=False, n_model_confs=N if self.teacher_force: cost_mat_i = cost_mat_detach[i, :n_true_confs_batch[i], :n_true_confs_batch[i]] ot_mat = ot.emd(a=H_1, b=H_1, M=np.max(np.abs(cost_mat_i)) + cost_mat_i, numItermax=10000) - ot_mat_attached = torch.tensor(ot_mat, device=self.device, requires_grad=False).float() + ot_mat_attached = torch.tensor(ot_mat, device=molecule_loss.device, requires_grad=False).float() ot_mat_list.append(ot_mat_attached) loss += torch.sum(ot_mat_attached * molecule_loss[:n_true_confs_batch[i], :n_true_confs_batch[i], i]) else: cost_mat_i = cost_mat_detach[i, :n_true_confs_batch[i]] ot_mat = ot.emd(a=H_1, b=H_2, M=np.max(np.abs(cost_mat_i)) + cost_mat_i, numItermax=10000) - ot_mat_attached = torch.tensor(ot_mat, device=self.device, requires_grad=False).float() + ot_mat_attached = torch.tensor(ot_mat, device=molecule_loss.device, requires_grad=False).float() ot_mat_list.append(ot_mat_attached) loss += torch.sum(ot_mat_attached * molecule_loss[:n_true_confs_batch[i], :, i]) @@ -165,13 +164,13 @@ def assign_neighborhoods(self, x, edge_index, edge_attr, batch, data): self.n_dihedral_pairs = len(self.dihedral_pairs.t()) # mask for neighbors - self.neighbor_mask = torch.zeros([self.n_neighborhoods, 4]).to(self.device) + self.neighbor_mask = torch.zeros([self.n_neighborhoods, 4], device=x.device) # maps node index to hidden index as given by self.neighbors self.x_to_h_map = torch.zeros(x.size(0)) # maps local neighborhood to batch molecule - self.neighborhood_to_mol_map = torch.zeros(self.n_neighborhoods, dtype=torch.int64).to(self.device) + self.neighborhood_to_mol_map = torch.zeros(self.n_neighborhoods, dtype=torch.int64, device=x.device) for i, (a, n) in enumerate(self.neighbors.items()): self.x_to_h_map[a] = i @@ -180,18 +179,18 @@ def assign_neighborhoods(self, x, edge_index, edge_attr, batch, data): self.neighborhood_to_mol_map[i] = batch[a] # maps which atom in (x,y) corresponds to the same atom in (y,x) for each dihedral pair - self.x_map_to_neighbor_y = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) - self.y_map_to_neighbor_x = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) + self.x_map_to_neighbor_y = torch.zeros([self.n_dihedral_pairs, 4], device=x.device) + self.y_map_to_neighbor_x = torch.zeros_like(self.x_map_to_neighbor_y) # neighbor mask but for dihedral pairs - self.dihedral_x_mask = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) - self.dihedral_y_mask = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) + self.dihedral_x_mask = torch.zeros_like(self.x_map_to_neighbor_y) + self.dihedral_y_mask = torch.zeros_like(self.dihedral_x_mask) # maps neighborhood pair to batch molecule - self.neighborhood_pairs_to_mol_map = torch.zeros(self.n_dihedral_pairs, dtype=torch.int64).to(self.device) + self.neighborhood_pairs_to_mol_map = torch.zeros(self.n_dihedral_pairs, dtype=torch.int64, device=x.device) # indicates which type of bond is formed by X-Y - self.xy_bond_type = torch.zeros([self.n_dihedral_pairs, 4]).to(self.device) + self.xy_bond_type = torch.zeros_like(self.x_map_to_neighbor_y) for i, (s, e) in enumerate(self.dihedral_pairs.t()): # this indicates which neighbor is the correct x <--> y map (see overleaf doc) @@ -216,8 +215,8 @@ def embed(self, x, edge_index, edge_attr, batch): # stochasticity rand_dist = torch.distributions.normal.Normal(loc=0, scale=self.random_vec_std) # rand_dist = torch.distributions.uniform.Uniform(torch.tensor([0.0]), torch.tensor([1.0])) - rand_x = rand_dist.sample([x.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) # added squeeze - rand_edge = rand_dist.sample([edge_attr.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) # added squeeze + rand_x = rand_dist.sample([x.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(x.device) # added squeeze + rand_edge = rand_dist.sample([edge_attr.size(0), self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(edge_attr.device) # added squeeze x = torch.cat([x.unsqueeze(1).repeat(1, self.n_model_confs, 1), rand_x], dim=-1) edge_attr = torch.cat([edge_attr.unsqueeze(1).repeat(1, self.n_model_confs, 1), rand_edge], dim=-1) @@ -234,9 +233,18 @@ def embed(self, x, edge_index, edge_attr, batch): x_transformer = x_transformer.permute(1, 0, 2, 3).reshape(n_max, -1, self.model_dim) x_transformer_mask = x_mask.unsqueeze(1).repeat(1, self.n_model_confs, 1).view(-1, n_max) - x_global = self.global_embed(x_transformer, src_key_padding_mask=~x_transformer_mask).view( - n_max, max(batch)+1, self.n_model_confs, -1).permute(1, 0, 2, 3) * \ - x_transformer_mask.view(max(batch)+1, n_max, self.n_model_confs, 1) + x_global \ + = self.global_embed(x_transformer, + src_key_padding_mask=~x_transformer_mask) \ + .view(n_max, + max(batch) + 1, + self.n_model_confs, + -1)\ + .permute(1, 0, 2, 3) \ + * x_transformer_mask.view(max(batch) + 1, + n_max, + self.n_model_confs, + 1) # global reps for torsions h_mol = self.h_mol_mlp(x_global.sum(dim=1)) @@ -245,14 +253,22 @@ def embed(self, x, edge_index, edge_attr, batch): x2 = x_global[x_mask, :] else: - h_mol = self.h_mol_mlp(global_add_pool(x2, batch)) + # global_add_pool changed in PR #4827 to use dim=-2 instead of 0 by default + # Use a more general version to support both new and old versions of PyTorch Geometric + size = int(batch.max().item() + 1) + h_mol = self.h_mol_mlp( + scatter(x2, + batch, + dim=0, + dim_size=size, + reduce='sum')) return x1, x2, h_mol def model_local_stats(self, x, chiral_tag): - n_h = torch.zeros([self.n_neighborhoods, 4, self.n_model_confs, self.model_dim]).to(self.device) - x_h = torch.zeros([self.n_neighborhoods, self.n_model_confs, self.model_dim]).to(self.device) + n_h = torch.zeros([self.n_neighborhoods, 4, self.n_model_confs, self.model_dim], device=x.device) + x_h = torch.zeros([self.n_neighborhoods, self.n_model_confs, self.model_dim], device=x.device) for i, (a, n) in enumerate(self.neighbors.items()): n_h[i, 0:len(n), :] = x[n] @@ -265,8 +281,11 @@ def model_local_stats(self, x, chiral_tag): h_ = h.permute(1, 0, 2, 3).reshape(4, self.n_neighborhoods * self.n_model_confs, self.model_dim * 2) # CHECK RESHAPE OP h_mask = self.neighbor_mask.bool().unsqueeze(1).repeat(1, self.n_model_confs, 1).view(self.n_neighborhoods * self.n_model_confs, 4) - h_new = self.encoder(h_, src_key_padding_mask=~h_mask).view(4, self.n_neighborhoods, self.n_model_confs, self.model_dim * 2).permute(1, 0, 2, 3) \ - * self.neighbor_mask.unsqueeze(-1).unsqueeze(-1) + h_new \ + = self.encoder(h_, src_key_padding_mask=~h_mask) \ + .view(4, self.n_neighborhoods, self.n_model_confs, self.model_dim * 2)\ + .permute(1, 0, 2, 3) \ + * self.neighbor_mask.unsqueeze(-1).unsqueeze(-1) unit_normals = self.coord_pred(h_new) * self.neighbor_mask.unsqueeze(-1).unsqueeze(-1) # tetrahedral chiral corrections @@ -299,7 +318,7 @@ def model_local_stats(self, x, chiral_tag): self.neighbor_mask) if self.teacher_force: - R = random_rotation_matrix([self.n_neighborhoods, 1, self.n_model_confs]).to(self.device) + R = random_rotation_matrix([self.n_neighborhoods, 1, self.n_model_confs]).to(self.true_local_coords.device) self.model_local_coords = torch.matmul(R, self.true_local_coords[:, 0].unsqueeze(-1)).squeeze(-1) return model_one_hop, model_two_hop, model_angles @@ -319,13 +338,13 @@ def ground_truth_local_stats(self, pos): """ n_neighborhoods = len(self.neighbors) - self.true_local_coords = torch.zeros(n_neighborhoods, 6, 4, self.n_true_confs, 3).to(self.device) + self.true_local_coords = torch.zeros(n_neighborhoods, 6, 4, self.n_true_confs, 3, device=pos.device) for i, (a, n) in enumerate(self.neighbors.items()): # permutations for symmetric hydrogens n_perms = n.unsqueeze(0).repeat(6, 1) - perms = torch.tensor(list(permutations(n[self.leaf_hydrogens[a]]))).to(self.device) + perms = torch.tensor(list(permutations(n[self.leaf_hydrogens[a]])), device=n_perms.device) if perms.size(1) != 0: n_perms[0:len(perms), self.leaf_hydrogens[a]] = perms @@ -351,7 +370,8 @@ def local_loss(self, true_one_hop, true_two_hop, true_angles, model_one_hop, mod # bending angles loss model_angles_perms = model_angles.unsqueeze(1).repeat(1, 6, 1) - angle_loss_perm = torch.sum(von_Mises_loss(true_angles, model_angles_perms) * true_angles.bool(), dim=-1) / (true_angles.bool().sum(dim=-1) + 1e-10) + angle_loss_perm = torch.sum(von_Mises_loss(true_angles, model_angles_perms) * true_angles.bool(), + dim=-1) / (true_angles.bool().sum(dim=-1) + 1e-10) angle_loss = scatter(angle_loss_perm.max(dim=-1).values, self.neighborhood_to_mol_map, reduce="mean") return one_hop_loss, two_hop_loss, angle_loss @@ -367,13 +387,13 @@ def model_pair_stats(self, x, batch, h_mol): :return: tuple of true stats (dihedral and three-hop), each with size (n_dihedral_pairs, 9, n_true_confs) """ - dihedral_x_neighbors = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, 3]).to(self.device) - dihedral_x_node_reps = torch.zeros([self.n_dihedral_pairs, self.n_model_confs, self.model_dim]).to(self.device) - dihedral_x_neighbor_reps = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, self.model_dim]).to(self.device) + dihedral_x_neighbors = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, 3], device=x.device) + dihedral_x_node_reps = torch.zeros([self.n_dihedral_pairs, self.n_model_confs, self.model_dim], device=x.device) + dihedral_x_neighbor_reps = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, self.model_dim], device=x.device) - dihedral_y_neighbors = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, 3]).to(self.device) - dihedral_y_node_reps = torch.zeros([self.n_dihedral_pairs, self.n_model_confs, self.model_dim]).to(self.device) - dihedral_y_neighbor_reps = torch.zeros([self.n_dihedral_pairs, 4, self.n_model_confs, self.model_dim]).to(self.device) + dihedral_y_neighbors = torch.zeros_like(dihedral_x_neighbors) + dihedral_y_node_reps = torch.zeros_like(dihedral_x_node_reps) + dihedral_y_neighbor_reps = torch.zeros_like(dihedral_x_neighbor_reps) for i, (s, e) in enumerate(self.dihedral_pairs.t()): @@ -410,7 +430,8 @@ def model_pair_stats(self, x, batch, h_mol): q_Z_translated_combos = q_Z_translated[:, qZ_idx, :] p_Y_alpha_combos = p_Y_alpha.unsqueeze(1).repeat(1, 9, 1, 1) - model_dihedrals_sin, model_dihedrals_cos = batch_dihedrals(p_T_alpha_combos, torch.zeros_like(p_Y_alpha_combos), p_Y_alpha_combos, q_Z_translated_combos) + model_dihedrals_sin, model_dihedrals_cos = batch_dihedrals( + p_T_alpha_combos, torch.zeros_like(p_Y_alpha_combos), p_Y_alpha_combos, q_Z_translated_combos) model_dihedrals_sin = model_dihedrals_sin * self.dihedral_mask.unsqueeze(-1) model_dihedrals_cos = model_dihedrals_cos * self.dihedral_mask.unsqueeze(-1) model_dihedrals = torch.stack([model_dihedrals_sin, model_dihedrals_cos], dim=0) @@ -440,7 +461,7 @@ def ground_truth_pair_stats(self, pos): """ n_dihedral_pairs = len(self.dihedral_pairs.t()) - true_dihedral_coords = torch.zeros([n_dihedral_pairs, 4, 4, 6, self.n_true_confs, 3]).to(self.device) + true_dihedral_coords = torch.zeros([n_dihedral_pairs, 4, 4, 6, self.n_true_confs, 3], device=pos.device) for i, (s, e) in enumerate(self.dihedral_pairs.t()): # construct true coordinates (order is x_n, x, y, y_n) @@ -448,8 +469,8 @@ def ground_truth_pair_stats(self, pos): y_neighbor_map_perms = self.neighbors[e.item()].unsqueeze(1).repeat(1, 6) # permutations for symmetric hydrogens - x_perms = torch.tensor(list(permutations(self.neighbors[s.item()][self.leaf_hydrogens[s.item()]]))).t().to(self.device) - y_perms = torch.tensor(list(permutations(self.neighbors[e.item()][self.leaf_hydrogens[e.item()]]))).t().to(self.device) + x_perms = torch.tensor(list(permutations(self.neighbors[s.item()][self.leaf_hydrogens[s.item()]]))).t().to(pos.device) + y_perms = torch.tensor(list(permutations(self.neighbors[e.item()][self.leaf_hydrogens[e.item()]]))).t().to(pos.device) if x_perms.size(0) != 0: x_neighbor_map_perms[self.leaf_hydrogens[s.item()], 0:x_perms.size(1)] = x_perms @@ -472,12 +493,14 @@ def ground_truth_pair_stats(self, pos): true_dihedral_yn_coords = true_dihedral_coords[:, 3][~self.y_map_to_neighbor_x.bool(), :].view(-1, 3, 6, self.n_true_confs, 3)[:, qZ_idx, :] # calculate true dihedrals - true_dihedrals_sin, true_dihedrals_cos = batch_dihedrals(true_dihedral_xn_coords, true_dihedral_x_coords, true_dihedral_y_coords, true_dihedral_yn_coords) + true_dihedrals_sin, true_dihedrals_cos = batch_dihedrals( + true_dihedral_xn_coords, true_dihedral_x_coords, true_dihedral_y_coords, true_dihedral_yn_coords) true_dihedrals_sin = true_dihedrals_sin * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) true_dihedrals_cos = true_dihedrals_cos * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) true_dihedrals = torch.stack([true_dihedrals_sin, true_dihedrals_cos], dim=0) # true_dihedrals = batch_vector_angles(true_dihedral_xn_coords, true_dihedral_x_coords, true_dihedral_y_coords, - # true_dihedral_yn_coords).view(-1, 9, 6, self.n_true_confs) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) + # true_dihedral_yn_coords).view(-1, 9, 6, self.n_true_confs) * + # self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) # calculate true three-hop distances true_three_hop = torch.linalg.norm(true_dihedral_xn_coords - true_dihedral_yn_coords, dim=-1) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1) @@ -506,7 +529,12 @@ def pair_loss(self, true_dihedrals, model_dihedrals, true_three_hop, model_three # dihedral loss model_dihedrals_perms = model_dihedrals.unsqueeze(-1).repeat(1, 1, 1, 6) - dihedral_loss_perms = torch.sum(von_Mises_loss(true_dihedrals[1], model_dihedrals_perms[1], true_dihedrals[0], model_dihedrals_perms[0]) * self.dihedral_mask.unsqueeze(-1), dim=-2) / (self.dihedral_mask.sum(dim=-1, keepdim=True) + 1e-10) + dihedral_loss_perms = torch.sum(von_Mises_loss(true_dihedrals[1], + model_dihedrals_perms[1], + true_dihedrals[0], + model_dihedrals_perms[0]) * self.dihedral_mask.unsqueeze(-1), + dim=-2) / (self.dihedral_mask.sum(dim=-1, + keepdim=True) + 1e-10) dihedral_loss = scatter(dihedral_loss_perms.max(dim=-1).values, self.neighborhood_pairs_to_mol_map, reduce="mean") # three-hop distance loss @@ -556,8 +584,9 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch p_Y_prime = p_H[self.x_map_to_neighbor_y.bool()] q_X_prime = q_H[self.y_map_to_neighbor_x.bool()] - transform_matrix = torch.diag(torch.tensor([-1., -1., 1.]).to(self.device)).unsqueeze(0).unsqueeze(0).unsqueeze(0) - q_Z_translated = torch.matmul(transform_matrix, q_Z_prime.unsqueeze(-1)).squeeze(-1) + p_Y_prime.unsqueeze(1) # broadcast over not coordinates + transform_matrix = torch.diag(torch.tensor([-1., -1., 1.], device=q_Z_prime.device)).unsqueeze(0).unsqueeze(0).unsqueeze(0) + q_Z_translated = torch.matmul(transform_matrix, + q_Z_prime.unsqueeze(-1)).squeeze(-1) + p_Y_prime.unsqueeze(1) # broadcast over not coordinates # calculate alpha dihedral_h_mol = h_mol[batch[self.dihedral_pairs[0]]] # (n_dihedral_pairs, n_model_confs. model_dim/2) @@ -565,12 +594,14 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch # more stochasticity! if self. random_alpha: rand_dist = torch.distributions.normal.Normal(loc=0, scale=self.random_vec_std) - rand_alpha = rand_dist.sample([self.n_dihedral_pairs, self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(self.device) - alpha = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) + \ - self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) + rand_alpha = rand_dist.sample([self.n_dihedral_pairs, self.n_model_confs, self.random_vec_dim]).squeeze(-1).to(dihedral_x_node_reps.device) + alpha \ + = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) \ + + self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol, rand_alpha], dim=-1)) else: - alpha = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol], dim=-1)) + \ - self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol], dim=-1)) + alpha \ + = self.alpha_mlp(torch.cat([dihedral_x_node_reps, dihedral_y_node_reps, dihedral_h_mol], dim=-1)) \ + + self.alpha_mlp(torch.cat([dihedral_y_node_reps, dihedral_x_node_reps, dihedral_h_mol], dim=-1)) alpha = alpha.view(self.n_dihedral_pairs, self.n_model_confs, 1) self.v_star = torch.cat([torch.cos(alpha), torch.sin(alpha)], dim=-1) @@ -588,14 +619,17 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch q_reps = dihedral_y_neighbor_reps[~self.y_map_to_neighbor_x.bool()].view(-1, 3, self.n_model_confs, self.model_dim) cx_reps = dihedral_x_node_reps.unsqueeze(1).repeat(1, 9, 1, 1) cy_reps = dihedral_y_node_reps.unsqueeze(1).repeat(1, 9, 1, 1) - self.c_ij = self.c_mlp(torch.cat([p_reps[:, pT_idx], cx_reps, q_reps[:, qZ_idx], cy_reps], dim=-1)) + \ - self.c_mlp(torch.cat([q_reps[:, qZ_idx], cy_reps, p_reps[:, pT_idx], cx_reps], dim=-1)) + self.c_ij \ + = self.c_mlp(torch.cat([p_reps[:, pT_idx], cx_reps, q_reps[:, qZ_idx], cy_reps], dim=-1)) \ + + self.c_mlp(torch.cat([q_reps[:, qZ_idx], cy_reps, p_reps[:, pT_idx], cx_reps], dim=-1)) # calculate gamma sin and cos - A_ij = self.build_A_matrix(XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + A_ij = self.build_A_matrix(XYTi_XYZj_curr_sin, XYTi_XYZj_curr_cos).to(self.dihedral_mask.device) * self.dihedral_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) A_curr = torch.sum(A_ij * self.c_ij.unsqueeze(-1), dim=1) determinants = torch.det(A_curr) + 1e-10 - A_curr_inv_ = A_curr.view(self.n_dihedral_pairs, self.n_model_confs, 4)[:, :, [3, 1, 2, 0]] * torch.tensor([[[1., -1., -1., 1.]]]).to(self.device) + A_curr_inv_ = A_curr.view(self.n_dihedral_pairs, + self.n_model_confs, + 4)[:, :, [3, 1, 2, 0]] * torch.tensor([[[1., -1., -1., 1.]]], device=A_curr.device) A_curr_inv = (A_curr_inv_ / determinants.unsqueeze(-1)).view(self.n_dihedral_pairs, self.n_model_confs, 2, 2) A_curr_inv_v_star = torch.matmul(A_curr_inv, self.v_star.unsqueeze(-1)).squeeze(-1) @@ -603,7 +637,7 @@ def align_dihedral_neighbors(self, dihedral_node_reps, dihedral_neighbors, batch gamma_cos, gamma_sin = v_gamma.split(1, dim=-1) # rotate p_coords by gamma - H_gamma = self.build_alpha_rotation(gamma_sin.squeeze(-1), gamma_cos.squeeze(-1)) + H_gamma = self.build_alpha_rotation(gamma_sin.squeeze(-1), gamma_cos.squeeze(-1)).to(p_T_prime.device) p_T_alpha = torch.matmul(H_gamma.unsqueeze(1), p_T_prime.unsqueeze(-1)).squeeze(-1) return q_Z_prime, p_T_alpha, p_Y_prime, q_Z_translated @@ -615,7 +649,7 @@ def build_alpha_rotation(self, alpha, alpha_cos=None): :param alpha: predicted values of torsion parameter alpha (n_dihedral_pairs, n_model_confs) :return: alpha rotation matrix (n_dihedral_pairs, n_model_confs, 3, 3) """ - H_alpha = torch.FloatTensor([[[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]]).repeat(self.n_dihedral_pairs, self.n_model_confs, 1, 1).to(self.device) + H_alpha = torch.FloatTensor([[[[1, 0, 0], [0, 0, 0], [0, 0, 0]]]]).repeat(self.n_dihedral_pairs, self.n_model_confs, 1, 1) if torch.is_tensor(alpha_cos): H_alpha[:, :, 1, 1] = alpha_cos @@ -632,7 +666,7 @@ def build_alpha_rotation(self, alpha, alpha_cos=None): def build_A_matrix(self, curr_sin, curr_cos): - A_ij = torch.FloatTensor([[[[[0, 0], [0, 0]]]]]).repeat(self.n_dihedral_pairs, 9, self.n_model_confs, 1, 1).to(self.device) + A_ij = torch.FloatTensor([[[[[0, 0], [0, 0]]]]]).repeat(self.n_dihedral_pairs, 9, self.n_model_confs, 1, 1) A_ij[:, :, :, 0, 0] = curr_cos A_ij[:, :, :, 0, 1] = curr_sin A_ij[:, :, :, 1, 0] = curr_sin diff --git a/model/parsing.py b/geomol/parsing.py similarity index 100% rename from model/parsing.py rename to geomol/parsing.py diff --git a/model/training.py b/geomol/training.py similarity index 97% rename from model/training.py rename to geomol/training.py index 0995b6b..a02393c 100644 --- a/model/training.py +++ b/geomol/training.py @@ -100,6 +100,7 @@ class NoamLR(_LRScheduler): total_epochs * steps_per_epoch). This is roughly based on the learning rate schedule from Attention is All You Need, section 5.3 (https://arxiv.org/abs/1706.03762). """ + def __init__(self, optimizer: Optimizer, warmup_epochs: List[Union[float, int]], @@ -119,8 +120,12 @@ def __init__(self, :param max_lr: The maximum learning rate (achieved after warmup_epochs). :param final_lr: The final learning rate (achieved after total_epochs). """ - assert len(optimizer.param_groups) == len(warmup_epochs) == len(total_epochs) == len(init_lr) == \ - len(max_lr) == len(final_lr) + assert len(optimizer.param_groups) \ + == len(warmup_epochs) \ + == len(total_epochs) \ + == len(init_lr) \ + == len(max_lr) \ + == len(final_lr) self.num_lrs = len(optimizer.param_groups) diff --git a/model/utils.py b/geomol/utils.py similarity index 96% rename from model/utils.py rename to geomol/utils.py index d27c941..8af762f 100644 --- a/model/utils.py +++ b/geomol/utils.py @@ -1,22 +1,25 @@ +from pathlib import Path + import torch import torch_geometric as tg from torch_geometric.utils import degree import networkx as nx -from model.cycle_utils import get_current_cycle_indices +from geomol.cycle_utils import get_current_cycle_indices + +model_path = Path(__file__).parents[1] / "trained_models" -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') angle_mask_ref = torch.LongTensor([[0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0], [1, 1, 1, 0, 0, 0], - [1, 1, 1, 1, 1, 1]]).to(device) + [1, 1, 1, 1, 1, 1]]) angle_combos = torch.LongTensor([[0, 1], [0, 2], [1, 2], [0, 3], [1, 3], - [2, 3]]).to(device) + [2, 3]]) def get_neighbor_ids(data): @@ -119,7 +122,7 @@ def get_dihedral_pairs(edge_index, data): keep.append(pair) - keep = [t.to(device) for t in keep] + keep = [t for t in keep] return torch.stack(keep).t() @@ -161,6 +164,12 @@ def batch_angles_from_coords(coords, mask): """ Given coordinates, compute all local neighborhood angles """ + device = coords.device + + global angle_mask_ref, angle_combos + angle_mask_ref = angle_mask_ref.to(device) + angle_combos = angle_combos.to(device) + if coords.dim() == 4: all_possible_combos = coords[:, angle_combos] v_a, v_b = all_possible_combos.split(1, dim=2) # does one of these need to be negative? @@ -200,7 +209,7 @@ def batch_dihedrals(p0, p1, p2, p3, angle=False): else: den = torch.linalg.norm(torch.cross(s1, s2, dim=-1), dim=-1) * torch.linalg.norm(torch.cross(s2, s3, dim=-1), dim=-1) + 1e-10 - return sin_d_/den, cos_d_/den + return sin_d_ / den, cos_d_ / den def batch_vector_angles(xn, x, y, yn): @@ -227,7 +236,7 @@ def von_Mises_loss(a, b, a_sin=None, b_sin=None): if torch.is_tensor(a_sin): out = a * b + a_sin * b_sin else: - out = a * b + torch.sqrt(1-a**2 + 1e-5) * torch.sqrt(1-b**2 + 1e-5) + out = a * b + torch.sqrt(1 - a**2 + 1e-5) * torch.sqrt(1 - b**2 + 1e-5) return out diff --git a/model/featurization.py b/model/featurization.py deleted file mode 100644 index 786d137..0000000 --- a/model/featurization.py +++ /dev/null @@ -1,353 +0,0 @@ -from rdkit import Chem -from rdkit.Chem.rdchem import HybridizationType -from rdkit.Chem.rdchem import BondType as BT -from rdkit.Chem.rdchem import ChiralType - -import os.path as osp -import numpy as np -import glob -import pickle -import random - -import torch -import torch.nn.functional as F -from torch_scatter import scatter -from torch_geometric.data import Dataset, Data, DataLoader -from model.utils import get_dihedral_pairs - -dihedral_pattern = Chem.MolFromSmarts('[*]~[*]~[*]~[*]') -chirality = {ChiralType.CHI_TETRAHEDRAL_CW: -1., - ChiralType.CHI_TETRAHEDRAL_CCW: 1., - ChiralType.CHI_UNSPECIFIED: 0, - ChiralType.CHI_OTHER: 0} - - -def one_k_encoding(value, choices): - """ - Creates a one-hot encoding with an extra category for uncommon values. - :param value: The value for which the encoding should be one. - :param choices: A list of possible values. - :return: A one-hot encoding of the :code:`value` in a list of length :code:`len(choices) + 1`. - If :code:`value` is not in :code:`choices`, then the final element in the encoding is 1. - """ - encoding = [0] * (len(choices) + 1) - index = choices.index(value) if value in choices else -1 - encoding[index] = 1 - - return encoding - - -class geom_confs(Dataset): - def __init__(self, root, split_path, mode, transform=None, pre_transform=None, max_confs=10): - super(geom_confs, self).__init__(root, transform, pre_transform) - - self.root = root - self.split_idx = 0 if mode == 'train' else 1 if mode == 'val' else 2 - self.split = np.load(split_path, allow_pickle=True)[self.split_idx] - self.bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} - - # try: - # with open(osp.join(self.root, 'all_data.pickle'), 'rb') as f: - # data_dict = pickle.load(f) - # smiles = [list(data_dict)[i] for i in self.split] - # self.pickle_files = [data_dict[smi] for smi in smiles] - # except FileNotFoundError: - self.dihedral_pairs = {} # for memoization - all_files = sorted(glob.glob(osp.join(self.root, '*.pickle'))) - self.pickle_files = [f for i, f in enumerate(all_files) if i in self.split] - self.max_confs = max_confs - - def len(self): - # return len(self.pickle_files) # should we change this to an integer for random sampling? - return 10000 if self.split_idx == 0 else 1000 - - def get(self, idx): - data = None - while not data: - pickle_file = random.choice(self.pickle_files) - mol_dic = self.open_pickle(pickle_file) - data = self.featurize_mol(mol_dic) - - if idx in self.dihedral_pairs: - data.edge_index_dihedral_pairs = self.dihedral_pairs[idx] - else: - data.edge_index_dihedral_pairs = get_dihedral_pairs(data.edge_index, data=data) - - return data - - def open_pickle(self, mol_path): - with open(mol_path, "rb") as f: - dic = pickle.load(f) - return dic - - def featurize_mol(self, mol_dic): - confs = mol_dic['conformers'] - random.shuffle(confs) # shuffle confs - name = mol_dic["smiles"] - - # filter mols rdkit can't intrinsically handle - mol_ = Chem.MolFromSmiles(name) - if mol_: - canonical_smi = Chem.MolToSmiles(mol_) - else: - return None - - # skip conformers with fragments - if '.' in name: - return None - - # skip conformers without dihedrals - N = confs[0]['rd_mol'].GetNumAtoms() - if N < 4: - return None - if confs[0]['rd_mol'].GetNumBonds() < 4: - return None - if not confs[0]['rd_mol'].HasSubstructMatch(dihedral_pattern): - return None - - pos = torch.zeros([self.max_confs, N, 3]) - pos_mask = torch.zeros(self.max_confs, dtype=torch.int64) - k = 0 - for conf in confs: - mol = conf['rd_mol'] - - # skip mols with atoms with more than 4 neighbors for now - n_neighbors = [len(a.GetNeighbors()) for a in mol.GetAtoms()] - if np.max(n_neighbors) > 4: - continue - - # filter for conformers that may have reacted - try: - conf_canonical_smi = Chem.MolToSmiles(Chem.RemoveHs(mol)) - except Exception as e: - continue - - if conf_canonical_smi != canonical_smi: - continue - - pos[k] = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float) - pos_mask[k] = 1 - k += 1 - correct_mol = mol - if k == self.max_confs: - break - - # return None if no non-reactive conformers were found - if k == 0: - return None - - type_idx = [] - atomic_number = [] - atom_features = [] - chiral_tag = [] - neighbor_dict = {} - ring = correct_mol.GetRingInfo() - for i, atom in enumerate(correct_mol.GetAtoms()): - type_idx.append(self.types[atom.GetSymbol()]) - n_ids = [n.GetIdx() for n in atom.GetNeighbors()] - if len(n_ids) > 1: - neighbor_dict[i] = torch.tensor(n_ids) - chiral_tag.append(chirality[atom.GetChiralTag()]) - atomic_number.append(atom.GetAtomicNum()) - atom_features.extend([atom.GetAtomicNum(), - 1 if atom.GetIsAromatic() else 0]) - atom_features.extend(one_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6])) - atom_features.extend(one_k_encoding(atom.GetHybridization(), [ - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2])) - atom_features.extend(one_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])) - atom_features.extend(one_k_encoding(atom.GetFormalCharge(), [-1, 0, 1])) - atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)), - int(ring.IsAtomInRingOfSize(i, 4)), - int(ring.IsAtomInRingOfSize(i, 5)), - int(ring.IsAtomInRingOfSize(i, 6)), - int(ring.IsAtomInRingOfSize(i, 7)), - int(ring.IsAtomInRingOfSize(i, 8))]) - atom_features.extend(one_k_encoding(int(ring.NumAtomRings(i)), [0, 1, 2, 3])) - - z = torch.tensor(atomic_number, dtype=torch.long) - chiral_tag = torch.tensor(chiral_tag, dtype=torch.float) - - row, col, edge_type, bond_features = [], [], [], [] - for bond in correct_mol.GetBonds(): - start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() - row += [start, end] - col += [end, start] - edge_type += 2 * [self.bonds[bond.GetBondType()]] - bt = tuple(sorted([bond.GetBeginAtom().GetAtomicNum(), bond.GetEndAtom().GetAtomicNum()])), bond.GetBondTypeAsDouble() - bond_features += 2 * [int(bond.IsInRing()), - int(bond.GetIsConjugated()), - int(bond.GetIsAromatic())] - - edge_index = torch.tensor([row, col], dtype=torch.long) - edge_type = torch.tensor(edge_type, dtype=torch.long) - edge_attr = F.one_hot(edge_type, num_classes=len(self.bonds)).to(torch.float) - # bond_features = torch.tensor(bond_features, dtype=torch.float).view(len(bond_type), -1) - - perm = (edge_index[0] * N + edge_index[1]).argsort() - edge_index = edge_index[:, perm] - edge_type = edge_type[perm] - # edge_attr = torch.cat([edge_attr[perm], bond_features], dim=-1) - edge_attr = edge_attr[perm] - - row, col = edge_index - hs = (z == 1).to(torch.float) - num_hs = scatter(hs[row], col, dim_size=N).tolist() - - x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(self.types)) - x2 = torch.tensor(atom_features).view(N, -1) - x = torch.cat([x1.to(torch.float), x2], dim=-1) - - data = Data(x=x, z=z, pos=[pos], edge_index=edge_index, edge_attr=edge_attr, neighbors=neighbor_dict, - chiral_tag=chiral_tag, name=name, boltzmann_weight=conf['boltzmannweight'], - degeneracy=conf['degeneracy'], mol=correct_mol, pos_mask=pos_mask) - return data - - - -class qm9_confs(geom_confs): - def __init__(self, root, split_path, mode, transform=None, pre_transform=None, max_confs=10): - super(qm9_confs, self).__init__(root, split_path, mode, transform, pre_transform, max_confs) - self.types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} - - -class drugs_confs(geom_confs): - def __init__(self, root, split_path, mode, transform=None, pre_transform=None, max_confs=10): - super(drugs_confs, self).__init__(root, split_path, mode, transform, pre_transform, max_confs) - self.types = {'H': 0, 'Li': 1, 'B': 2, 'C': 3, 'N': 4, 'O': 5, 'F': 6, 'Na': 7, 'Mg': 8, 'Al': 9, 'Si': 10, - 'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20, - 'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30, - 'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34} - - -def construct_loader(args, modes=('train', 'val')): - - if isinstance(modes, str): - modes = [modes] - - loaders = [] - for mode in modes: - if args.dataset == 'qm9': - dataset = qm9_confs(args.data_dir, args.split_path, mode, max_confs=args.n_true_confs) - elif args.dataset == 'drugs': - dataset = drugs_confs(args.data_dir, args.split_path, mode, max_confs=args.n_true_confs) - loader = DataLoader(dataset=dataset, - batch_size=args.batch_size, - shuffle=False if mode == 'test' else True, - num_workers=args.num_workers, - pin_memory=False) - loaders.append(loader) - - if len(loaders) == 1: - return loaders[0] - else: - return loaders - - -bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} -qm9_types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4} -drugs_types = {'H': 0, 'Li': 1, 'B': 2, 'C': 3, 'N': 4, 'O': 5, 'F': 6, 'Na': 7, 'Mg': 8, 'Al': 9, 'Si': 10, - 'P': 11, 'S': 12, 'Cl': 13, 'K': 14, 'Ca': 15, 'V': 16, 'Cr': 17, 'Mn': 18, 'Cu': 19, 'Zn': 20, - 'Ga': 21, 'Ge': 22, 'As': 23, 'Se': 24, 'Br': 25, 'Ag': 26, 'In': 27, 'Sb': 28, 'I': 29, 'Gd': 30, - 'Pt': 31, 'Au': 32, 'Hg': 33, 'Bi': 34} - - -def featurize_mol_from_smiles(smiles, dataset='qm9'): - - if dataset == 'qm9': - types = qm9_types - elif dataset == 'drugs': - types = drugs_types - - # filter fragments - if '.' in smiles: - return None - - # filter mols rdkit can't intrinsically handle - mol = Chem.MolFromSmiles(smiles) - if mol: - mol = Chem.AddHs(mol) - else: - return None - N = mol.GetNumAtoms() - - # filter out mols model can't make predictions for - if not mol.HasSubstructMatch(dihedral_pattern): - return None - if N < 4: - return None - - type_idx = [] - atomic_number = [] - atom_features = [] - chiral_tag = [] - neighbor_dict = {} - ring = mol.GetRingInfo() - for i, atom in enumerate(mol.GetAtoms()): - type_idx.append(types[atom.GetSymbol()]) - n_ids = [n.GetIdx() for n in atom.GetNeighbors()] - if len(n_ids) > 1: - neighbor_dict[i] = torch.tensor(n_ids) - chiral_tag.append(chirality[atom.GetChiralTag()]) - atomic_number.append(atom.GetAtomicNum()) - atom_features.extend([atom.GetAtomicNum(), - 1 if atom.GetIsAromatic() else 0]) - atom_features.extend(one_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5, 6])) - atom_features.extend(one_k_encoding(atom.GetHybridization(), [ - Chem.rdchem.HybridizationType.SP, - Chem.rdchem.HybridizationType.SP2, - Chem.rdchem.HybridizationType.SP3, - Chem.rdchem.HybridizationType.SP3D, - Chem.rdchem.HybridizationType.SP3D2])) - atom_features.extend(one_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5, 6])) - atom_features.extend(one_k_encoding(atom.GetFormalCharge(), [-1, 0, 1])) - atom_features.extend([int(ring.IsAtomInRingOfSize(i, 3)), - int(ring.IsAtomInRingOfSize(i, 4)), - int(ring.IsAtomInRingOfSize(i, 5)), - int(ring.IsAtomInRingOfSize(i, 6)), - int(ring.IsAtomInRingOfSize(i, 7)), - int(ring.IsAtomInRingOfSize(i, 8))]) - atom_features.extend(one_k_encoding(int(ring.NumAtomRings(i)), [0, 1, 2, 3])) - - z = torch.tensor(atomic_number, dtype=torch.long) - chiral_tag = torch.tensor(chiral_tag, dtype=torch.float) - - row, col, edge_type, bond_features = [], [], [], [] - for bond in mol.GetBonds(): - start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() - row += [start, end] - col += [end, start] - edge_type += 2 * [bonds[bond.GetBondType()]] - bt = tuple( - sorted([bond.GetBeginAtom().GetAtomicNum(), bond.GetEndAtom().GetAtomicNum()])), bond.GetBondTypeAsDouble() - bond_features += 2 * [int(bond.IsInRing()), - int(bond.GetIsConjugated()), - int(bond.GetIsAromatic())] - - edge_index = torch.tensor([row, col], dtype=torch.long) - edge_type = torch.tensor(edge_type, dtype=torch.long) - edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) - # bond_features = torch.tensor(bond_features, dtype=torch.float).view(len(bond_type), -1) - - perm = (edge_index[0] * N + edge_index[1]).argsort() - edge_index = edge_index[:, perm] - edge_type = edge_type[perm] - # edge_attr = torch.cat([edge_attr[perm], bond_features], dim=-1) - edge_attr = edge_attr[perm] - - row, col = edge_index - hs = (z == 1).to(torch.float) - num_hs = scatter(hs[row], col, dim_size=N).tolist() - - x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types)) - x2 = torch.tensor(atom_features).view(N, -1) - x = torch.cat([x1.to(torch.float), x2], dim=-1) - - data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, neighbors=neighbor_dict, chiral_tag=chiral_tag, - name=smiles) - data.edge_index_dihedral_pairs = get_dihedral_pairs(data.edge_index, data=data) - - return data diff --git a/scripts/compare_confs.py b/scripts/compare_confs.py index 9662830..9b3732c 100644 --- a/scripts/compare_confs.py +++ b/scripts/compare_confs.py @@ -23,7 +23,7 @@ def calc_performance_stats(true_confs, model_confs): - + threshold = np.arange(0, 2.5, .125) rmsd_list = [] for tc in true_confs: @@ -42,7 +42,7 @@ def calc_performance_stats(true_confs, model_confs): coverage_precision = np.sum(rmsd_array.min(axis=0, keepdims=True) < np.expand_dims(threshold, 1), axis=1) / len(model_confs) amr_precision = rmsd_array.min(axis=0).mean() - + return coverage_recall, amr_recall, coverage_precision, amr_precision @@ -63,45 +63,45 @@ def clean_confs(smi, confs): for smi, n_confs, corrected_smi in tqdm(test_data.values): if not Chem.MolFromSmiles(smi): continue - + try: model_confs = model_preds[corrected_smi] except KeyError: print(f'no model prediction available: {corrected_smi}') - coverage_recall.append(threshold_ranges*0) + coverage_recall.append(threshold_ranges * 0) amr_recall.append(np.nan) - coverage_precision.append(threshold_ranges*0) + coverage_precision.append(threshold_ranges * 0) amr_precision.append(np.nan) test_smiles.append(smi) continue - # failure if model can't generate confs + # failure if model can't generate confs if len(model_confs) == 0: print(f'model failed: {smi}') - coverage_recall.append(threshold_ranges*0) + coverage_recall.append(threshold_ranges * 0) amr_recall.append(np.nan) - coverage_precision.append(threshold_ranges*0) + coverage_precision.append(threshold_ranges * 0) amr_precision.append(np.nan) test_smiles.append(smi) continue - + try: true_confs = true_mols[smi] except KeyError: print(f'cannot find ground truth conformer file: {smi}') continue - + # remove reacted conformers true_confs = clean_confs(corrected_smi, true_confs) if len(true_confs) == 0: print(f'poor ground truth conformers: {corrected_smi}') continue - + stats = calc_performance_stats(true_confs, model_confs) if not stats: print(f'failure calculating stats: {smi, corrected_smi}') continue - + cr, mr, cp, mp = stats coverage_recall.append(cr) amr_recall.append(mr) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b5e0172 --- /dev/null +++ b/setup.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +from setuptools import setup, find_packages + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + +setup( + name="GeoMol", + version="1.0.0", + author="Lagnajit Pattanaik", + author_email="lagnajit@mit.com", + description="Machine learning tools for molecule conformer generation", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/PattanaikL/GeoMol", + packages=find_packages(), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Chemistry" + ], + license="MIT License", + python_requires='>=3.7', +) diff --git a/train.py b/train.py index 36d87af..5aab533 100644 --- a/train.py +++ b/train.py @@ -6,11 +6,11 @@ import numpy as np import random -from model.model import GeoMol -from model.training import train, test, NoamLR +from geomol.model import GeoMol +from geomol.training import train, test, NoamLR from utils import create_logger, dict_to_str, plot_train_val_loss, save_yaml_file, get_optimizer_and_scheduler -from model.featurization import construct_loader -from model.parsing import parse_train_args, set_hyperparams +from geomol.featurization import construct_loader +from geomol.parsing import parse_train_args, set_hyperparams from torch.utils.tensorboard import SummaryWriter import resource diff --git a/utils.py b/utils.py index 7e5c894..c23c144 100644 --- a/utils.py +++ b/utils.py @@ -7,7 +7,7 @@ import yaml import torch -from model.training import build_lr_scheduler +from geomol.training import build_lr_scheduler sns.set_style('whitegrid', {'axes.edgecolor': '.2'}) @@ -19,8 +19,10 @@ sns.color_palette('husl') local_modules = ['gnn', 'encoder', 'coord_pred', 'd_mlp'] + class Standardizer: """Z-score standardization""" + def __init__(self, mean, std): self.mean = mean self.std = std @@ -173,12 +175,19 @@ def get_optimizer_and_scheduler(args, model, train_data_size): if args.scheduler == 'plateau': if args.separate_opts: - scheduling_fn = lambda opt: torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.7, - patience=5, min_lr=args.lr / 100) + def scheduling_fn(opt): + return torch.optim.lr_scheduler.ReduceLROnPlateau(opt, + mode='min', + factor=0.7, + patience=5, + min_lr=args.lr / 100) scheduler = MultipleScheduler(optimizer, scheduling_fn) else: - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.7, - patience=5, min_lr=args.lr/100) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=0.7, + patience=5, + min_lr=args.lr / 100) elif args.scheduler == 'noam': scheduler = build_lr_scheduler(optimizer=optimizer, args=args, train_data_size=train_data_size) else: