Skip to content

GATConv doesn't work on hetero graphs with empty edge arrays or during backpropagation #637

@lenianiva

Description

@lenianiva

e.g. (The dropout defaults to a 64-bit float which causes more problems but can be easily fixed)

using GNNGraphs, GraphNeuralNetworks, NNlib, Flux

graph = GNNHeteroGraph(
    Dict(
        (:A, :a, :B) => ([1, 2], [3, 4]),
        (:B, :a, :A) => ([1], [2]),
        (:C, :a, :A) => (Int[], Int[]),
        (:A, :a, :C) => (Int[], Int[]),
        (:D, :a, :A) => (Int[], Int[]),
        (:E, :a, :A) => (Int[], Int[]),
        (:E, :a, :D) => (Int[], Int[]),
        (:D, :a, :E) => (Int[], Int[]),
    );
    num_nodes = Dict(:A => 3, :B => 5, :C => 7, :D => 0, :E => 0)
)
layer = HeteroGraphConv(
    [
        (src, edge, dst) => GATConv(4 => 4, NNlib.elu; dropout = Float32(0.25)) for
        (src, edge, dst) in keys(graph.edata)
    ];
)
layer2 = HeteroGraphConv(
    [
        (src, edge, dst) => GATConv(4 => 4, NNlib.elu; dropout = Float32(0.25)) for
        (src, edge, dst) in keys(graph.edata)
    ];
)

x = (
    A = rand(Float32, 4, 3),
    B = rand(Float32, 4, 5),
    C = rand(Float32, 4, 7),
    D = rand(Float32, 4, 0),
    E = rand(Float32, 4, 0),
)

x1 = layer(graph, x)
x2 = layer2(graph, x1)
@info "$x2"

g = Flux.gradient(x) do x
    y = layer(graph, x)
    sum(y[:A])
end

The error is

ERROR: LoadError: DimensionMismatch: arrays could not be broadcast to a common size: a has axes Base.OneTo(0) and b has axes Base.OneTo(4)
Stacktrace:
  [1] _bcs1
    @ ./broadcast.jl:535 [inlined]
  [2] _bcs
    @ ./broadcast.jl:529 [inlined]
  [3] broadcast_shape
    @ ./broadcast.jl:523 [inlined]
  [4] combine_axes
    @ ./broadcast.jl:504 [inlined]
  [5] _axes
    @ ./broadcast.jl:240 [inlined]
  [6] axes
    @ ./broadcast.jl:238 [inlined]
  [7] combine_axes
    @ ./broadcast.jl:505 [inlined]
  [8] instantiate
    @ ./broadcast.jl:313 [inlined]
  [9] materialize
    @ ./broadcast.jl:894 [inlined]
 [10] gat_conv(l::GATConv{Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}}, g::GNNHeteroGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, x::Tuple{Matrix{Float32}, Matrix{Float32}}, e::Nothing)
    @ GNNlib ~/.julia/packages/GNNlib/wxiDz/src/layers/conv.jl:147
 [11] GATConv
    @ ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/conv.jl:346 [inlined]
 [12] (::GATConv{Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}})(g::GNNHeteroGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, x::Tuple{Matrix{Float32}, Matrix{Float32}})
    @ GraphNeuralNetworks ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/conv.jl:346
 [13] (::GraphNeuralNetworks.var"#forw#forw##0"{GNNHeteroGraph{Tuple{T, T, Union{Nothing, AbstractVector}} where T<:(AbstractVector{<:Integer})}, @NamedTuple{A::Matrix{Float32}, B::Matrix{Float32}, C::Matrix{Float32}, D::Matrix{Float32}, E::Matrix{Float32}}})(l::GATConv{Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}}, et::Tuple{Symbol, Symbol, Symbol})
    @ GraphNeuralNetworks ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/heteroconv.jl:63
 [14] #60
    @ ./none:-1 [inlined]
 [15] iterate
    @ ./generator.jl:48 [inlined]
 [16] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Vector{GATConv{Flux.Dense{typeof(identity), Matrix{Float32}, Bool}, Nothing, Float32, Float32, Matrix{Float32}, typeof(elu), Vector{Float32}}}, Vector{Tuple{Symbol, Symbol, Symbol}}}}, GraphNeuralNetworks.var"#60#61"{GraphNeuralNetworks.var"#forw#forw##0"{GNNHeteroGraph{Tuple{T, T, Union{Nothing, AbstractVector}} where T<:(AbstractVector{<:Integer})}, @NamedTuple{A::Matrix{Float32}, B::Matrix{Float32}, C::Matrix{Float32}, D::Matrix{Float32}, E::Matrix{Float32}}}}})
    @ Base ./array.jl:790
 [17] (::HeteroGraphConv)(g::GNNHeteroGraph{Tuple{T, T, Union{Nothing, AbstractVector}} where T<:(AbstractVector{<:Integer})}, x::@NamedTuple{A::Matrix{Float32}, B::Matrix{Float32}, C::Matrix{Float32}, D::Matrix{Float32}, E::Matrix{Float32}})
    @ GraphNeuralNetworks ~/.julia/packages/GraphNeuralNetworks/XGIXF/src/layers/heteroconv.jl:65

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions