Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions makani/utils/inference/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import numpy as np
from tqdm import tqdm
import pynvml
import datetime as dt
import h5py as h5

Expand Down Expand Up @@ -47,6 +46,7 @@

# checkpoint helpers
from makani.utils.checkpoint_helpers import get_latest_checkpoint_version
from makani.utils.training.training_helpers import get_memory_usage

class Inferencer(Driver):
"""
Expand All @@ -70,11 +70,6 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
if self.log_to_wandb:
self._init_wandb(self.params, job_type="inference")

# nvml stuff
if self.log_to_screen:
pynvml.nvmlInit()
self.nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)

# set amp_parameters
if hasattr(self.params, "amp_mode") and (self.params.amp_mode != "none"):
self.amp_enabled = True
Expand Down Expand Up @@ -754,9 +749,8 @@ def score_model(
# log parameters
if self.log_to_screen:
# log memory usage so far
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
max_mem_gb = torch.cuda.max_memory_allocated(device=self.device) / (1024.0 * 1024.0 * 1024.0)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb} GB ({max_mem_gb} GB for pytorch)")
all_mem_gb, max_mem_gb = get_memory_usage(self.device)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb:.2f} GB ({max_mem_gb:.2f} GB for pytorch)")
# announce training start
self.logger.info("Starting Scoring...")

Expand Down
17 changes: 4 additions & 13 deletions makani/utils/training/autoencoder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
import numpy as np
from tqdm import tqdm

# gpu info
import pynvml

# torch
import torch
import torch.amp as amp
Expand Down Expand Up @@ -56,7 +53,7 @@
from makani.utils.checkpoint_helpers import get_latest_checkpoint_version

# weight normalizing helper
from makani.utils.training.training_helpers import clip_grads
from makani.utils.training.training_helpers import get_memory_usage, clip_grads

class AutoencoderTrainer(Driver):
"""
Expand All @@ -79,11 +76,6 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
tens = torch.ones(1, device=self.device)
dist.all_reduce(tens, group=comm.get_group("data"))

# nvml stuff
if self.log_to_screen:
pynvml.nvmlInit()
self.nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)

# set amp_parameters
if hasattr(self.params, "amp_mode") and (self.params.amp_mode != "none"):
self.amp_enabled = True
Expand Down Expand Up @@ -327,9 +319,8 @@ def train(self, training_profiler=None, validation_profiler=None):
# log parameters
if self.log_to_screen:
# log memory usage so far
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
max_mem_gb = torch.cuda.max_memory_allocated(device=self.device) / (1024.0 * 1024.0 * 1024.0)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb} GB ({max_mem_gb} GB for pytorch)")
all_mem_gb, max_mem_gb = get_memory_usage(self.device)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb:.2f} GB ({max_mem_gb:.2f} GB for pytorch)")
# announce training start
self.logger.info("Starting Training Loop...")

Expand Down Expand Up @@ -721,7 +712,7 @@ def get_pad(nchar):
self.logger.info(f"Performance Parameters:")
self.logger.info(print_prefix + "training steps: {}".format(train_logs["train_steps"]))
self.logger.info(print_prefix + "validation steps: {}".format(valid_logs["base"]["validation steps"]))
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
all_mem_gb, _ = get_memory_usage(self.device)
self.logger.info(print_prefix + f"memory footprint [GB]: {all_mem_gb:.2f}")
for key in timing_logs.keys():
self.logger.info(print_prefix + key + ": {:.2f}".format(timing_logs[key]))
Expand Down
17 changes: 4 additions & 13 deletions makani/utils/training/deterministic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
import numpy as np
from tqdm import tqdm

# gpu info
import pynvml

# torch
import torch
from torch import amp
Expand Down Expand Up @@ -59,7 +56,7 @@
from makani.utils.checkpoint_helpers import get_latest_checkpoint_version

# weight normalizing helper
from makani.utils.training.training_helpers import clip_grads
from makani.utils.training.training_helpers import get_memory_usage, clip_grads

class Trainer(Driver):
"""
Expand All @@ -86,11 +83,6 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
dist.all_reduce(tens, group=comm.get_group("data"))
self.timers["nccl init"] = timer.time

# nvml stuff
if self.log_to_screen:
pynvml.nvmlInit()
self.nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)

# set amp_parameters
if hasattr(self.params, "amp_mode") and (self.params.amp_mode != "none"):
self.amp_enabled = True
Expand Down Expand Up @@ -371,9 +363,8 @@ def train(self, training_profiler=None, validation_profiler=None):
# log parameters
if self.log_to_screen:
# log memory usage so far
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
max_mem_gb = torch.cuda.max_memory_allocated(device=self.device) / (1024.0 * 1024.0 * 1024.0)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb} GB ({max_mem_gb} GB for pytorch)")
all_mem_gb, max_mem_gb = get_memory_usage(self.device)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb:.2f} GB ({max_mem_gb:.2f} GB for pytorch)")
# announce training start
self.logger.info("Starting Training Loop...")

Expand Down Expand Up @@ -727,7 +718,7 @@ def get_pad(nchar):
self.logger.info(f"Performance Parameters:")
self.logger.info(print_prefix + "training steps: {}".format(train_logs["train_steps"]))
self.logger.info(print_prefix + "validation steps: {}".format(valid_logs["base"]["validation steps"]))
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
all_mem_gb, _ = get_memory_usage(self.device)
self.logger.info(print_prefix + f"memory footprint [GB]: {all_mem_gb:.2f}")
for key in timing_logs.keys():
self.logger.info(print_prefix + key + ": {:.2f}".format(timing_logs[key]))
Expand Down
20 changes: 4 additions & 16 deletions makani/utils/training/ensemble_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
import numpy as np
from tqdm import tqdm

# gpu info
import pynvml

# torch
import torch
from torch import amp
Expand All @@ -35,9 +32,6 @@
# timers
from makani.utils.profiling import Timer

# for the manipulation of state dict
from collections import OrderedDict

# makani depenedencies
from makani.utils import LossHandler, MetricsHandler
from makani.utils.driver import Driver
Expand All @@ -64,7 +58,7 @@
from makani.utils.checkpoint_helpers import get_latest_checkpoint_version

# weight normalizing helper
from makani.utils.training.training_helpers import clip_grads
from makani.utils.training.training_helpers import get_memory_usage, clip_grads

class EnsembleTrainer(Trainer):
"""
Expand All @@ -91,11 +85,6 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
dist.all_reduce(tens, group=comm.get_group("data"))
self.timers["nccl init"] = timer.time

# nvml stuff
if self.log_to_screen:
pynvml.nvmlInit()
self.nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)

# set amp_parameters
if hasattr(self.params, "amp_mode") and (self.params.amp_mode != "none"):
self.amp_enabled = True
Expand Down Expand Up @@ -367,9 +356,8 @@ def train(self, training_profiler=None, validation_profiler=None):
# log parameters
if self.log_to_screen:
# log memory usage so far
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
max_mem_gb = torch.cuda.max_memory_allocated(device=self.device) / (1024.0 * 1024.0 * 1024.0)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb} GB ({max_mem_gb} GB for pytorch)")
all_mem_gb, max_mem_gb = get_memory_usage(self.device)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb:.2f} GB ({max_mem_gb:.2f} GB for pytorch)")
# announce training start
self.logger.info("Starting Ensemble Training Loop...")

Expand Down Expand Up @@ -776,7 +764,7 @@ def get_pad(nchar):
self.logger.info(f"Performance Parameters:")
self.logger.info(print_prefix + "training steps: {}".format(train_logs["train_steps"]))
self.logger.info(print_prefix + "validation steps: {}".format(valid_logs["base"]["validation steps"]))
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
all_mem_gb, _ = get_memory_usage(self.device)
self.logger.info(print_prefix + f"memory footprint [GB]: {all_mem_gb:.2f}")
for key in timing_logs.keys():
self.logger.info(print_prefix + key + ": {:.2f}".format(timing_logs[key]))
Expand Down
17 changes: 4 additions & 13 deletions makani/utils/training/stochastic_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
import numpy as np
from tqdm import tqdm

# gpu info
import pynvml

# torch
import torch
import torch.optim as optim
Expand Down Expand Up @@ -60,7 +57,7 @@
from makani.utils.checkpoint_helpers import get_latest_checkpoint_version

# weight normalizing helper
from makani.utils.training.training_helpers import normalize_weights, clip_grads
from makani.utils.training.training_helpers import get_memory_usage, normalize_weights, clip_grads


class StochasticTrainer(Driver):
Expand All @@ -87,11 +84,6 @@ def __init__(self, params: Optional[YParams] = None, world_rank: Optional[int] =
tens = torch.ones(1, device=self.device)
dist.all_reduce(tens, group=comm.get_group("data"))

# nvml stuff
if self.log_to_screen:
pynvml.nvmlInit()
self.nvml_handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)

# set amp_parameters
if hasattr(self.params, "amp_mode") and (self.params.amp_mode != "none"):
self.amp_enabled = True
Expand Down Expand Up @@ -346,9 +338,8 @@ def train(self):
# log parameters
if self.log_to_screen:
# log memory usage so far
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
max_mem_gb = torch.cuda.max_memory_allocated(device=self.device) / (1024.0 * 1024.0 * 1024.0)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb} GB ({max_mem_gb} GB for pytorch)")
all_mem_gb, max_mem_gb = get_memory_usage(self.device)
self.logger.info(f"Scaffolding memory high watermark: {all_mem_gb:.2f} GB ({max_mem_gb:.2f} GB for pytorch)")
# announce training start
self.logger.info("Starting Training Loop...")

Expand Down Expand Up @@ -712,7 +703,7 @@ def get_pad(nchar):
self.logger.info(f"Performance Parameters:")
self.logger.info(print_prefix + "training steps: {}".format(train_logs["train_steps"]))
self.logger.info(print_prefix + "validation steps: {}".format(valid_logs["base"]["validation steps"]))
all_mem_gb = pynvml.nvmlDeviceGetMemoryInfo(self.nvml_handle).used / (1024.0 * 1024.0 * 1024.0)
all_mem_gb, _ = get_memory_usage(self.device)
self.logger.info(print_prefix + f"memory footprint [GB]: {all_mem_gb:.2f}")
for key in timing_logs.keys():
self.logger.info(print_prefix + key + ": {:.2f}".format(timing_logs[key]))
Expand Down
8 changes: 8 additions & 0 deletions makani/utils/training/training_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from makani.utils import comm


def get_memory_usage(device):
free_mem, total_mem = torch.cuda.mem_get_info(device=device)
allocated_mem_gb = (total_mem - free_mem) / (1024.0 * 1024.0 * 1024.0)
torch_mem_gb = torch.cuda.max_memory_allocated(device=device) / (1024.0 * 1024.0 * 1024.0)

return allocated_mem_gb, torch_mem_gb


def normalize_weights(model, eps=1e-5):
for param in model.parameters():
# numel = torch.tensor(param.numel(), dtype=torch.long, device=param.device)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ dependencies = [
"wandb>=0.13.7",
"numba",
"tqdm>=4.60.0",
"pynvml>=10.0.0",
"jsbeautifier",
"more-itertools",
"importlib-metadata",
Expand Down