-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMAML_Example.py
More file actions
54 lines (46 loc) · 1.46 KB
/
MAML_Example.py
File metadata and controls
54 lines (46 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torch.optim as optim
# Define the model architecture
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)
# Define the MAML algorithm
class MAML:
def __init__(self, model, lr_inner=0.01, lr_outer=0.001):
self.model = model
self.lr_inner = lr_inner
self.lr_outer = lr_outer
def inner_loop(self, task, x, y):
task_model = Model()
task_model.load_state_dict(self.model.state_dict())
optimizer = optim.SGD(task_model.parameters(), lr=self.lr_inner)
for i in range(10):
loss = nn.MSELoss()(task_model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return task_model
def outer_loop(self, tasks):
optimizer = optim.Adam(self.model.parameters(), lr=self.lr_outer)
for task in tasks:
x, y = task
task_model = self.inner_loop(task, x, y)
loss = nn.MSELoss()(task_model(x), y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Define the tasks
tasks = []
for i in range(10):
x = torch.randn(10, 1)
y = x * 2 + torch.randn(10, 1) * 0.1
tasks.append((x, y))
# Initialize the model and the MAML algorithm
model = Model()
maml = MAML(model)
# Train the model using MAML
maml.outer_loop(tasks)