Model reflection

Since constructions of large models can be a tedious and error-prone process, Mill.jl provides reflectinmodel function that helps to automate it. The simplest definition accepts only one argument, a sample ds, and returns a compatible model:

julia> ds = BagNode(ProductNode((BagNode(randn(Float32, 4, 10),
                                         [1:2, 3:4, 5:5, 6:7, 8:10]),
                                 randn(Float32, 3, 5),
                                 BagNode(BagNode(randn(Float32, 2, 30),
                                                 [i:i+1 for i in 1:2:30]),
                                         [1:3, 4:6, 7:9, 10:12, 13:15]),
                                 randn(Float32, 2, 5))),
                    [1:1, 2:3, 4:5]);
julia> printtree(ds)BagNode 3 obs ╰── ProductNode 5 obs ├── BagNode 5 obs │ ╰── ArrayNode(4×10 Array with Float32 elements) 10 obs ├── ArrayNode(3×5 Array with Float32 elements) 5 obs ├── BagNode 5 obs │ ╰── BagNode 15 obs │ ╰── ArrayNode(2×30 Array with Float32 elements) 30 obs ╰── ArrayNode(2×5 Array with Float32 elements) 5 obs
julia> m = reflectinmodel(ds, d -> Dense(d, 2));
julia> printtree(m)BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 232 bytes ╰── ProductModel ↦ Dense(8 => 2) 2 arrays, 18 params, 160 bytes ├── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 232 bytes │ ╰── ArrayModel(Dense(4 => 2)) 2 arrays, 10 params, 128 bytes ├── ArrayModel(Dense(3 => 2)) 2 arrays, 8 params, 120 bytes ├── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 232 bytes │ ╰── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) 4 arrays, 16 params, 232 bytes │ ╰── ArrayModel(Dense(2 => 2)) 2 arrays, 6 params, 112 bytes ╰── ArrayModel(Dense(2 => 2)) 2 arrays, 6 params, 112 bytes
julia> m(ds)2×3 Matrix{Float32}: 0.0231536 0.383077 -1.84346 -1.30441 -0.0164543 -2.67783

The sample ds serves here as a specimen needed to specify a structure of the problem and calculate dimensions.

Optional arguments

To have better control over the topology, reflectinmodel accepts up to two more optional arguments and four keyword arguments:

  • The first optional argument expects a function that returns a layer (or a set of layers) given input dimension d (defaults to d -> Flux.Dense(d, 10)).
  • The second optional argument is a function returning aggregation function for BagModel nodes (defaults to BagCount ∘ SegmentedMeanMax).

Compare the following example to the previous one:

julia> m = reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMax);
julia> printtree(m)BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes ╰── ProductModel ↦ Dense(20 => 5, relu) 2 arrays, 105 params, 508 bytes ├── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes │ ╰── ArrayModel(Dense(4 => 5, relu)) 2 arrays, 25 params, 188 bytes ├── ArrayModel(Dense(3 => 5, relu)) 2 arrays, 20 params, 168 bytes ├── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes │ ╰── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) 3 arrays, 35 params, 268 bytes │ ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 148 bytes ╰── ArrayModel(Dense(2 => 5, relu)) 2 arrays, 15 params, 148 bytes
julia> m(ds)5×3 Matrix{Float32}: 0.0558981 0.258097 0.0 0.220776 0.321497 0.48561 0.0 0.0 0.412755 0.569792 1.77116 0.0457244 0.0 0.0 0.0

Keyword arguments

The reflectinmodel allows even further customization. To index into the sample (or model), we can use printtree(ds; trav=true) from HierarchicalUtils.jl that prints the sample together with identifiers of individual nodes:

using HierarchicalUtils
julia> printtree(ds; trav=true)BagNode [""]  3 obs
  ╰── ProductNode ["U"]  5 obs
        ├── BagNode ["Y"]  5 obs
        │     ╰── ArrayNode(4×10 Array with Float32 elements) ["a"]  10 obs
        ├── ArrayNode(3×5 Array with Float32 elements) ["c"]  5 obs
        ├── BagNode ["g"]  5 obs
        │     ╰── BagNode ["i"]  15 obs
        │           ╰── ArrayNode(2×30 Array with Float32 elements) ["j"]  30 obs
        ╰── ArrayNode(2×5 Array with Float32 elements) ["k"]  5 obs

These identifiers can be used to override the default construction functions. Note that the output, i.e. the last feed-forward network of the whole model is always tagged with an empty string "", which simplifies putting linear layer with an appropriate output dimension on the end. Dictionaries with these overrides can be passed in as keyword arguments:

  • fsm overrides constructions of feed-forward models
  • fsa overrides construction of aggregation functions.

For example to specify just the last feed forward neural network:

julia> reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMeanMax;
           fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12)))) |> printtreeBagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Chain(Dense(10 => 20, relu), Dense(20 => 12))  6 arrays, 482 params, 2.133 KiB
  ╰── ProductModel ↦ Dense(20 => 5, relu)  2 arrays, 105 params, 508 bytes
        ├── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu)  4 arrays, 65 params, 428 bytes
        │     ╰── ArrayModel(Dense(4 => 5, relu))  2 arrays, 25 params, 188 bytes
        ├── ArrayModel(Dense(3 => 5, relu))  2 arrays, 20 params, 168 bytes
        ├── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu)  4 arrays, 65 params, 428 bytes
        │     ╰── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu)  4 arrays, 65 params, 428 bytes
        │           ╰── ArrayModel(Dense(2 => 5, relu))  2 arrays, 15 params, 148 bytes
        ╰── ArrayModel(Dense(2 => 5, relu))  2 arrays, 15 params, 148 bytes

Both keyword arguments in action:

julia> reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMeanMax;
           fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12))),
           fsa = Dict("Y" => SegmentedMean, "g" => SegmentedPNorm)) |> printtreeBagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Chain(Dense(10 => 20, relu), Dense(20 => 12))  6 arrays, 482 params, 2.133 KiB
  ╰── ProductModel ↦ Dense(20 => 5, relu)  2 arrays, 105 params, 508 bytes
        ├── BagModel ↦ SegmentedMean(5) ↦ Dense(5 => 5, relu)  3 arrays, 35 params, 268 bytes
        │     ╰── ArrayModel(Dense(4 => 5, relu))  2 arrays, 25 params, 188 bytes
        ├── ArrayModel(Dense(3 => 5, relu))  2 arrays, 20 params, 168 bytes
        ├── BagModel ↦ SegmentedPNorm(5) ↦ Dense(5 => 5, relu)  5 arrays, 45 params, 388 bytes
        │     ╰── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu)  4 arrays, 65 params, 428 bytes
        │           ╰── ArrayModel(Dense(2 => 5, relu))  2 arrays, 15 params, 148 bytes
        ╰── ArrayModel(Dense(2 => 5, relu))  2 arrays, 15 params, 148 bytes

There are even more ways to modify the reflection behavior, see the reflectinmodel api reference.

Float precision

Mill.jl is built on top of Flux.jl, which by default uses 32-bit precision for model parameters:

julia> Dense(2, 2).weight |> eltypeFloat32

If you attempt to process Float64 data with model using lower precision, you get a warning:

julia> x = randn(2, 2)2×2 Matrix{Float64}:
 -1.13853  -0.721216
 -2.21835   0.0331891
julia> eltype(x)Float64
julia> m = Dense(2, 2)Dense(2 => 2) # 6 parameters
julia> m(x)┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 2) # 6 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59 2×2 Matrix{Float32}: -2.01866 -0.022319 1.11091 -0.0697713

Unless additional arguments are provided, reflectinmodel also instantiates all Dense layers using 32-bit precision:

julia> x = randn(Float32, 2, 2) |> ArrayNode2×2 ArrayNode{Matrix{Float32}, Nothing}:
  0.7041741  -0.45216417
 -0.3617286   1.4807765
julia> eltype(Mill.data(x))Float32
julia> m = reflectinmodel(x)ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 bytes
julia> m.m.weight |> eltypeFloat32

Because reflectinmodel evaluates (sub)models on parts of the input when building the model, if some Float64 values are passed in, the warning is shown during construction as well as during the evaluation:

julia> x = randn(2, 2) |> ArrayNode2×2 ArrayNode{Matrix{Float64}, Nothing}:
 -0.2617165161930634   -1.1169820345030674
  0.03973931389554643  -1.9057709581450335
julia> eltype(Mill.data(x))Float64
julia> m = reflectinmodel(x)┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 10) # 30 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59 ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 bytes
julia> m(x)┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 10) # 30 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59 10×2 Matrix{Float32}: 0.178307 -0.207178 -0.174379 -0.190231 0.126416 -0.390108 0.0908661 -0.823867 -0.0826094 -0.324438 0.0130696 0.325754 -0.071541 -0.666409 0.0525041 0.328029 0.0982923 -0.174111 0.178692 0.530059

To prevent this from happening, we recommend making sure that the same precision is used for input data and for reflectinmodel parameters. For example:

julia> x32 = randn(Float32, 2, 2) |> ArrayNode2×2 ArrayNode{Matrix{Float32}, Nothing}:
 1.1421611  0.8822706
 2.514736   0.036934458
julia> m = reflectinmodel(x32)ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 bytes
julia> x64 = randn(2, 2) |> ArrayNode2×2 ArrayNode{Matrix{Float64}, Nothing}: 0.3560285043427309 0.9366766658560893 0.6108965887101426 1.560294028451233
julia> m = reflectinmodel(x64, d -> f64(Dense(d, 5)), d -> f64(SegmentedMean(d)))ArrayModel(Dense(2 => 5)) 2 arrays, 15 params, 208 bytes

Functions Flux.f64 and Flux.f32 may come in handy.