diff --git a/GNNGraphs/src/GNNGraphs.jl b/GNNGraphs/src/GNNGraphs.jl index bb35dcfcd..5a3aa5206 100644 --- a/GNNGraphs/src/GNNGraphs.jl +++ b/GNNGraphs/src/GNNGraphs.jl @@ -2,7 +2,7 @@ module GNNGraphs using SparseArrays import Graphs -using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, +using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree, has_self_loops, is_directed, induced_subgraph, has_edge import NearestNeighbors import NNlib @@ -11,7 +11,7 @@ import KrylovKit import ChainRulesCore as CRC using LinearAlgebra, Random, Statistics import MLUtils -using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like +using MLUtils: getobs, numobs, ones_like, zeros_like, fill_like, chunk, batch, rand_like using MLDataDevices: get_device, cpu_device, CPUDevice using Functors: @functor @@ -38,9 +38,9 @@ export GNNHeteroGraph, include("temporalsnapshotsgnngraph.jl") export TemporalSnapshotsGNNGraph, add_snapshot, - # add_snapshot!, +# add_snapshot!, remove_snapshot - # remove_snapshot! +# remove_snapshot! include("query.jl") include("gnnheterograph/query.jl") @@ -59,7 +59,7 @@ export adjacency_list, # from Graphs.jl adjacency_matrix, degree, - has_edge, + has_edge, has_isolated_nodes, has_self_loops, inneighbors, @@ -75,7 +75,7 @@ export add_nodes, negative_sample, rand_edge_split, remove_self_loops, - remove_edges, + remove_edges, coalesce, set_edge_weight, to_bidirected, @@ -117,7 +117,7 @@ export mldataset2gnngraph include("deprecations.jl") include("sampling.jl") -export NeighborLoader, sample_neighbors, - induced_subgraph # from Graphs.jl +export NeighborLoader, sample_neighbors, + induced_subgraph # from Graphs.jl end #module diff --git a/GNNGraphs/src/transform.jl b/GNNGraphs/src/transform.jl index c20579d20..14fe5b215 100644 --- a/GNNGraphs/src/transform.jl +++ b/GNNGraphs/src/transform.jl @@ -123,7 +123,8 @@ function remove_edges(g::GNNGraph{<:COO_T}, edges_to_remove::AbstractVector{<:In w = get_edge_weight(g) edata = g.edata - mask_to_keep = trues(length(s)) + # mask_to_keep = trues(length(s)) + mask_to_keep = MLUtils.fill_like(edges_to_remove, true) mask_to_keep[edges_to_remove] .= false