Model reflection

Since constructions of large models can be a tedious and error-prone process, Mill.jl provides reflectinmodel function that helps to automate it. The simplest definition accepts only one argument, a sample ds, and returns a compatible model:

julia> ds = BagNode(ProductNode((BagNode(randn(Float32, 4, 10),
                                         [1:2, 3:4, 5:5, 6:7, 8:10]),
                                 randn(Float32, 3, 5),
                                 BagNode(BagNode(randn(Float32, 2, 30),
                                                 [i:i+1 for i in 1:2:30]),
                                         [1:3, 4:6, 7:9, 10:12, 13:15]),
                                 randn(Float32, 2, 5))),
                    [1:1, 2:3, 4:5]);
julia> printtree(ds)BagNode # 3 obs, 160 bytes ╰── ProductNode # 5 obs, 56 bytes ├── BagNode # 5 obs, 144 bytes │ ╰── ArrayNode(4×10 Array with Float32 elements) # 10 obs, 208 bytes ├── ArrayNode(3×5 Array with Float32 elements) # 5 obs, 108 bytes ├── BagNode # 5 obs, 152 bytes │ ╰── BagNode # 15 obs, 304 bytes │ ╰── ArrayNode(2×30 Array with Float32 elements) # 30 obs, 288 bytes ╰── ArrayNode(2×5 Array with Float32 elements) # 5 obs, 88 bytes
julia> m = reflectinmodel(ds, d -> Dense(d, 2));
julia> printtree(m)BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) # 4 arrays, 16 params, 224 bytes ╰── ProductModel ↦ Dense(8 => 2) # 2 arrays, 18 params, 152 bytes ├── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) # 4 arrays, 16 params, 224 bytes │ ╰── ArrayModel(Dense(4 => 2)) # 2 arrays, 10 params, 120 bytes ├── ArrayModel(Dense(3 => 2)) # 2 arrays, 8 params, 112 bytes ├── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) # 4 arrays, 16 params, 224 bytes │ ╰── BagModel ↦ BagCount([SegmentedMean(2); SegmentedMax(2)]) ↦ Dense(5 => 2) # 4 arrays, 16 params, 224 bytes │ ╰── ArrayModel(Dense(2 => 2)) # 2 arrays, 6 params, 104 bytes ╰── ArrayModel(Dense(2 => 2)) # 2 arrays, 6 params, 104 bytes
julia> m(ds)2×3 Matrix{Float32}: 0.0231536 0.383077 -1.84346 -1.30441 -0.0164543 -2.67783

The sample ds serves here as a specimen needed to specify a structure of the problem and calculate dimensions.

Optional arguments

To have better control over the topology, reflectinmodel accepts up to two more optional arguments and four keyword arguments:

  • The first optional argument expects a function that returns a layer (or a set of layers) given input dimension d (defaults to d -> Flux.Dense(d, 10)).
  • The second optional argument is a function returning aggregation function for BagModel nodes (defaults to BagCount ∘ SegmentedMeanMax).

Compare the following example to the previous one:

julia> m = reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMax);
julia> printtree(m)BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) # 3 arrays, 35 params, 260 bytes ╰── ProductModel ↦ Dense(20 => 5, relu) # 2 arrays, 105 params, 500 bytes ├── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) # 3 arrays, 35 params, 260 bytes │ ╰── ArrayModel(Dense(4 => 5, relu)) # 2 arrays, 25 params, 180 bytes ├── ArrayModel(Dense(3 => 5, relu)) # 2 arrays, 20 params, 160 bytes ├── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) # 3 arrays, 35 params, 260 bytes │ ╰── BagModel ↦ SegmentedMax(5) ↦ Dense(5 => 5, relu) # 3 arrays, 35 params, 260 bytes │ ╰── ArrayModel(Dense(2 => 5, relu)) # 2 arrays, 15 params, 140 bytes ╰── ArrayModel(Dense(2 => 5, relu)) # 2 arrays, 15 params, 140 bytes
julia> m(ds)5×3 Matrix{Float32}: 0.0558981 0.258097 0.0 0.220776 0.321497 0.48561 0.0 0.0 0.412755 0.569792 1.77116 0.0457244 0.0 0.0 0.0

Keyword arguments

The reflectinmodel allows even further customization. To index into the sample (or model), we can use printtree(ds; trav=true) from HierarchicalUtils.jl that prints the sample together with identifiers of individual nodes:

using HierarchicalUtils
julia> printtree(ds; trav=true)BagNode [""]  # 3 obs, 160 bytes
  ╰── ProductNode ["U"]  # 5 obs, 56 bytes
        ├── BagNode ["Y"]  # 5 obs, 144 bytes
        │     ╰── ArrayNode(4×10 Array with Float32 elements) ["a"]  # 10 obs, 208 bytes
        ├── ArrayNode(3×5 Array with Float32 elements) ["c"]  # 5 obs, 108 bytes
        ├── BagNode ["g"]  # 5 obs, 152 bytes
        │     ╰── BagNode ["i"]  # 15 obs, 304 bytes
        │           ╰── ArrayNode(2×30 Array with Float32 elements) ["j"]  # 30 obs, 288 bytes
        ╰── ArrayNode(2×5 Array with Float32 elements) ["k"]  # 5 obs, 88 bytes

These identifiers can be used to override the default construction functions. Note that the output, i.e. the last feed-forward network of the whole model is always tagged with an empty string "", which simplifies putting linear layer with an appropriate output dimension on the end. Dictionaries with these overrides can be passed in as keyword arguments:

  • fsm overrides constructions of feed-forward models
  • fsa overrides construction of aggregation functions.

For example to specify just the last feed forward neural network:

julia> reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMeanMax;
           fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12)))) |> printtreeBagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Chain(Dense(10 => 20, relu), Dense(20 => 12))  # 6 arrays, 482 params, 2.117 KiB
  ╰── ProductModel ↦ Dense(20 => 5, relu)  # 2 arrays, 105 params, 500 bytes
        ├── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu)  # 4 arrays, 65 params, 420 bytes
        │     ╰── ArrayModel(Dense(4 => 5, relu))  # 2 arrays, 25 params, 180 bytes
        ├── ArrayModel(Dense(3 => 5, relu))  # 2 arrays, 20 params, 160 bytes
        ├── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu)  # 4 arrays, 65 params, 420 bytes
        │     ╰── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu)  # 4 arrays, 65 params, 420 bytes
        │           ╰── ArrayModel(Dense(2 => 5, relu))  # 2 arrays, 15 params, 140 bytes
        ╰── ArrayModel(Dense(2 => 5, relu))  # 2 arrays, 15 params, 140 bytes

Both keyword arguments in action:

julia> reflectinmodel(ds, d -> Dense(d, 5, relu), SegmentedMeanMax;
           fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12))),
           fsa = Dict("Y" => SegmentedMean, "g" => SegmentedPNorm)) |> printtreeBagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Chain(Dense(10 => 20, relu), Dense(20 => 12))  # 6 arrays, 482 params, 2.117 KiB
  ╰── ProductModel ↦ Dense(20 => 5, relu)  # 2 arrays, 105 params, 500 bytes
        ├── BagModel ↦ SegmentedMean(5) ↦ Dense(5 => 5, relu)  # 3 arrays, 35 params, 260 bytes
        │     ╰── ArrayModel(Dense(4 => 5, relu))  # 2 arrays, 25 params, 180 bytes
        ├── ArrayModel(Dense(3 => 5, relu))  # 2 arrays, 20 params, 160 bytes
        ├── BagModel ↦ SegmentedPNorm(5) ↦ Dense(5 => 5, relu)  # 5 arrays, 45 params, 380 bytes
        │     ╰── BagModel ↦ [SegmentedMean(5); SegmentedMax(5)] ↦ Dense(10 => 5, relu)  # 4 arrays, 65 params, 420 bytes
        │           ╰── ArrayModel(Dense(2 => 5, relu))  # 2 arrays, 15 params, 140 bytes
        ╰── ArrayModel(Dense(2 => 5, relu))  # 2 arrays, 15 params, 140 bytes

There are even more ways to modify the reflection behavior, see the reflectinmodel api reference.

Float precision

Mill.jl is built on top of Flux.jl, which by default uses 32-bit precision for model parameters:

julia> Dense(2, 2).weight |> eltypeFloat32

If you attempt to process Float64 data with model using lower precision, you get a warning:

julia> x = randn(2, 2)2×2 Matrix{Float64}:
 -1.13853  -0.721216
 -2.21835   0.0331891
julia> eltype(x)Float64
julia> m = Dense(2, 2)Dense(2 => 2) # 6 parameters
julia> m(x)┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 2) # 6 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/ljuc2/src/layers/stateless.jl:60 2×2 Matrix{Float32}: -2.01866 -0.022319 1.11091 -0.0697713

Unless additional arguments are provided, reflectinmodel also instantiates all Dense layers using 32-bit precision:

julia> x = randn(Float32, 2, 2) |> ArrayNode2×2 ArrayNode{Matrix{Float32}, Nothing}:
  0.7041741  -0.45216417
 -0.3617286   1.4807765
julia> eltype(Mill.data(x))Float32
julia> m = reflectinmodel(x)ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
julia> m.m.weight |> eltypeFloat32

Because reflectinmodel evaluates (sub)models on parts of the input when building the model, if some Float64 values are passed in, the warning is shown during construction as well as during the evaluation:

julia> x = randn(2, 2) |> ArrayNode2×2 ArrayNode{Matrix{Float64}, Nothing}:
 -0.2617165161930634   -1.1169820345030674
  0.03973931389554643  -1.9057709581450335
julia> eltype(Mill.data(x))Float64
julia> m = reflectinmodel(x)┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 10) # 30 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/ljuc2/src/layers/stateless.jl:60 ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
julia> m(x)┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(2 => 10) # 30 parameters │ summary(x) = "2×2 Matrix{Float64}" └ @ Flux ~/.julia/packages/Flux/ljuc2/src/layers/stateless.jl:60 10×2 Matrix{Float32}: 0.178307 -0.207178 -0.174379 -0.190231 0.126416 -0.390108 0.0908661 -0.823867 -0.0826094 -0.324438 0.0130696 0.325754 -0.071541 -0.666409 0.0525041 0.328029 0.0982923 -0.174111 0.178692 0.530059

To prevent this from happening, we recommend making sure that the same precision is used for input data and for reflectinmodel parameters. For example:

julia> x32 = randn(Float32, 2, 2) |> ArrayNode2×2 ArrayNode{Matrix{Float32}, Nothing}:
 1.1421611  0.8822706
 2.514736   0.036934458
julia> m = reflectinmodel(x32)ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
julia> x64 = randn(2, 2) |> ArrayNode2×2 ArrayNode{Matrix{Float64}, Nothing}: 0.3560285043427309 0.9366766658560893 0.6108965887101426 1.560294028451233
julia> m = reflectinmodel(x64, d -> f64(Dense(d, 5)), d -> f64(SegmentedMean(d)))ArrayModel(Dense(2 => 5)) # 2 arrays, 15 params, 200 bytes

Functions Flux.f64 and Flux.f32 may come in handy.