Model reflection
Since constructions of large models can be a tedious and error-prone process, Mill.jl
provides reflectinmodel
function that helps to automate it. The simplest definition accepts only one argument, a sample ds
, and returns a compatible model:
julia> ds = BagNode(ProductNode((BagNode(randn(Float32, 4, 10), [1:2, 3:4, 5:5, 6:7, 8:10]), randn(Float32, 3, 5), BagNode(BagNode(randn(Float32, 2, 30), [i:i+1 for i in 1:2:30]), [1:3, 4:6, 7:9, 10:12, 13:15]), randn(Float32, 2, 5))), [1:1, 2:3, 4:5]);
julia> printtree(ds)
BagNode 3 obs, 160 bytes ╰── ProductNode 5 obs, 56 bytes ├── BagNode 5 obs, 144 bytes │ ╰── ArrayNode(4×10 Array with Float32 elements) 10 obs, 208 bytes ├── ArrayNode(3×5 Array with Float32 elements) 5 obs, 108 bytes ├── BagNode 5 obs, 152 bytes │ ╰── BagNode 15 obs, 304 bytes │ ╰── ArrayNode(2×30 Array with Float32 elements) 30 obs, 288 bytes ╰── ArrayNode(2×5 Array with Float32 elements) 5 obs, 88 bytes
julia> m = reflectinmodel(ds, d -> Dense(d, 2));
julia> printtree(m)
BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 224 bytes ╰── ProductModel ↦ Dense(8 => 2) 2 arrays, 18 params, 152 bytes ├── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 224 bytes │ ╰── ArrayModel(Dense(4 => 2)) 2 arrays, 10 params, 120 bytes ├── ArrayModel(Dense(3 => 2)) 2 arrays, 8 params, 112 bytes ├── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 224 bytes │ ╰── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 224 bytes │ ╰── ArrayModel(Dense(2 => 2)) 2 arrays, 6 params, 104 bytes ╰── ArrayModel(Dense(2 => 2)) 2 arrays, 6 params, 104 bytes
julia> m(ds)
2×3 Matrix{Float32}: 0.0231536 0.383077 -1.84346 -1.30441 -0.0164543 -2.67783
The sample ds
serves here as a specimen needed to specify a structure of the problem and calculate dimensions.
Optional arguments
To have better control over the topology, reflectinmodel
accepts up to two more optional arguments and four keyword arguments:
- The first optional argument expects a function that returns a layer (or a set of layers) given input dimension
d
(defaults tod -> Flux.Dense(d, 10)
). - The second optional argument is a function returning aggregation function for
BagModel
nodes (defaults toBagCount ∘ SegmentedMeanMax
).
Compare the following example to the previous one:
julia> m = reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMax);
julia> printtree(m)
BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 260 bytes ╰── ProductModel ↦ Dense(20 => 5, relu) 2 arrays, 105 params, 500 bytes ├── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 260 bytes │ ╰── ArrayModel(Dense(4 => 5, relu)) 2 arrays, 25 params, 180 bytes ├── ArrayModel(Dense(3 => 5, relu)) 2 arrays, 20 params, 160 bytes ├── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 260 bytes │ ╰── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 260 bytes │ ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 140 bytes ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 140 bytes
julia> m(ds)
5×3 Matrix{Float32}: 0.0558981 0.258097 0.0 0.220776 0.321497 0.48561 0.0 0.0 0.412755 0.569792 1.77116 0.0457244 0.0 0.0 0.0
Keyword arguments
The reflectinmodel
allows even further customization. To index into the sample (or model), we can use printtree(ds; trav=true)
from HierarchicalUtils.jl
that prints the sample together with identifiers of individual nodes:
using HierarchicalUtils
julia> printtree(ds; trav=true)
BagNode [""] 3 obs, 160 bytes ╰── ProductNode ["U"] 5 obs, 56 bytes ├── BagNode ["Y"] 5 obs, 144 bytes │ ╰── ArrayNode(4×10 Array with Float32 elements) ["a"] 10 obs, 208 bytes ├── ArrayNode(3×5 Array with Float32 elements) ["c"] 5 obs, 108 bytes ├── BagNode ["g"] 5 obs, 152 bytes │ ╰── BagNode ["i"] 15 obs, 304 bytes │ ╰── ArrayNode(2×30 Array with Float32 elements) ["j"] 30 obs, 288 bytes ╰── ArrayNode(2×5 Array with Float32 elements) ["k"] 5 obs, 88 bytes
These identifiers can be used to override the default construction functions. Note that the output, i.e. the last feed-forward network of the whole model is always tagged with an empty string ""
, which simplifies putting linear layer with an appropriate output dimension on the end. Dictionaries with these overrides can be passed in as keyword arguments:
fsm
overrides constructions of feed-forward modelsfsa
overrides construction of aggregation functions.
For example to specify just the last feed forward neural network:
julia> reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMeanMax; fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12)))) |> printtree
BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Chain(Dense(10 => 20, relu), Dense(20 => 12)) 6 arrays, 482 params, 2.117 KiB ╰── ProductModel ↦ Dense(20 => 5, relu) 2 arrays, 105 params, 500 bytes ├── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu) 4 arrays, 65 params, 420 bytes │ ╰── ArrayModel(Dense(4 => 5, relu)) 2 arrays, 25 params, 180 bytes ├── ArrayModel(Dense(3 => 5, relu)) 2 arrays, 20 params, 160 bytes ├── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu) 4 arrays, 65 params, 420 bytes │ ╰── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu) 4 arrays, 65 params, 420 bytes │ ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 140 bytes ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 140 bytes
Both keyword arguments in action:
julia> reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMeanMax; fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12))), fsa = Dict("Y" => SegmentedMean, "g" => SegmentedPNorm)) |> printtree
BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Chain(Dense(10 => 20, relu), Dense(20 => 12)) 6 arrays, 482 params, 2.117 KiB ╰── ProductModel ↦ Dense(20 => 5, relu) 2 arrays, 105 params, 500 bytes ├── BagModel ↦ SegmentedMean(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 260 bytes │ ╰── ArrayModel(Dense(4 => 5, relu)) 2 arrays, 25 params, 180 bytes ├── ArrayModel(Dense(3 => 5, relu)) 2 arrays, 20 params, 160 bytes ├── BagModel ↦ SegmentedPNorm(5) ↦ Dense(5 => 5, relu) 5 arrays, 45 params, 380 bytes │ ╰── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu) 4 arrays, 65 params, 420 bytes │ ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 140 bytes ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 140 bytes
There are even more ways to modify the reflection behavior, see the reflectinmodel
api reference.
Float
precision
Mill.jl
is built on top of Flux.jl
, which by default uses 32-bit precision for model parameters:
julia> Dense(2, 2).weight |> eltype
Float32
If you attempt to process Float64
data with model using lower precision, you get a warning:
julia> x = randn(2, 2)
2×2 Matrix{Float64}: -1.13853 -0.721216 -2.21835 0.0331891
julia> eltype(x)
Float64
julia> m = Dense(2, 2)
Dense(2 => 2) # 6 parameters
julia> m(x)
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 2) # 6 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60 2×2 Matrix{Float32}: -2.01866 -0.022319 1.11091 -0.0697713
Unless additional arguments are provided, reflectinmodel
also instantiates all Dense
layers using 32-bit precision:
julia> x = randn(Float32, 2, 2) |> ArrayNode
2×2 ArrayNode{Matrix{Float32}, Nothing}: 0.7041741 -0.45216417 -0.3617286 1.4807765
julia> eltype(Mill.data(x))
Float32
julia> m = reflectinmodel(x)
ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 200 bytes
julia> m.m.weight |> eltype
Float32
Because reflectinmodel
evaluates (sub)models on parts of the input when building the model, if some Float64
values are passed in, the warning is shown during construction as well as during the evaluation:
julia> x = randn(2, 2) |> ArrayNode
2×2 ArrayNode{Matrix{Float64}, Nothing}: -0.2617165161930634 -1.1169820345030674 0.03973931389554643 -1.9057709581450335
julia> eltype(Mill.data(x))
Float64
julia> m = reflectinmodel(x)
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 10) # 30 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60 ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 200 bytes
julia> m(x)
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 10) # 30 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/HBF2N/src/layers/stateless.jl:60 10×2 Matrix{Float32}: 0.178307 -0.207178 -0.174379 -0.190231 0.126416 -0.390108 0.0908661 -0.823867 -0.0826094 -0.324438 0.0130696 0.325754 -0.071541 -0.666409 0.0525041 0.328029 0.0982923 -0.174111 0.178692 0.530059
To prevent this from happening, we recommend making sure that the same precision is used for input data and for reflectinmodel
parameters. For example:
julia> x32 = randn(Float32, 2, 2) |> ArrayNode
2×2 ArrayNode{Matrix{Float32}, Nothing}: 1.1421611 0.8822706 2.514736 0.036934458
julia> m = reflectinmodel(x32)
ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 200 bytes
julia> x64 = randn(2, 2) |> ArrayNode
2×2 ArrayNode{Matrix{Float64}, Nothing}: 0.3560285043427309 0.9366766658560893 0.6108965887101426 1.560294028451233
julia> m = reflectinmodel(x64, d -> f64(Dense(d, 5)), d -> f64(SegmentedMean(d)))
ArrayModel(Dense(2 => 5)) 2 arrays, 15 params, 200 bytes
Functions Flux.f64
and Flux.f32
may come in handy.