Model reflection
Since constructions of large models can be a tedious and error-prone process, Mill.jl
provides reflectinmodel
function that helps to automate it. The simplest definition accepts only one argument, a sample ds
, and returns a compatible model:
julia> ds = BagNode(ProductNode((BagNode(randn(Float32, 4, 10), [1:2, 3:4, 5:5, 6:7, 8:10]), randn(Float32, 3, 5), BagNode(BagNode(randn(Float32, 2, 30), [i:i+1 for i in 1:2:30]), [1:3, 4:6, 7:9, 10:12, 13:15]), randn(Float32, 2, 5))), [1:1, 2:3, 4:5]);
julia> printtree(ds)
BagNode 3 obs ╰── ProductNode 5 obs ├── BagNode 5 obs │ ╰── ArrayNode(4×10 Array with Float32 elements) 10 obs ├── ArrayNode(3×5 Array with Float32 elements) 5 obs ├── BagNode 5 obs │ ╰── BagNode 15 obs │ ╰── ArrayNode(2×30 Array with Float32 elements) 30 obs ╰── ArrayNode(2×5 Array with Float32 elements) 5 obs
julia> m = reflectinmodel(ds, d -> Dense(d, 2));
julia> printtree(m)
BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 232 bytes ╰── ProductModel ↦ Dense(8 => 2) 2 arrays, 18 params, 160 bytes ├── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 232 bytes │ ╰── ArrayModel(Dense(4 => 2)) 2 arrays, 10 params, 128 bytes ├── ArrayModel(Dense(3 => 2)) 2 arrays, 8 params, 120 bytes ├── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 232 bytes │ ╰── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 232 bytes │ ╰── ArrayModel(Dense(2 => 2)) 2 arrays, 6 params, 112 bytes ╰── ArrayModel(Dense(2 => 2)) 2 arrays, 6 params, 112 bytes
julia> m(ds)
2×3 Matrix{Float32}: 0.0231536 0.383077 -1.84346 -1.30441 -0.0164543 -2.67783
The sample ds
serves here as a specimen needed to specify a structure of the problem and calculate dimensions.
Optional arguments
To have better control over the topology, reflectinmodel
accepts up to two more optional arguments and four keyword arguments:
- The first optional argument expects a function that returns a layer (or a set of layers) given input dimension
d
(defaults tod -> Flux.Dense(d, 10)
). - The second optional argument is a function returning aggregation function for
BagModel
nodes (defaults toBagCount ∘ SegmentedMeanMax
).
Compare the following example to the previous one:
julia> m = reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMax);
julia> printtree(m)
BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes ╰── ProductModel ↦ Dense(20 => 5, relu) 2 arrays, 105 params, 508 bytes ├── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes │ ╰── ArrayModel(Dense(4 => 5, relu)) 2 arrays, 25 params, 188 bytes ├── ArrayModel(Dense(3 => 5, relu)) 2 arrays, 20 params, 168 bytes ├── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes │ ╰── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes │ ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 148 bytes ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 148 bytes
julia> m(ds)
5×3 Matrix{Float32}: 0.0558981 0.258097 0.0 0.220776 0.321497 0.48561 0.0 0.0 0.412755 0.569792 1.77116 0.0457244 0.0 0.0 0.0
Keyword arguments
The reflectinmodel
allows even further customization. To index into the sample (or model), we can use printtree(ds; trav=true)
from HierarchicalUtils.jl
that prints the sample together with identifiers of individual nodes:
using HierarchicalUtils
julia> printtree(ds; trav=true)
BagNode [""] 3 obs ╰── ProductNode ["U"] 5 obs ├── BagNode ["Y"] 5 obs │ ╰── ArrayNode(4×10 Array with Float32 elements) ["a"] 10 obs ├── ArrayNode(3×5 Array with Float32 elements) ["c"] 5 obs ├── BagNode ["g"] 5 obs │ ╰── BagNode ["i"] 15 obs │ ╰── ArrayNode(2×30 Array with Float32 elements) ["j"] 30 obs ╰── ArrayNode(2×5 Array with Float32 elements) ["k"] 5 obs
These identifiers can be used to override the default construction functions. Note that the output, i.e. the last feed-forward network of the whole model is always tagged with an empty string ""
, which simplifies putting linear layer with an appropriate output dimension on the end. Dictionaries with these overrides can be passed in as keyword arguments:
fsm
overrides constructions of feed-forward modelsfsa
overrides construction of aggregation functions.
For example to specify just the last feed forward neural network:
julia> reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMeanMax; fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12)))) |> printtree
BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Chain(Dense(10 => 20, relu), Dense(20 => 12)) 6 arrays, 482 params, 2.133 KiB ╰── ProductModel ↦ Dense(20 => 5, relu) 2 arrays, 105 params, 508 bytes ├── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu) 4 arrays, 65 params, 428 bytes │ ╰── ArrayModel(Dense(4 => 5, relu)) 2 arrays, 25 params, 188 bytes ├── ArrayModel(Dense(3 => 5, relu)) 2 arrays, 20 params, 168 bytes ├── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu) 4 arrays, 65 params, 428 bytes │ ╰── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu) 4 arrays, 65 params, 428 bytes │ ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 148 bytes ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 148 bytes
Both keyword arguments in action:
julia> reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMeanMax; fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12))), fsa = Dict("Y" => SegmentedMean, "g" => SegmentedPNorm)) |> printtree
BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Chain(Dense(10 => 20, relu), Dense(20 => 12)) 6 arrays, 482 params, 2.133 KiB ╰── ProductModel ↦ Dense(20 => 5, relu) 2 arrays, 105 params, 508 bytes ├── BagModel ↦ SegmentedMean(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes │ ╰── ArrayModel(Dense(4 => 5, relu)) 2 arrays, 25 params, 188 bytes ├── ArrayModel(Dense(3 => 5, relu)) 2 arrays, 20 params, 168 bytes ├── BagModel ↦ SegmentedPNorm(5) ↦ Dense(5 => 5, relu) 5 arrays, 45 params, 388 bytes │ ╰── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu) 4 arrays, 65 params, 428 bytes │ ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 148 bytes ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 148 bytes
There are even more ways to modify the reflection behavior, see the reflectinmodel
api reference.
Float
precision
Mill.jl
is built on top of Flux.jl
, which by default uses 32-bit precision for model parameters:
julia> Dense(2, 2).weight |> eltype
Float32
If you attempt to process Float64
data with model using lower precision, you get a warning:
julia> x = randn(2, 2)
2×2 Matrix{Float64}: -1.13853 -0.721216 -2.21835 0.0331891
julia> eltype(x)
Float64
julia> m = Dense(2, 2)
Dense(2 => 2) # 6 parameters
julia> m(x)
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 2) # 6 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59 2×2 Matrix{Float32}: -2.01866 -0.022319 1.11091 -0.0697713
Unless additional arguments are provided, reflectinmodel
also instantiates all Dense
layers using 32-bit precision:
julia> x = randn(Float32, 2, 2) |> ArrayNode
2×2 ArrayNode{Matrix{Float32}, Nothing}: 0.7041741 -0.45216417 -0.3617286 1.4807765
julia> eltype(Mill.data(x))
Float32
julia> m = reflectinmodel(x)
ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 bytes
julia> m.m.weight |> eltype
Float32
Because reflectinmodel
evaluates (sub)models on parts of the input when building the model, if some Float64
values are passed in, the warning is shown during construction as well as during the evaluation:
julia> x = randn(2, 2) |> ArrayNode
2×2 ArrayNode{Matrix{Float64}, Nothing}: -0.2617165161930634 -1.1169820345030674 0.03973931389554643 -1.9057709581450335
julia> eltype(Mill.data(x))
Float64
julia> m = reflectinmodel(x)
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 10) # 30 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59 ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 bytes
julia> m(x)
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 10) # 30 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59 10×2 Matrix{Float32}: 0.178307 -0.207178 -0.174379 -0.190231 0.126416 -0.390108 0.0908661 -0.823867 -0.0826094 -0.324438 0.0130696 0.325754 -0.071541 -0.666409 0.0525041 0.328029 0.0982923 -0.174111 0.178692 0.530059
To prevent this from happening, we recommend making sure that the same precision is used for input data and for reflectinmodel
parameters. For example:
julia> x32 = randn(Float32, 2, 2) |> ArrayNode
2×2 ArrayNode{Matrix{Float32}, Nothing}: 1.1421611 0.8822706 2.514736 0.036934458
julia> m = reflectinmodel(x32)
ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 bytes
julia> x64 = randn(2, 2) |> ArrayNode
2×2 ArrayNode{Matrix{Float64}, Nothing}: 0.3560285043427309 0.9366766658560893 0.6108965887101426 1.560294028451233
julia> m = reflectinmodel(x64, d -> f64(Dense(d, 5)), d -> f64(SegmentedMean(d)))
ArrayModel(Dense(2 => 5)) 2 arrays, 15 params, 208 bytes
Functions Flux.f64
and Flux.f32
may come in handy.