We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4b2585e commit 348c062Copy full SHA for 348c062
1 file changed
src/sizes.jl
@@ -309,8 +309,17 @@ function _infer_sizes(
309
return !iszero(sizes.ndims[children_arr[i]])
310
end
311
if !isnothing(first_matrix)
312
- if sizes.ndims[children_arr[first(children_indices)]] == 0
313
- _add_size!(sizes, k, (1, 1))
+ first_is_scalar =
+ sizes.ndims[children_arr[first(children_indices)]] == 0
314
+ last_is_scalar =
315
+ sizes.ndims[children_arr[last(children_indices)]] == 0
316
+ if first_is_scalar || last_is_scalar
317
+ # `scalar * matrix` (or `matrix * scalar`) is
318
+ # element-wise scaling, not matmul: result inherits
319
+ # the matrix's shape.
320
+ ix_mat =
321
+ children_arr[children_indices[first_matrix]]
322
+ _copy_size!(sizes, k, ix_mat)
323
continue
324
else
325
_add_size!(
0 commit comments