|
8 | 8 | ********************************************************/ |
9 | 9 |
|
10 | 10 | #include <af/dim4.hpp> |
| 11 | +#include <af/data.h> |
11 | 12 | #include <af/statistics.h> |
12 | 13 | #include <af/defines.h> |
13 | 14 | #include <err_common.hpp> |
@@ -91,23 +92,39 @@ af_err af_mean_weighted(af_array *out, const af_array in, const af_array weights |
91 | 92 | af_dtype wType = wInfo.getType(); |
92 | 93 |
|
93 | 94 | ARG_ASSERT(2, (wType==f32 || wType==f64)); /* verify that weights are non-complex real numbers */ |
94 | | - ARG_ASSERT(2, iInfo.dims() == wInfo.dims()); |
| 95 | + |
| 96 | + //FIXME: We should avoid additional copies |
| 97 | + af_array w = weights; |
| 98 | + if (iInfo.dims() != wInfo.dims()) { |
| 99 | + dim4 iDims = iInfo.dims(); |
| 100 | + dim4 wDims = wInfo.dims(); |
| 101 | + dim4 tDims(1,1,1,1); |
| 102 | + for (int i = 0; i < 4; i++) { |
| 103 | + ARG_ASSERT(2, wDims[i] == 1 || wDims[i] == iDims[i]); |
| 104 | + tDims[i] = iDims[i] / wDims[i]; |
| 105 | + } |
| 106 | + AF_CHECK(af_tile(&w, weights, tDims[0], tDims[1], tDims[2], tDims[3])); |
| 107 | + } |
95 | 108 |
|
96 | 109 | switch(iType) { |
97 | | - case f64: output = mean< double>(in, weights, dim); break; |
98 | | - case f32: output = mean< float >(in, weights, dim); break; |
99 | | - case s32: output = mean< float >(in, weights, dim); break; |
100 | | - case u32: output = mean< float >(in, weights, dim); break; |
101 | | - case s64: output = mean< double>(in, weights, dim); break; |
102 | | - case u64: output = mean< double>(in, weights, dim); break; |
103 | | - case s16: output = mean< float >(in, weights, dim); break; |
104 | | - case u16: output = mean< float >(in, weights, dim); break; |
105 | | - case u8: output = mean< float >(in, weights, dim); break; |
106 | | - case b8: output = mean< float >(in, weights, dim); break; |
107 | | - case c32: output = mean< cfloat>(in, weights, dim); break; |
108 | | - case c64: output = mean<cdouble>(in, weights, dim); break; |
| 110 | + case f64: output = mean< double>(in, w, dim); break; |
| 111 | + case f32: output = mean< float >(in, w, dim); break; |
| 112 | + case s32: output = mean< float >(in, w, dim); break; |
| 113 | + case u32: output = mean< float >(in, w, dim); break; |
| 114 | + case s64: output = mean< double>(in, w, dim); break; |
| 115 | + case u64: output = mean< double>(in, w, dim); break; |
| 116 | + case s16: output = mean< float >(in, w, dim); break; |
| 117 | + case u16: output = mean< float >(in, w, dim); break; |
| 118 | + case u8: output = mean< float >(in, w, dim); break; |
| 119 | + case b8: output = mean< float >(in, w, dim); break; |
| 120 | + case c32: output = mean< cfloat>(in, w, dim); break; |
| 121 | + case c64: output = mean<cdouble>(in, w, dim); break; |
109 | 122 | default : TYPE_ERROR(1, iType); |
110 | 123 | } |
| 124 | + |
| 125 | + if (w != weights) { |
| 126 | + AF_CHECK(af_release_array(w)); |
| 127 | + } |
111 | 128 | std::swap(*out, output); |
112 | 129 | } |
113 | 130 | CATCHALL; |
|
0 commit comments