-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrange_optimization.py
More file actions
117 lines (100 loc) · 3.59 KB
/
range_optimization.py
File metadata and controls
117 lines (100 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from enum import Enum
from typing import Callable, List
from abc import ABC, abstractmethod
import numpy as np
import torch
from ml_dtypes import float8_e4m3fn, float8_e5m2
from quantization import Statistics, QuantizationGranularity, dtype_boundaries
class ClippingStrategy(ABC):
def __init__(self, name: str) -> None:
self.name = name
@abstractmethod
def optimize_boundaries(
self,
values_type: str,
weights: torch.Tensor,
activations: torch.Tensor,
stat: Statistics
) -> Statistics:
...
class KLDivergenceClipping(ClippingStrategy):
def __init__(
self, name: str,
weights_only: bool,
granularity: QuantizationGranularity
) -> None:
self.name = name
self.weights_only = weights_only
self.granularity = granularity
def optimize_boundaries(
self,
values_type: str,
weights: torch.Tensor,
activations: torch.Tensor,
stat: Statistics
) -> Statistics:
pass
class PercentileClipping(ClippingStrategy):
def __init__(self, name: str, p: float) -> None:
self.name = name
self.p = p
def optimize_boundaries(
self,
values_type: str,
weights: torch.Tensor,
activations: torch.Tensor,
stat: Statistics
) -> Statistics:
clipped_stat = stat
updated_chosen_stat = getattr(clipped_stat, values_type)
for channel in range(len(updated_chosen_stat.min_vlaues)):
updated_chosen_stat.min_values = np.percentile(updated_chosen_stat.min_values, self.p)
updated_chosen_stat.max_values = np.percentile(updated_chosen_stat.max_values, self.p)
setattr(clipped_stat, values_type, updated_chosen_stat)
return clipped_stat
class LossOptimizationClipping(ClippingStrategy):
def __init__(self, name: str, loss_func: Callable) -> None:
pass
def optimize_boundaries(
self,
values_type: str,
weights: torch.Tensor,
activations: torch.Tensor,
stat: Statistics
) -> Statistics:
pass
class ScaleOptimizationStrategy:
def __init__(self, name: str, quantization_scheme: str) -> None:
self.name = name
self.quantization_scheme
def optimize_scale(
self,
tensor: torch.Tensor,
) -> List[np.float16]:
...
class GridScaleOptimizaion(ScaleOptimizationStrategy):
def __init__(self, name: str, quantization_scheme: str) -> None:
self.name = name
self.quantization_scheme = quantization_scheme
if "e4m3" in quantization_scheme:
self.quantization_dtype = float8_e4m3fn
elif "e5m2" in quantization_scheme:
self.quantization_dtype = float8_e5m2
else:
self.quantization_dtype= np.int8
def optimize_scale(
self,
tensor: torch.Tensor,
) -> List[np.float16]:
def _calculate_loss(tensor_fp32: torch.Tensor, scale: np.float32) -> np.float32:
tensor_fp8 = (tensor_fp32 / scale).astype(dtype=self.quantization_dtype)
tensor_fp32_after_fp8 = tensor_fp8.astype(dtype=np.float32) * scale
return np.sum(np.abs(tensor_fp32 - tensor_fp32_after_fp8))
tensor = tensor.cpu().detach().numpy()
scale_grid = np.linspace(0.01, 10, 15)
losses = []
for scale in scale_grid:
losses.append(_calculate_loss(tensor, scale))
# TODO: extend to all quantization schemes
tensor = torch.Tensor(tensor).to("cuda")
return torch.tensor(scale_grid[np.argmin(losses)]).to("cuda")