Skip to content

Commit 1fa78a7

Browse files
committed
Fixes
1 parent 1d1f2e6 commit 1fa78a7

1 file changed

Lines changed: 43 additions & 5 deletions

File tree

src/coloring_compat.jl

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,53 @@ function _hessian_color_preprocess(
9999
algo = SparseMatrixColorings.GreedyColoringAlgorithm(; decompression=:substitution)
100100
tree_result = SparseMatrixColorings.coloring(S, problem, algo)
101101

102-
# Convert back to global indices
103-
for k in eachindex(I)
104-
I[k] = local_indices[I[k]]
105-
J[k] = local_indices[J[k]]
102+
# Reconstruct I and J from the tree structure (matching original _indirect_recover_structure)
103+
# First add all diagonal elements
104+
N = length(local_indices)
105+
106+
# Count off-diagonal elements from tree structure
107+
(; reverse_bfs_orders, tree_edge_indices, nt) = tree_result
108+
nnz_offdiag = 0
109+
for tree_idx in 1:nt
110+
first = tree_edge_indices[tree_idx]
111+
last = tree_edge_indices[tree_idx + 1] - 1
112+
nnz_offdiag += (last - first + 1)
113+
end
114+
115+
I_new = Vector{Int}(undef, N + nnz_offdiag)
116+
J_new = Vector{Int}(undef, N + nnz_offdiag)
117+
k = 0
118+
119+
# Add all diagonal elements
120+
for i in 1:N
121+
k += 1
122+
I_new[k] = local_indices[i]
123+
J_new[k] = local_indices[i]
106124
end
107125

126+
# Then add off-diagonal elements from the tree structure
127+
for tree_idx in 1:nt
128+
first = tree_edge_indices[tree_idx]
129+
last = tree_edge_indices[tree_idx + 1] - 1
130+
for pos in first:last
131+
(i_local, j_local) = reverse_bfs_orders[pos]
132+
# Convert from local to global indices and normalize (lower triangle)
133+
i_global = local_indices[i_local]
134+
j_global = local_indices[j_local]
135+
if j_global > i_global
136+
i_global, j_global = j_global, i_global
137+
end
138+
k += 1
139+
I_new[k] = i_global
140+
J_new[k] = j_global
141+
end
142+
end
143+
144+
@assert k == length(I_new)
145+
108146
# Wrap result with local_indices
109147
result = ColoringResult(tree_result, local_indices)
110-
return I, J, result
148+
return I_new, J_new, result
111149
end
112150

113151
"""

0 commit comments

Comments
 (0)