Skip to content

Commit 348c062

Browse files
committed
Fix size inference for *
1 parent 4b2585e commit 348c062

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

src/sizes.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,17 @@ function _infer_sizes(
309309
return !iszero(sizes.ndims[children_arr[i]])
310310
end
311311
if !isnothing(first_matrix)
312-
if sizes.ndims[children_arr[first(children_indices)]] == 0
313-
_add_size!(sizes, k, (1, 1))
312+
first_is_scalar =
313+
sizes.ndims[children_arr[first(children_indices)]] == 0
314+
last_is_scalar =
315+
sizes.ndims[children_arr[last(children_indices)]] == 0
316+
if first_is_scalar || last_is_scalar
317+
# `scalar * matrix` (or `matrix * scalar`) is
318+
# element-wise scaling, not matmul: result inherits
319+
# the matrix's shape.
320+
ix_mat =
321+
children_arr[children_indices[first_matrix]]
322+
_copy_size!(sizes, k, ix_mat)
314323
continue
315324
else
316325
_add_size!(

0 commit comments

Comments
 (0)