Skip to content

Commit 0a675e8

Browse files
committed
BUGFIX: Add support for broadcasting weights when using mean.
1 parent 9d99010 commit 0a675e8

File tree

2 files changed

+50
-13
lines changed

2 files changed

+50
-13
lines changed

src/api/c/mean.cpp

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
********************************************************/
99

1010
#include <af/dim4.hpp>
11+
#include <af/data.h>
1112
#include <af/statistics.h>
1213
#include <af/defines.h>
1314
#include <err_common.hpp>
@@ -91,23 +92,39 @@ af_err af_mean_weighted(af_array *out, const af_array in, const af_array weights
9192
af_dtype wType = wInfo.getType();
9293

9394
ARG_ASSERT(2, (wType==f32 || wType==f64)); /* verify that weights are non-complex real numbers */
94-
ARG_ASSERT(2, iInfo.dims() == wInfo.dims());
95+
96+
//FIXME: We should avoid additional copies
97+
af_array w = weights;
98+
if (iInfo.dims() != wInfo.dims()) {
99+
dim4 iDims = iInfo.dims();
100+
dim4 wDims = wInfo.dims();
101+
dim4 tDims(1,1,1,1);
102+
for (int i = 0; i < 4; i++) {
103+
ARG_ASSERT(2, wDims[i] == 1 || wDims[i] == iDims[i]);
104+
tDims[i] = iDims[i] / wDims[i];
105+
}
106+
AF_CHECK(af_tile(&w, weights, tDims[0], tDims[1], tDims[2], tDims[3]));
107+
}
95108

96109
switch(iType) {
97-
case f64: output = mean< double>(in, weights, dim); break;
98-
case f32: output = mean< float >(in, weights, dim); break;
99-
case s32: output = mean< float >(in, weights, dim); break;
100-
case u32: output = mean< float >(in, weights, dim); break;
101-
case s64: output = mean< double>(in, weights, dim); break;
102-
case u64: output = mean< double>(in, weights, dim); break;
103-
case s16: output = mean< float >(in, weights, dim); break;
104-
case u16: output = mean< float >(in, weights, dim); break;
105-
case u8: output = mean< float >(in, weights, dim); break;
106-
case b8: output = mean< float >(in, weights, dim); break;
107-
case c32: output = mean< cfloat>(in, weights, dim); break;
108-
case c64: output = mean<cdouble>(in, weights, dim); break;
110+
case f64: output = mean< double>(in, w, dim); break;
111+
case f32: output = mean< float >(in, w, dim); break;
112+
case s32: output = mean< float >(in, w, dim); break;
113+
case u32: output = mean< float >(in, w, dim); break;
114+
case s64: output = mean< double>(in, w, dim); break;
115+
case u64: output = mean< double>(in, w, dim); break;
116+
case s16: output = mean< float >(in, w, dim); break;
117+
case u16: output = mean< float >(in, w, dim); break;
118+
case u8: output = mean< float >(in, w, dim); break;
119+
case b8: output = mean< float >(in, w, dim); break;
120+
case c32: output = mean< cfloat>(in, w, dim); break;
121+
case c64: output = mean<cdouble>(in, w, dim); break;
109122
default : TYPE_ERROR(1, iType);
110123
}
124+
125+
if (w != weights) {
126+
AF_CHECK(af_release_array(w));
127+
}
111128
std::swap(*out, output);
112129
}
113130
CATCHALL;

test/mean.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,3 +301,23 @@ TYPED_TEST(WeightedMean, Basic)
301301
{
302302
weightedMeanAllTest<TypeParam, float>(af::dim4(32, 30, 33, 17));
303303
}
304+
305+
TEST(WeightedMean, Broadacst)
306+
{
307+
float val = 0.5f;
308+
af::array a = af::randu(4096, 32);
309+
af::array w = af::constant(val, a.dims());
310+
af::array c = af::mean(a);
311+
af::array d = af::mean(a, w);
312+
313+
std::vector<float> hc(c.elements());
314+
std::vector<float> hd(d.elements());
315+
316+
c.host(hc.data());
317+
d.host(hd.data());
318+
319+
for(size_t i = 0; i < hc.size(); i++) {
320+
//C and D are the same because they are normalized by the sum of the weights.
321+
ASSERT_NEAR(hc[i], hd[i], 1E-5);
322+
}
323+
}

0 commit comments

Comments
 (0)