Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion datasets/copernicus/get_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@
# limitations under the License.

import cdsapi
import numpy as np
import os
import argparse
import warnings


def main(args):

warnings.warn(
"get_data() is deprecated and will be removed in the future. Use data_process/convert_wb2_to_makani_input.py instead.",
category=DeprecationWarning,
stacklevel=2, # so the warning points to the caller, not this line
)

# get base path
base_path = os.path.join(args.output_dir, "raw")
os.makedirs(base_path, exist_ok=True)
Expand Down
5 changes: 1 addition & 4 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ RUN pip install more_itertools zarr xarray pandas gcsfs boto3
RUN pip install moviepy imageio

# other python stuff
RUN pip install --upgrade wandb ruamel.yaml tqdm progressbar2 jsbeautifier
RUN pip install --upgrade wandb ruamel.yaml tqdm progressbar2

# numba
RUN pip install numba
Expand All @@ -66,9 +66,6 @@ ENV NUMBA_DISABLE_CUDA=1
# scoring tools
RUN pip install xskillscore properscoring

# benchy
RUN pip install git+https://github.com/romerojosh/benchy.git

# some useful scripts from mlperf
RUN pip install --ignore-installed "git+https://github.com/NVIDIA/mlperf-common.git"

Expand Down
1 change: 0 additions & 1 deletion makani/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@
params["n_future"] = args.multistep_count - 1 # note that n_future counts only the additional samples

# debug:
params["enable_benchy"] = args.enable_benchy
params["disable_ddp"] = args.disable_ddp
params["enable_grad_anomaly_detection"] = args.enable_grad_anomaly_detection

Expand Down
1 change: 0 additions & 1 deletion makani/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@
params["jit_mode"] = args.jit_mode
params["enable_odirect"] = args.enable_odirect
params["enable_s3"] = args.enable_s3
params["enable_benchy"] = args.enable_benchy
params["disable_ddp"] = args.disable_ddp
params["checkpointing_level"] = args.checkpointing_level
params["enable_synthetic_data"] = args.enable_synthetic_data
Expand Down
7 changes: 2 additions & 5 deletions makani/models/model_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import os
import shutil
import json
import jsbeautifier
import numpy as np
import torch
from makani.utils.YParams import ParamsBase
Expand Down Expand Up @@ -180,11 +179,9 @@ def save_model_package(params):
"""
# save out the current state of the parameters, make it human readable
config_path = os.path.join(params.experiment_dir, "config.json")
jsopts = jsbeautifier.default_options()
jsopts.indent_size = 2

with open(config_path, "w") as f:
msg = jsbeautifier.beautify(json.dumps(params.to_dict()), jsopts)
msg = json.dumps(params.to_dict(), indent=4, sort_keys=True)
f.write(msg)

if params.get("add_orography", False):
Expand All @@ -211,7 +208,7 @@ def save_model_package(params):
"entrypoint": {"name": f"{LocalPackage.THIS_MODULE}:load_time_loop"},
}
with open(os.path.join(params.experiment_dir, "metadata.json"), "w") as f:
msg = jsbeautifier.beautify(json.dumps(fcn_mip_data), jsopts)
msg = json.dumps(fcn_mip_data, indent=4, sort_keys=True)
f.write(msg)


Expand Down
9 changes: 4 additions & 5 deletions makani/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
import os
import importlib.util

# we need this here for the code to work
import importlib_metadata
from importlib.metadata import EntryPoint, entry_points
from importlib.metadata import entry_points

import logging

Expand Down Expand Up @@ -163,12 +161,13 @@ def get_model(params: ParamsBase, use_stochastic_interpolation: bool = False, mu

model_handle = _model_registry.get(params.nettype)
if model_handle is not None:
if isinstance(model_handle, (EntryPoint, importlib_metadata.EntryPoint)):
# EntryPoint-like (stdlib or backport importlib_metadata): call .load() to get the callable
if hasattr(model_handle, "load") and callable(model_handle.load):
model_handle = model_handle.load()

model_handle = partial(model_handle, inp_shape=inp_shape, out_shape=out_shape, inp_chans=inp_chans, out_chans=out_chans, **params.to_dict())
else:
raise KeyError(f"No model is registered under the name {name}")
raise KeyError(f"No model is registered under the name {params.nettype}")

# use the constraint wrapper
if hasattr(params, "constraints"):
Expand Down
1 change: 0 additions & 1 deletion makani/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@
params["n_future"] = args.multistep_count - 1 # note that n_future counts only the additional samples

# debug:
params["enable_benchy"] = args.enable_benchy
params["disable_ddp"] = args.disable_ddp
params["enable_grad_anomaly_detection"] = args.enable_grad_anomaly_detection

Expand Down
1 change: 0 additions & 1 deletion makani/train_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@
params["n_future"] = args.multistep_count - 1 # note that n_future counts only the additional samples

# debug:
params["enable_benchy"] = args.enable_benchy
params["disable_ddp"] = args.disable_ddp
params["enable_grad_anomaly_detection"] = args.enable_grad_anomaly_detection

Expand Down
1 change: 0 additions & 1 deletion makani/train_stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@
params["n_future"] = 0

# debug:
params["enable_benchy"] = args.enable_benchy
params["disable_ddp"] = args.disable_ddp
params["enable_grad_anomaly_detection"] = args.enable_grad_anomaly_detection

Expand Down
1 change: 0 additions & 1 deletion makani/utils/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def get_default_argument_parser(training=True):
parser.add_argument("--multistep_count", default=1, type=int, help="Number of autoregressive training steps. A value of 1 denotes conventional training")

# debug parameters
parser.add_argument("--enable_benchy", action="store_true")
if training:
parser.add_argument("--disable_ddp", action="store_true")
parser.add_argument("--enable_grad_anomaly_detection", action="store_true")
Expand Down
5 changes: 0 additions & 5 deletions makani/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,6 @@ def get_dataloader(params, files_pattern, device, mode="train"):
img_local_offset_y=dataloader.img_local_offset_y,
)

if params.enable_benchy and (mode == "train"):
from benchy.torch import BenchmarkGenericIteratorWrapper

dataloader = BenchmarkGenericIteratorWrapper(dataloader, params.batch_size)

# not needed for the no multifiles case
sampler = None

Expand Down
25 changes: 22 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ dependencies = [
"wandb>=0.13.7",
"numba",
"tqdm>=4.60.0",
"jsbeautifier",
"more-itertools",
"importlib-metadata",
"Pillow",
"ruamel.yaml",
]

[tool.setuptools.dynamic]
Expand All @@ -74,12 +74,31 @@ version = {attr = "makani.__version__"}

[project.optional-dependencies]
dev = [
"pytest>=6.0.0",
"black>=22.10.0",
"coverage>=6.5.0",
"nvidia_dali_cuda110>=1.16.0",
]

test = [
"pytest>=6.0.0",
"parameterized",
"properscoring",
"xarray",
"xskillscore",
]

data_process = [
"mpi4py",
"xarray",
"gcsfs",
"dask",
"progressbar2",
]

legacy_data_process = [
"cdsapi>=0.7.2"
]

vis = [
"matplotlib>=3.8.1",
"imageio>=2.28.1",
Expand Down
1 change: 0 additions & 1 deletion tests/distributed/distributed_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def get_default_parameters():
params.resuming = False
params.amp_mode = "none"
params.jit_mode = "none"
params.enable_benchy = False
params.disable_ddp = False
params.checkpointing_level = 0
params.enable_synthetic_data = False
Expand Down
1 change: 0 additions & 1 deletion tests/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def get_default_parameters():
params.resuming = False
params.amp_mode = "none"
params.jit_mode = "none"
params.enable_benchy = False
params.disable_ddp = False
params.checkpointing_level = 0
params.enable_synthetic_data = False
Expand Down