GNN in 16 lines

As has been mentioned in [4], multiple instance learning is an essential piece for implementing message passing inference over graphs, the main concept behind spatial Graph Neural Networks (GNNs). It is straightforward and quick to achieve this with Mill.jl. We begin with some dependencies:

using Mill, Flux, Graphs, Statistics

Let's assume a graph g, represented as a SimpleGraph from Graphs.jl

julia> g = SimpleGraph(10){10, 0} undirected simple Int64 graph
julia> for e in [(1, 2), (1, 3), (1, 4), (2, 4), (2, 5), (3, 4), (3, 5), (3, 6), (3, 8), (3, 10), (4, 5), (4, 6), (4, 9), (5, 7), (5, 8), (6, 5), (6, 7), (6, 8), (7, 8), (7, 10), (8, 9) ] add_edge!(g, e...) end

Furthermore, let's assume that each vertex is described by three features stored in a matrix X:

julia> X = ArrayNode(randn(Float32, 3, 10))3×10 ArrayNode{Matrix{Float32}, Nothing}:
  0.2493349    -1.3772725   1.667343   …  -1.23439     -0.21806878
 -0.060704295  -2.5928798   1.6768072      0.6548603   -1.5255698
  0.93775165   -0.49202558  0.5589707      0.66485745   0.22424777

We use ScatteredBags from Mill.jl to encode neighbors of each vertex. In other words, each vertex is described by a bag of its neighbors. This information is conveniently stored in fadjlist field of g, therefore the bags can be constructed as:

julia> b = ScatteredBags(g.fadjlist)ScatteredBags{Int64}([[2, 3, 4], [1, 4, 5], [1, 4, 5, 6, 8, 10], [1, 2, 3, 5, 6, 9], [2, 3, 4, 6, 7, 8], [3, 4, 5, 7, 8], [5, 6, 8, 10], [3, 5, 6, 7, 9], [4, 8], [3, 7]])

Finally, we create two models. First model called lift will pre-process the description of vertices to some latent space for message passing, and the second one will realize the message passing itself, which we will call mp:

julia> lift = reflectinmodel(X, d -> Dense(d, 4), SegmentedMean)ArrayModel(Dense(3 => 4))  2 arrays, 16 params, 152 bytes
julia> U = lift(X)4×10 Matrix{Float32}: -0.790085 -0.399773 -0.405256 … 0.924469 0.807205 -0.948249 -0.00503209 -2.89086 2.48879 -0.950163 -0.72512 -1.29063 -0.760845 1.05121 -0.888108 1.35215 -0.690386 0.187755 0.306107 -0.241221 -0.071433 -0.478906 0.995806 -0.164386
julia> mp = reflectinmodel(BagNode(U, b), d -> Dense(d, 3), SegmentedMean)BagModel ↦ SegmentedMean(3) ↦ Dense(3 => 3) 3 arrays, 15 params, 188 bytes ╰── ArrayModel(Dense(4 => 3)) 2 arrays, 15 params, 148 bytes

Notice that BagNode(U, b) now essentially encodes vertex features as well as the adjacency matrix. This also means that one step of message passing algorithm can be realized as:

julia> Y = mp(BagNode(U, b))3×10 Matrix{Float32}:
 0.118788  0.0476876  0.203684  0.0165302  …  -0.064877  0.436567  -0.328518
 0.261469  0.114964   0.490808  0.137215      -0.241936  0.887545  -1.14205
 0.382536  0.170967   0.732393  0.224735      -0.379021  1.27937   -1.77105

and it is differentiable, which can be verified by executing:

julia> gradient(m -> sum(sin.(m(BagNode(U, b)))), mp)((im = (m = (weight = Float32[1.24085 -0.33822134 1.2871495 -0.46675086; 1.9613303 -1.7043076 2.596578 -0.88312876; -3.2512262 1.2915746 -3.5684185 1.2741774], bias = Float32[4.9351535, 9.197704, -13.402116], σ = nothing),), a = (ψ = Float32[0.0, 0.0, 0.0],), bm = (weight = Float32[0.62676126 0.1574063 -2.2814016; 0.7607007 0.4133077 -2.293382; 0.88013124 0.6802884 -2.2223806], bias = Float32[9.753667, 8.551793, 7.023405], σ = nothing)),)

If we put everything together, the GNN implementation is implemented in the following 16 lines:

struct GNN{L, M, R}
    lift::L
    mp::M
    m::R
end

Flux.@layer :ignore GNN

function mpstep(m::GNN, U, bags, n)
    n == 0 && return(U)
    mpstep(m, m.mp(BagNode(U, bags)), bags, n - 1)
end

function (m::GNN)(g, X, n)
    U = m.lift(X)
    bags = Mill.ScatteredBags(g.fadjlist)
    o = mpstep(m, U, bags, n)
    m.m(vcat(mean(o, dims = 2), maximum(o, dims = 2)))
end

As it is the case with whole Mill.jl, even this graph neural network is properly integrated with Flux.jl ecosystem and suports automatic differentiation:

zd = 4
f(d) = Chain(Dense(d, zd, relu), Dense(zd, zd))
agg = SegmentedMeanMax
gnn = GNN(reflectinmodel(X, f, agg),
          BagModel(f(zd), agg(zd), f(2zd)),
          f(2zd))
julia> gnn(g, X, 5)4×1 Matrix{Float32}:
  0.00016222768
  0.00021448887
  0.00030102703
 -0.00046899694
julia> gradient(m -> m(g, X, 5) |> sum, gnn)((lift = (m = (layers = ((weight = Float32[0.023729375 0.0073043876 0.00507586; 0.0034009127 -0.00055816496 -0.001996559; 0.0074122837 -0.012157058 -0.0032464762; 0.019973766 -0.002466879 -0.008808151], bias = Float32[0.031564355, 0.007343511, -0.0030228072, -0.009473307], σ = nothing), (weight = Float32[0.0010020004 -0.004688316 0.025474805 0.040968303; 0.004139893 0.00063540856 0.037446007 0.060782257; -0.011153192 -0.01265391 -0.017427681 -0.028889604; -0.0030410632 -0.0043758284 0.013323379 0.021474218], bias = Float32[0.033730146, 0.064496905, -0.051303748, 0.01573436], σ = nothing)),),), mp = (im = (m = (layers = ((weight = Float32[-0.0343868 0.028065441 -0.022546805 -0.030715173; 0.044193078 -0.04215087 -0.008396544 -0.03416454; -0.017730065 0.028343473 -0.004356968 0.025194738; 0.0025157235 0.029422589 0.0075494736 -0.020511081], bias = Float32[-0.01040332, 0.11693167, 0.027748574, 0.06606675], σ = nothing), (weight = Float32[0.011344679 0.0293103 -0.00657038 0.0055761575; -0.0030732683 -0.02212625 -0.01653417 -0.0044274977; 0.008140642 -0.0019123852 0.016163364 -0.0073407902; 0.013540533 -0.016273316 0.015952319 -0.00794171], bias = Float32[0.02243815, 0.020000346, -0.29838583, -0.09716275], σ = nothing)),),), a = (fs = ((ψ = Float32[0.0, 0.0, 0.0, 0.0],), (ψ = Float32[0.0, 0.0, 0.0, 0.0],)),), bm = (layers = ((weight = Float32[-0.020415353 0.023660988 … -0.0028440235 0.007423374; -0.015537388 0.009754301 … -0.0022147452 -0.004587447; 0.01613906 -0.013099574 … -0.00021842607 -0.02126051; 0.0125365555 -0.0090761045 … 0.0013786387 -0.007008895], bias = Float32[-0.025294017, 0.16506173, -0.08604198, 0.16027108], σ = nothing), (weight = Float32[0.0032370002 0.007290289 -0.0015835441 -0.0907538; 0.010025288 0.024912367 0.037466925 -0.075727746; -0.010327265 -0.02448437 -0.013396049 0.045430623; 0.0025176767 0.0047642025 0.011162643 -0.0017723107], bias = Float32[-0.20535398, -0.11515407, -0.2677527, -0.15884492], σ = nothing)),)), m = (layers = ((weight = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; -0.012068051 0.011812808 … 0.0032108114 -0.0027346285; 0.0 0.0 … 0.0 0.0], bias = Float32[0.0, -0.0, 0.30443656, -0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0], bias = Fill(1.0f0, 4), σ = nothing)),)),)

The above implementation is surprisingly general, as it supports an arbitrarily rich description of vertices. For simplicity, we used only vectors in X, however, any Mill.jl hierarchy is applicable.

To put different weights on edges, one can use Weighted aggregation.