-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkd.html
More file actions
157 lines (144 loc) · 7.75 KB
/
kd.html
File metadata and controls
157 lines (144 loc) · 7.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Knowledge Distillation Explained with PyTorch</title>
<style>
body {
font-family: sans-serif;
max-width: 900px;
margin: 0 auto;
padding: 2rem;
line-height: 1.6;
color: #333;
}
h1, h2, h3, h4 {
color: #2c3e50; /* Darker shade for better contrast */
margin-top: 1.5em;
margin-bottom: 0.5em;
}
h1 { font-size: 2.5em; border-bottom: 2px solid #3498db; padding-bottom: 0.3em;}
h2 { font-size: 2em; border-bottom: 1px solid #bdc3c7; padding-bottom: 0.2em;}
h3 { font-size: 1.5em; }
h4 { font-size: 1.2em; color: #555;}
nav { margin-bottom: 30px; padding: 10px; background: #ecf0f1; border: 1px solid #bdc3c7; border-radius: 4px;}
nav ul { list-style: none; padding: 0; }
nav li { display: inline-block; margin-right: 15px; }
nav a { text-decoration: none; color: #3498db; font-weight: bold;}
nav a:hover { text-decoration: underline; color: #2980b9;}
pre {
background: #f8f9f9; /* Lighter background for code blocks */
padding: 1rem;
overflow-x: auto;
border: 1px solid #e1e4e8; /* Softer border */
border-left: 4px solid #3498db; /* Accent border */
border-radius: 4px;
font-size: 0.9em;
}
code {
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace;
}
/* For inline code */
p > code, li > code, table td > code {
background: #e8eaed;
padding: 0.2em 0.4em;
border-radius: 3px;
font-size: 0.85em;
}
pre code { /* Reset for code inside pre, already handled by pre styling */
background: none;
padding: 0;
font-size: 1em; /* Ensure pre's font size is inherited */
}
ul {
padding-left: 20px;
}
li {
margin-bottom: 0.5em;
}
strong {
color: #2980b9;
}
</style>
<script type="text/javascript" async
src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-MML-AM_CHTML">
</script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
tex2jax: {
inlineMath: [['$','$'], ['\\(','\\)']],
displayMath: [['$$','$$'], ['\\[','\\]']],
processEscapes: true
}
});
</script>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/monokai.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
<script>hljs.highlightAll();</script>
</head>
<body>
<nav>
<ul>
<li><a href="index.html">Home</a></li>
<li><a href="https://github.com/raminaeye/model-optimization/blob/main/knowledge_distillation/knowledge_distillation.ipynb">Knowledge Distillation-Notebook</a></li>
</ul>
</nav>
<h1>Knowledge Distillation</h1>
<p>
How do you shrink a giant model into a small one without a significant loss in performance? <strong>Knowledge Distillation (KD)</strong> provides a powerful solution. The core idea is to use a large, accurate pre-trained model (the "teacher") to train a smaller, faster model (the "student") that mimics the teacher’s behavior.
</p>
<p>
This technique has proven highly effective in practice. For example, <code>DistilBERT</code> was distilled from the <code>BERT-base</code> model, resulting in a 40% parameter reduction while retaining 95% of BERT's performance. Similarly, the original <code>MobileNet</code> was distilled from <code>ResNet-50</code>.
</p>
<p>
Today, Knowledge Distillation is widely deployed for a variety of tasks, including:
</p>
<ul>
<li>Getting Automatic Speech Recognition (ASR) models to run efficiently on-device.</li>
<li>Performing real-time segmentation with smaller, faster models.</li>
<li>Creating tiny yet powerful transformer models for resource-constrained environments.</li>
<li>Developing lightweight reinforcement learning policies distilled from large, complex RL agents.</li>
</ul>
<h2>The Core Idea: Learning from "Soft" Predictions</h2>
<p>
Instead of training the student model solely on the hard ground-truth labels (e.g., "this image is a cat"), the student also learns from the "soft" predictions produced by the teacher. These soft predictions are the probability distribution over all classes that the teacher model outputs. This distribution contains richer information, revealing how the teacher "thinks" (e.g., "this image is 90% likely a cat, 8% a dog, and 2% a car").
</p>
<p>
To make this information more useful, the teacher's output logits are typically softened using a <strong>temperature scaling</strong> parameter, $T$.
</p>
$$ \text{soft_targets} = \text{softmax}\left(\frac{\text{logits}}{T}\right) $$
<p>
The temperature $T$ controls the smoothness of the output distribution.
</p>
<ul>
<li>When $T = 1$, you get the standard <code>softmax</code> output.</li>
<li>When $T > 1$, the distribution becomes softer, shrinking the differences between class probabilities. This better reveals the teacher's knowledge about class similarities (e.g., that a truck is more similar to a car than to a cat).</li>
<li>When $T < 1$, the distribution becomes sharper, approaching a one-hot encoding.</li>
</ul>
<p>
Using a higher temperature helps the student learn the teacher’s inductive biases and provides a much better training signal, especially when true labels are sparse or noisy.
</p>
<h2>The Distillation Loss Function</h2>
<p>
The student is trained by optimizing a composite loss function that combines two objectives. This encourages the student to match both the ground truth and the teacher's soft predictions.
</p>
<ol>
<li><strong>Classification Loss:</strong> A standard cross-entropy loss between the student's predictions and the hard ground-truth labels.</li>
<li><strong>Distillation Loss:</strong> A loss that measures the difference between the student's and teacher's soft predictions. The Kullback-Leibler (KL) divergence is typically used for this.</li>
</ol>
<p>The final loss is a weighted average of these two components, controlled by a hyperparameter $\alpha$:</p>
$$ \mathcal{L} = \alpha \cdot \mathcal{L}_{CE}(\text{student_logits, true_labels}) + (1 - \alpha) \cdot \mathcal{L}_{KL}(\text{student_logits}_T, \text{teacher_logits}_T) $$
<h2>Types of Knowledge Distillation</h2>
<h3>Response-Based Distillation (Logit Distillation)</h3>
<p>
This is the most common and straightforward form of KD, where only the final output logits from the teacher are used to train the student. While this method doesn't capture the teacher's internal reasoning process, its simplicity and effectiveness have made it very popular. Both <code>DistilBERT</code> and <code>MobileNet</code> used this approach. The distillation loss is the KL divergence between the temperature-scaled outputs of the student ($z_s$) and teacher ($z_t$).
</p>
$$ \mathcal{L}_{KD} = KL\left(\text{softmax}\left(\frac{z_s}{T}\right) \bigg|\bigg| \text{softmax}\left(\frac{z_t}{T}\right)\right) $$
<h3>Feature-Based Distillation</h3>
<p>This approach goes deeper by forcing the student to mimic the teacher's intermediate feature representations from one or more layers, not just the final output.</p>
<h3>Relation-Based Distillation</h3>
<p>Here, the student learns from the relationships between different data points or layers as seen by the teacher, rather than the absolute values of the outputs.</p>
<h3>Self-Distillation</h3>
<p>An interesting variant where a model teaches itself. Typically, a larger teacher model is used to train a student of the exact same architecture, often leading to improved performance and robustness.</p>
</body>
</html>