-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathreplaymemory.lua
More file actions
126 lines (106 loc) · 3.23 KB
/
replaymemory.lua
File metadata and controls
126 lines (106 loc) · 3.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
ReplayMemory = {}
ReplayMemory.__index = ReplayMemory
function boolToNum(val)
if val then
return 1
else
return 0
end
end
function ReplayMemory.create()
local rpmem = {}
setmetatable(rpmem, ReplayMemory)
return rpmem
end
function ReplayMemory:__init(args)
self.maxSize = args.maxSize
self.numEntries = 0
self.entries = {}
self.insertIndex = 0
self.inputDim = args.inputDim
self.batchSize = args.batchSize
self.buf_input = torch.Tensor(self.batchSize, 3, 112, 112):fill(0)
self.buf_input2 = torch.Tensor(self.batchSize, 3, 112, 112):fill(0)
self.buf_action = torch.Tensor(self.batchSize):fill(0)
self.buf_reward = torch.Tensor(self.batchSize):fill(0)
self.buf_term = torch.Tensor(self.batchSize):fill(0)
end
-- Add transition into the table
function ReplayMemory:add(elem)
if self.numEntries < self.maxSize then
self.numEntries = self.numEntries + 1
end
self.insertIndex = self.insertIndex + 1
if self.insertIndex > self.maxSize then
self.insertIndex = 1
end
self.entries[self.insertIndex] = elem
end
-- Clear memory
function ReplayMemory:reset()
self.numEntries = 0
self.insertIndex = 0
end
-- Returns mini batch
function ReplayMemory:sample()
if self.insertIndex > 0 then
local index
local i
local numSamples = 0
while (numSamples < ((self.batchSize))) do
i = torch.random(1, self.numEntries)
numSamples = numSamples + 1
self.buf_action[numSamples] = self.entries[i].action
self.buf_reward[numSamples] = self.entries[i].reward
self.buf_term[numSamples] =
boolToNum(self.entries[i].next_state.terminal)
self.buf_input[numSamples] =
self.entries[i].start_state.screenTensor:clone()
self.buf_input2[numSamples] =
self.entries[i].next_state.screenTensor:clone()
end
--self.buf_action[self.batchSize] = self.entries[self.insertIndex].action
--self.buf_reward[self.batchSize] = self.entries[self.insertIndex].reward
--self.buf_term[self.batchSize] =
-- boolToNum(self.entries[self.insertIndex].next_state.terminal)
--self.buf_input[self.batchSize] =
-- self.entries[self.insertIndex].start_state.screenTensor:clone()
--self.buf_input2[self.batchSize] =
-- self.entries[self.insertIndex].next_state.screenTensor:clone()
return self.buf_action, self.buf_reward, self.buf_input,
self.buf_input2, self.buf_term
else
return nil
end
end
-- Returns mini batch of replay memory from frame to frame
function ReplayMemory:getMinibatch(start)
if (start + self.batchSize - 1 > self.numEntries ) then
return nil
else
for i = 1, self.batchSize do
self.buf_action[i] = self.entries[i + start - 1].action
self.buf_reward[i] = self.entries[i + start - 1].reward
self.buf_term[i] =
boolToNum(self.entries[i + start - 1].next_state.terminal)
self.buf_input[i] =
self.entries[i + start - 1].start_state.screenTensor:clone()
self.buf_input2[i] =
self.entries[i + start - 1].next_state.screenTensor:clone()
end
return self.buf_action, self.buf_reward, self.buf_input,
self.buf_input2, self.buf_term
end
end
-- Returns last transition
function ReplayMemory:lastTransition()
if (self.insertIndex >= 1) then
return self.entries[self.insertIndex]
else
return nil
end
end
-- Returns number of entries
function ReplayMemory:size()
return self.numEntries
end