GNN in 16 lines
As has been mentioned in [4], multiple instance learning is an essential piece for implementing message passing inference over graphs, the main concept behind spatial Graph Neural Networks (GNNs). It is straightforward and quick to achieve this with Mill.jl. We begin with some dependencies:
using Mill, Flux, Graphs, StatisticsLet's assume a graph g, represented as a SimpleGraph from Graphs.jl
julia> g = SimpleGraph(10){10, 0} undirected simple Int64 graphjulia> for e in [(1, 2), (1, 3), (1, 4), (2, 4), (2, 5), (3, 4), (3, 5), (3, 6), (3, 8), (3, 10), (4, 5), (4, 6), (4, 9), (5, 7), (5, 8), (6, 5), (6, 7), (6, 8), (7, 8), (7, 10), (8, 9) ] add_edge!(g, e...) end
Furthermore, let's assume that each vertex is described by three features stored in a matrix X:
julia> X = ArrayNode(randn(Float32, 3, 10))3×10 ArrayNode{Matrix{Float32}, Nothing}: 0.2493349 -1.3772725 1.667343 … -1.23439 -0.21806878 -0.060704295 -2.5928798 1.6768072 0.6548603 -1.5255698 0.93775165 -0.49202558 0.5589707 0.66485745 0.22424777
We use ScatteredBags from Mill.jl to encode neighbors of each vertex. In other words, each vertex is described by a bag of its neighbors. This information is conveniently stored in fadjlist field of g, therefore the bags can be constructed as:
julia> b = ScatteredBags(g.fadjlist)ScatteredBags{Int64}([[2, 3, 4], [1, 4, 5], [1, 4, 5, 6, 8, 10], [1, 2, 3, 5, 6, 9], [2, 3, 4, 6, 7, 8], [3, 4, 5, 7, 8], [5, 6, 8, 10], [3, 5, 6, 7, 9], [4, 8], [3, 7]])
Finally, we create two models. First model called lift will pre-process the description of vertices to some latent space for message passing, and the second one will realize the message passing itself, which we will call mp:
julia> lift = reflectinmodel(X, d -> Dense(d, 4), SegmentedMean)ArrayModel(Dense(3 => 4)) 2 arrays, 16 params, 152 bytesjulia> U = lift(X)4×10 Matrix{Float32}: -0.790085 -0.399773 -0.405256 … 0.924469 0.807205 -0.948249 -0.00503209 -2.89086 2.48879 -0.950163 -0.72512 -1.29063 -0.760845 1.05121 -0.888108 1.35215 -0.690386 0.187755 0.306107 -0.241221 -0.071433 -0.478906 0.995806 -0.164386julia> mp = reflectinmodel(BagNode(U, b), d -> Dense(d, 3), SegmentedMean)BagModel ↦ SegmentedMean(3) ↦ Dense(3 => 3) 3 arrays, 15 params, 188 bytes ╰── ArrayModel(Dense(4 => 3)) 2 arrays, 15 params, 148 bytes
Notice that BagNode(U, b) now essentially encodes vertex features as well as the adjacency matrix. This also means that one step of message passing algorithm can be realized as:
julia> Y = mp(BagNode(U, b))3×10 Matrix{Float32}: 0.118788 0.0476876 0.203684 0.0165302 … -0.064877 0.436567 -0.328518 0.261469 0.114964 0.490808 0.137215 -0.241936 0.887545 -1.14205 0.382536 0.170967 0.732393 0.224735 -0.379021 1.27937 -1.77105
and it is differentiable, which can be verified by executing:
julia> gradient(m -> sum(sin.(m(BagNode(U, b)))), mp)((im = (m = (weight = Float32[1.24085 -0.33822134 1.2871495 -0.46675086; 1.9613303 -1.7043076 2.596578 -0.88312876; -3.2512262 1.2915746 -3.5684185 1.2741774], bias = Float32[4.9351535, 9.197704, -13.402116], σ = nothing),), a = (ψ = Float32[0.0, 0.0, 0.0],), bm = (weight = Float32[0.62676126 0.1574063 -2.2814016; 0.7607007 0.4133077 -2.293382; 0.88013124 0.6802884 -2.2223806], bias = Float32[9.753667, 8.551793, 7.023405], σ = nothing)),)
If we put everything together, the GNN implementation is implemented in the following 16 lines:
struct GNN{L, M, R}
    lift::L
    mp::M
    m::R
end
Flux.@layer :ignore GNN
function mpstep(m::GNN, U, bags, n)
    n == 0 && return(U)
    mpstep(m, m.mp(BagNode(U, bags)), bags, n - 1)
end
function (m::GNN)(g, X, n)
    U = m.lift(X)
    bags = Mill.ScatteredBags(g.fadjlist)
    o = mpstep(m, U, bags, n)
    m.m(vcat(mean(o, dims = 2), maximum(o, dims = 2)))
endAs it is the case with whole Mill.jl, even this graph neural network is properly integrated with Flux.jl ecosystem and suports automatic differentiation:
zd = 4
f(d) = Chain(Dense(d, zd, relu), Dense(zd, zd))
agg = SegmentedMeanMax
gnn = GNN(reflectinmodel(X, f, agg),
          BagModel(f(zd), agg(zd), f(2zd)),
          f(2zd))julia> gnn(g, X, 5)4×1 Matrix{Float32}: 0.00016222768 0.00021448887 0.00030102703 -0.00046899694julia> gradient(m -> m(g, X, 5) |> sum, gnn)((lift = (m = (layers = ((weight = Float32[0.023729375 0.0073043876 0.00507586; 0.0034009127 -0.00055816496 -0.001996559; 0.0074122837 -0.012157058 -0.0032464762; 0.019973766 -0.002466879 -0.008808151], bias = Float32[0.031564355, 0.007343511, -0.0030228072, -0.009473307], σ = nothing), (weight = Float32[0.0010020004 -0.004688316 0.025474805 0.040968303; 0.004139893 0.00063540856 0.037446007 0.060782257; -0.011153192 -0.01265391 -0.017427681 -0.028889604; -0.0030410632 -0.0043758284 0.013323379 0.021474218], bias = Float32[0.033730146, 0.064496905, -0.051303748, 0.01573436], σ = nothing)),),), mp = (im = (m = (layers = ((weight = Float32[-0.0343868 0.028065441 -0.022546805 -0.030715173; 0.044193078 -0.04215087 -0.008396544 -0.03416454; -0.017730065 0.028343473 -0.004356968 0.025194738; 0.0025157235 0.029422589 0.0075494736 -0.020511081], bias = Float32[-0.01040332, 0.11693167, 0.027748574, 0.06606675], σ = nothing), (weight = Float32[0.011344679 0.0293103 -0.00657038 0.0055761575; -0.0030732683 -0.02212625 -0.01653417 -0.0044274977; 0.008140642 -0.0019123852 0.016163364 -0.0073407902; 0.013540533 -0.016273316 0.015952319 -0.00794171], bias = Float32[0.02243815, 0.020000346, -0.29838583, -0.09716275], σ = nothing)),),), a = (fs = ((ψ = Float32[0.0, 0.0, 0.0, 0.0],), (ψ = Float32[0.0, 0.0, 0.0, 0.0],)),), bm = (layers = ((weight = Float32[-0.020415353 0.023660988 … -0.0028440235 0.007423374; -0.015537388 0.009754301 … -0.0022147452 -0.004587447; 0.01613906 -0.013099574 … -0.00021842607 -0.02126051; 0.0125365555 -0.0090761045 … 0.0013786387 -0.007008895], bias = Float32[-0.025294017, 0.16506173, -0.08604198, 0.16027108], σ = nothing), (weight = Float32[0.0032370002 0.007290289 -0.0015835441 -0.0907538; 0.010025288 0.024912367 0.037466925 -0.075727746; -0.010327265 -0.02448437 -0.013396049 0.045430623; 0.0025176767 0.0047642025 0.011162643 -0.0017723107], bias = Float32[-0.20535398, -0.11515407, -0.2677527, -0.15884492], σ = nothing)),)), m = (layers = ((weight = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; -0.012068051 0.011812808 … 0.0032108114 -0.0027346285; 0.0 0.0 … 0.0 0.0], bias = Float32[0.0, 0.0, 0.30443656, 0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0], bias = Float32[1.0, 1.0, 1.0, 1.0], σ = nothing)),)),)
The above implementation is surprisingly general, as it supports an arbitrarily rich description of vertices. For simplicity, we used only vectors in X, however, any Mill.jl hierarchy is applicable.
To put different weights on edges, one can use Weighted aggregation.