-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSparseBlockToDenseSum.lua
More file actions
117 lines (90 loc) · 3.95 KB
/
SparseBlockToDenseSum.lua
File metadata and controls
117 lines (90 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
local SparseBlockToDenseSum, parent = torch.class('nn.SparseBlockToDenseSum', 'nn.Module')
function SparseBlockToDenseSum:__init()
end
function SparseBlockToDenseSum:pri_ensureOutput(input)
if self.output ~= nil then
return
end
self.output = torch.zeros(input.nBatchSize, 1)
-- the following only done to support matrix operation only accumulation of results (instead of iterating using a for loop), hence maybe avoided
self.outputBufferA = torch.zeros(input.nBatchSize, 1) -- used for initial output result of each column (allocated for maximum possible size)
self.outputBufferB = torch.zeros(input.nBatchSize, 1) -- used for holding "scatter" results, to be added to self.output
end
function SparseBlockToDenseSum:pri_updateOutput_column(taInput)
local nInputWidth = taInput.teValue:size(2)
local teWeight = torch.ones(nInputWidth, 1)
-- calculate output for teDefault input
if taInput.teDefault then
local teDefaultInputExpanded = taInput.teDefault:view(1, nInputWidth):expand(self.output:size(1), nInputWidth) -- expand for multiplication
self.outputBufferB:zero()
self.outputBufferB:addmm(teDefaultInputExpanded, teWeight) -- so this writes the default, but sparse blocks will be overwritten next
end
-- calculate the output for Sparse blocks
local teInput = taInput.teValue
local nRows = teInput:size(1)
local teOutput = self.outputBufferA:narrow(1, 1, nRows)
teOutput:zero()
teOutput:addmm(teInput, teWeight)
-- copy result to buffer
local teDstIdx = torch.expand(taInput.teRowIdx, nRows, 1)
self.outputBufferB:scatter(1, teDstIdx, teOutput)
-- add buffer to output
self.output:add(self.outputBufferB)
-- cleanup the buffer
self.outputBufferB:scatter(1, teDstIdx, 0)
end
function SparseBlockToDenseSum:updateOutput(input)
self:pri_ensureOutput(input)
self.output:zero()
local nColumns = table.getn(input.taData)
for i=1, nColumns do
self:pri_updateOutput_column(input.taData[i])
end
return self.output
end
function SparseBlockToDenseSum:pri_ensureGradInput(input)
if self.gradInput ~= nil then
return
end
self.gradInput = { nBatchSize = input.nBatchSize, taData = {} }
local nColumns = table.getn(input.taData)
for i=1, nColumns do
local taInputCurr = input.taData[i]
taGradInputCurr = { teValue = torch.zeros(taInputCurr.teValue:size()),
teRowIdx = taInputCurr.teRowIdx }
if taInputCurr.teDefault then
taGradInputCurr.teGradOutputSum = torch.zeros(1, taInputCurr.teValue:size(2))
end
table.insert(self.gradInput.taData, taGradInputCurr)
end
end
function SparseBlockToDenseSum:pri_updateGradInput_column(taInput, teGradOutput, taGradInput)
local nWidth = taInput.teValue:size(2)
local teWeight = torch.ones(nWidth, 1)
-- copy teGradOutput to teGradOutputSelected based on teRowIdx
local nRows = taInput.teValue:size(1)
local teGradOutputSelected = self.outputBufferA:narrow(1, 1, nRows)
local teDstIdx = torch.expand(taInput.teRowIdx, nRows, 1)
teGradOutputSelected:gather(teGradOutput, 1, teDstIdx)
-- calculate and update gradInput
local gradInput = taGradInput.teValue
gradInput:zero()
gradInput:addmm(teGradOutputSelected, teWeight:t())
-- cleanup teGradOutputSelected
teGradOutputSelected:zero()
-- calculate gradOutput sum, then multiply by weights (just reordering optimization to save memory)
if taGradInput.teGradOutputSum then
local teGradOutputSum = teGradOutput:sum(1)
taGradInput.teGradOutputSum:mm(teGradOutputSum, teWeight:t())
end
end
function SparseBlockToDenseSum:updateGradInput(input, gradOutput)
self:pri_ensureGradInput(input)
local nColumns = table.getn(self.gradInput.taData)
for i=1, nColumns do
self:pri_updateGradInput_column(input.taData[i],
gradOutput,
self.gradInput.taData[i])
end
return self.gradInput
end