Skip to content
This repository was archived by the owner on Aug 1, 2024. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def main(args, resume_preempt=False):
device = torch.device('cuda:0')
torch.cuda.set_device(device)

# Check if bfloat16 is supported, if not fall back to float16 if bfloat16 was requested
autocast_dtype = torch.bfloat16 if use_bfloat16 else torch.float32

bfloat16_supported = False
try:
bfloat16_supported = torch.cuda.is_bf16_supported()
except RuntimeError:
bfloat16_supported = False

if not bfloat16_supported and use_bfloat16:
logger.info(f'Device does not support bfloat16, falling back to float16')
autocast_dtype = torch.float16

# -- DATA
use_gaussian_blur = args['data']['use_gaussian_blur']
use_horizontal_flip = args['data']['use_horizontal_flip']
Expand Down Expand Up @@ -313,7 +326,7 @@ def loss_fn(z, h):
return loss

# Step 1. Forward
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=use_bfloat16):
with torch.cuda.amp.autocast(dtype=autocast_dtype, enabled=use_bfloat16):
h = forward_target()
z = forward_context()
loss = loss_fn(z, h)
Expand Down