GNN in 16 lines
As has been mentioned in [4], multiple instance learning is an essential piece for implementing message passing inference over graphs, the main concept behind spatial Graph Neural Networks (GNNs). It is straightforward and quick to achieve this with Mill.jl
. We begin with some dependencies:
using Mill, Flux, Graphs, Statistics
Let's assume a graph g
, represented as a SimpleGraph
from Graphs.jl
julia> g = SimpleGraph(10)
{10, 0} undirected simple Int64 graph
julia> for e in [(1, 2), (1, 3), (1, 4), (2, 4), (2, 5), (3, 4), (3, 5), (3, 6), (3, 8), (3, 10), (4, 5), (4, 6), (4, 9), (5, 7), (5, 8), (6, 5), (6, 7), (6, 8), (7, 8), (7, 10), (8, 9) ] add_edge!(g, e...) end
Furthermore, let's assume that each vertex is described by three features stored in a matrix X
:
julia> X = ArrayNode(randn(Float32, 3, 10))
3×10 ArrayNode{Matrix{Float32}, Nothing}: 0.2493349 -1.3772725 1.667343 … -1.23439 -0.21806878 -0.060704295 -2.5928798 1.6768072 0.6548603 -1.5255698 0.93775165 -0.49202558 0.5589707 0.66485745 0.22424777
We use ScatteredBags
from Mill.jl
to encode neighbors of each vertex. In other words, each vertex is described by a bag of its neighbors. This information is conveniently stored in fadjlist
field of g
, therefore the bags can be constructed as:
julia> b = ScatteredBags(g.fadjlist)
ScatteredBags{Int64}([[2, 3, 4], [1, 4, 5], [1, 4, 5, 6, 8, 10], [1, 2, 3, 5, 6, 9], [2, 3, 4, 6, 7, 8], [3, 4, 5, 7, 8], [5, 6, 8, 10], [3, 5, 6, 7, 9], [4, 8], [3, 7]])
Finally, we create two models. First model called lift
will pre-process the description of vertices to some latent space for message passing, and the second one will realize the message passing itself, which we will call mp
:
julia> lift = reflectinmodel(X, d -> Dense(d, 4), SegmentedMean)
ArrayModel(Dense(3 => 4)) 2 arrays, 16 params, 152 bytes
julia> U = lift(X)
4×10 Matrix{Float32}: -0.790085 -0.399773 -0.405256 … 0.924469 0.807205 -0.948249 -0.00503209 -2.89086 2.48879 -0.950163 -0.72512 -1.29063 -0.760845 1.05121 -0.888108 1.35215 -0.690386 0.187755 0.306107 -0.241221 -0.071433 -0.478906 0.995806 -0.164386
julia> mp = reflectinmodel(BagNode(U, b), d -> Dense(d, 3), SegmentedMean)
BagModel ↦ SegmentedMean(3) ↦ Dense(3 => 3) 3 arrays, 15 params, 188 bytes ╰── ArrayModel(Dense(4 => 3)) 2 arrays, 15 params, 148 bytes
Notice that BagNode(U, b)
now essentially encodes vertex features as well as the adjacency matrix. This also means that one step of message passing algorithm can be realized as:
julia> Y = mp(BagNode(U, b))
3×10 Matrix{Float32}: 0.118788 0.0476876 0.203684 0.0165302 … -0.064877 0.436567 -0.328518 0.261469 0.114964 0.490808 0.137215 -0.241936 0.887545 -1.14205 0.382536 0.170967 0.732393 0.224735 -0.379021 1.27937 -1.77105
and it is differentiable, which can be verified by executing:
julia> gradient(m -> sum(sin.(m(BagNode(U, b)))), mp)
((im = (m = (weight = Float32[1.24085 -0.33822134 1.2871495 -0.46675086; 1.9613303 -1.7043076 2.596578 -0.88312876; -3.2512262 1.2915746 -3.5684185 1.2741774], bias = Float32[4.9351535, 9.197704, -13.402116], σ = nothing),), a = (ψ = Float32[0.0, 0.0, 0.0],), bm = (weight = Float32[0.62676126 0.1574063 -2.2814016; 0.7607007 0.4133077 -2.293382; 0.88013124 0.6802884 -2.2223806], bias = Float32[9.753667, 8.551793, 7.023405], σ = nothing)),)
If we put everything together, the GNN implementation is implemented in the following 16 lines:
struct GNN{L, M, R}
lift::L
mp::M
m::R
end
Flux.@layer :ignore GNN
function mpstep(m::GNN, U, bags, n)
n == 0 && return(U)
mpstep(m, m.mp(BagNode(U, bags)), bags, n - 1)
end
function (m::GNN)(g, X, n)
U = m.lift(X)
bags = Mill.ScatteredBags(g.fadjlist)
o = mpstep(m, U, bags, n)
m.m(vcat(mean(o, dims = 2), maximum(o, dims = 2)))
end
As it is the case with whole Mill.jl
, even this graph neural network is properly integrated with Flux.jl
ecosystem and suports automatic differentiation:
zd = 4
f(d) = Chain(Dense(d, zd, relu), Dense(zd, zd))
agg = SegmentedMeanMax
gnn = GNN(reflectinmodel(X, f, agg),
BagModel(f(zd), agg(zd), f(2zd)),
f(2zd))
julia> gnn(g, X, 5)
4×1 Matrix{Float32}: 0.00016222768 0.00021448887 0.00030102703 -0.00046899694
julia> gradient(m -> m(g, X, 5) |> sum, gnn)
((lift = (m = (layers = ((weight = Float32[0.023729375 0.0073043876 0.00507586; 0.0034009127 -0.00055816496 -0.001996559; 0.0074122837 -0.012157058 -0.0032464762; 0.019973766 -0.002466879 -0.008808151], bias = Float32[0.031564355, 0.007343511, -0.0030228072, -0.009473307], σ = nothing), (weight = Float32[0.0010020004 -0.004688316 0.025474805 0.040968303; 0.004139893 0.00063540856 0.037446007 0.060782257; -0.011153192 -0.01265391 -0.017427681 -0.028889604; -0.0030410632 -0.0043758284 0.013323379 0.021474218], bias = Float32[0.033730146, 0.064496905, -0.051303748, 0.01573436], σ = nothing)),),), mp = (im = (m = (layers = ((weight = Float32[-0.0343868 0.028065441 -0.022546805 -0.030715173; 0.044193078 -0.04215087 -0.008396544 -0.03416454; -0.017730065 0.028343473 -0.004356968 0.025194738; 0.0025157235 0.029422589 0.0075494736 -0.020511081], bias = Float32[-0.01040332, 0.11693167, 0.027748574, 0.06606675], σ = nothing), (weight = Float32[0.011344679 0.0293103 -0.00657038 0.0055761575; -0.0030732683 -0.02212625 -0.01653417 -0.0044274977; 0.008140642 -0.0019123852 0.016163364 -0.0073407902; 0.013540533 -0.016273316 0.015952319 -0.00794171], bias = Float32[0.02243815, 0.020000346, -0.29838583, -0.09716275], σ = nothing)),),), a = (fs = ((ψ = Float32[0.0, 0.0, 0.0, 0.0],), (ψ = Float32[0.0, 0.0, 0.0, 0.0],)),), bm = (layers = ((weight = Float32[-0.020415353 0.023660988 … -0.0028440235 0.007423374; -0.015537388 0.009754301 … -0.0022147452 -0.004587447; 0.01613906 -0.013099574 … -0.00021842607 -0.02126051; 0.0125365555 -0.0090761045 … 0.0013786387 -0.007008895], bias = Float32[-0.025294017, 0.16506173, -0.08604198, 0.16027108], σ = nothing), (weight = Float32[0.0032370002 0.007290289 -0.0015835441 -0.0907538; 0.010025288 0.024912367 0.037466925 -0.075727746; -0.010327265 -0.02448437 -0.013396049 0.045430623; 0.0025176767 0.0047642025 0.011162643 -0.0017723107], bias = Float32[-0.20535398, -0.11515407, -0.2677527, -0.15884492], σ = nothing)),)), m = (layers = ((weight = Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; -0.012068051 0.011812808 … 0.0032108114 -0.0027346285; 0.0 0.0 … 0.0 0.0], bias = Float32[0.0, -0.0, 0.30443656, -0.0], σ = nothing), (weight = Float32[0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0; 0.0 0.0 0.0006856819 0.0], bias = Fill(1.0f0, 4), σ = nothing)),)),)
The above implementation is surprisingly general, as it supports an arbitrarily rich description of vertices. For simplicity, we used only vectors in X
, however, any Mill.jl
hierarchy is applicable.
To put different weights on edges, one can use Weighted aggregation.