forked from ujjwal-9/Knowledge-Distillation
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmobilenet.py
More file actions
23 lines (19 loc) · 783 Bytes
/
mobilenet.py
File metadata and controls
23 lines (19 loc) · 783 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import keras
from keras.applications.mobilenet import MobileNet
from keras.models import Model
from keras.layers import Activation, GlobalAveragePooling2D, Dropout, Dense, Input
def get_mobilenet(input_size, alpha, weight_decay, dropout):
input_shape = (input_size, input_size, 3)
base_model = MobileNet(
include_top=False, weights='imagenet',
input_shape=input_shape, alpha=alpha
)
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(dropout)(x)
logits = Dense(256, kernel_regularizer=keras.regularizers.l2(weight_decay))(x)
probabilities = Activation('softmax')(logits)
model = Model(base_model.input, probabilities)
for layer in model.layers[:-2]:
layer.trainable = False
return model