Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions ovi/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math

import torch
import torch.amp as amp
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -34,7 +34,7 @@ def sinusoidal_embedding_1d(dim, position):
return x


@amp.autocast('cuda', enabled=False)
@amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000, freqs_scaling=1.0):
assert dim % 2 == 0
pos = torch.arange(max_seq_len)
Expand All @@ -44,7 +44,7 @@ def rope_params(max_seq_len, dim, theta=10000, freqs_scaling=1.0):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs

@amp.autocast('cuda', enabled=False)
@amp.autocast(enabled=False)
def rope_apply_1d(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2 ## b l h d
c_rope = freqs.shape[1] # number of complex dims to rotate
Expand All @@ -69,7 +69,7 @@ def rope_apply_1d(x, grid_sizes, freqs):
output.append(x_i)
return torch.stack(output).bfloat16()

@amp.autocast('cuda', enabled=False)
@amp.autocast(enabled=False)
def rope_apply_3d(x, grid_sizes, freqs):
n, c = x.size(2), x.size(3) // 2

Expand Down Expand Up @@ -99,7 +99,7 @@ def rope_apply_3d(x, grid_sizes, freqs):
output.append(x_i)
return torch.stack(output).bfloat16()

@amp.autocast('cuda', enabled=False)
@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs):
x_ndim = grid_sizes.shape[-1]
if x_ndim == 3:
Expand Down Expand Up @@ -447,23 +447,23 @@ def forward(
"""
assert e.dtype == torch.bfloat16
assert len(e.shape) == 4 and e.size(2) == 6 and e.shape[1] == x.shape[1], f"{e.shape}, {x.shape}"
with amp.autocast('cuda', dtype=torch.bfloat16):
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
e = self.modulation(e).chunk(6, dim=2)
assert e[0].dtype == torch.bfloat16

# self-attention
y = self.self_attn(
self.norm1(x).bfloat16() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
seq_lens, grid_sizes, freqs)
with amp.autocast('cuda', dtype=torch.bfloat16):
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
x = x + y * e[2].squeeze(2)

# cross-attention & ffn function
def cross_attn_ffn(x, context, context_lens, e):
x = x + self.cross_attn(self.norm3(x), context, context_lens)
y = self.ffn(
self.norm2(x).bfloat16() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
with amp.autocast('cuda', dtype=torch.bfloat16):
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
x = x + y * e[5].squeeze(2)
return x

Expand Down Expand Up @@ -495,7 +495,7 @@ def forward(self, x, e):
e(Tensor): Shape [B, L, C]
"""
assert e.dtype == torch.bfloat16
with amp.autocast('cuda', dtype=torch.bfloat16):
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
e = (self.modulation.bfloat16().unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) # 1 1 2 D, B L 1 D -> B L 2 D -> 2 * (B L 1 D)
x = (self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)))
return x
Expand Down Expand Up @@ -740,7 +740,7 @@ def prepare_transformer_block_kwargs(
# print(f"zeroing out first {_first_images_seq_len} from t: {t.shape}, {t}")
else:
t = t.unsqueeze(1).expand(t.size(0), seq_len)
with amp.autocast('cuda', dtype=torch.bfloat16):
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
bt = t.size(0)
t = t.flatten()
e = self.time_embedding(
Expand Down
12 changes: 6 additions & 6 deletions ovi/modules/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

import torch
import torch.amp as amp
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
Expand Down Expand Up @@ -672,14 +672,14 @@ def encode(self, videos):
"""
videos: A list of videos each with shape [C, T, H, W].
"""
with amp.autocast('cuda', dtype=self.dtype):
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
for u in videos
]

def decode(self, zs):
with amp.autocast('cuda', dtype=self.dtype):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
Expand All @@ -688,16 +688,16 @@ def decode(self, zs):

@torch.no_grad()
def wrapped_decode(self, z):
with amp.autocast('cuda', dtype=self.dtype):
with torch.amp.autocast('cuda', dtype=self.dtype):
return self.model.decode(z, self.scale).float().clamp_(-1, 1)

@torch.no_grad()
def wrapped_decode_stream(self, z):
with amp.autocast('cuda', dtype=self.dtype):
with torch.amp.autocast('cuda', dtype=self.dtype):
return self.model.decode_stream(z, self.scale).float().clamp_(-1, 1)

@torch.no_grad()
def wrapped_encode(self, video):
with amp.autocast('cuda', dtype=self.dtype):
with torch.amp.autocast('cuda', dtype=self.dtype):
return self.model.encode(video, self.scale).float()

10 changes: 5 additions & 5 deletions ovi/modules/vae2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

import torch
import torch.amp as amp
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
Expand Down Expand Up @@ -1025,7 +1025,7 @@ def encode(self, videos):
try:
if not isinstance(videos, list):
raise TypeError("videos should be a list")
with amp.autocast('cuda', dtype=self.dtype):
with amp.autocast(dtype=self.dtype):
return [
self.model.encode(u.unsqueeze(0),
self.scale).float().squeeze(0)
Expand All @@ -1039,7 +1039,7 @@ def decode(self, zs):
try:
if not isinstance(zs, list):
raise TypeError("zs should be a list")
with amp.autocast('cuda', dtype=self.dtype):
with amp.autocast(dtype=self.dtype):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1,
Expand All @@ -1054,7 +1054,7 @@ def wrapped_decode(self, zs):
try:
if not isinstance(zs, torch.Tensor):
raise TypeError("zs should be a torch.Tensor")
with amp.autocast('cuda', dtype=self.dtype):
with amp.autocast(dtype=self.dtype):
return self.model.decode(zs, self.scale).float().clamp_(-1,
1)

Expand All @@ -1066,7 +1066,7 @@ def wrapped_encode(self, video):
try:
if not isinstance(video, torch.Tensor):
raise TypeError("video should be a torch.Tensor")
with amp.autocast('cuda', dtype=self.dtype):
with amp.autocast(dtype=self.dtype):

return self.model.encode(video, self.scale).float()

Expand Down
2 changes: 1 addition & 1 deletion ovi/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def save_video(

# Add audio if provided
if audio_numpy is not None:
with tempfile.NamedTemporaryFile(suffix=".wav", mode='wb', delete=False) as temp_audio_file:
with tempfile.NamedTemporaryFile(suffix=".wav") as temp_audio_file:
wavfile.write(
temp_audio_file.name,
sample_rate,
Expand Down