More on nodes

Node nesting

The main advantage of the Mill.jl library is that it allows to arbitrarily nest and cross-product BagModels, as described in Theorem 5 in [6]. In other words, instances themselves may be represented in much more complex way than in the BagNode and BagModel example.

Let's start the demonstration by nesting two MIL problems. The outer MIL model contains three samples (outer-level bags), whose instances are (inner-level) bags themselves. The first outer-level bag contains one inner-level bag problem with two inner-level instances, the second outer-level bag contains two inner-level bags with total of three inner-level instances, and finally the third outer-level bag contains two inner bags with four instances:

julia> ds = BagNode(BagNode(ArrayNode(randn(Float32, 4, 10)),
                            [1:2, 3:4, 5:5, 6:7, 8:10]),
                    [1:1, 2:3, 4:5])BagNode  # 3 obs, 120 bytes
  ╰── BagNode  # 5 obs, 144 bytes
        ╰── ArrayNode(4×10 Array with Float32 elements)  # 10 obs, 208 bytes

Here is one example of a model, which is appropriate for this hierarchy:

using Flux: Dense, Chain, relu
julia> m = BagModel(
               BagModel(
                   ArrayModel(Dense(4, 3, relu)),
                   SegmentedMeanMax(3),
                   Dense(6, 3, relu)),
               SegmentedMeanMax(3),
               Chain(Dense(6, 3, relu), Dense(3, 2)))BagModel ↦ [SegmentedMean(3); SegmentedMax(3)] ↦ Chain(Dense(6 => 3, relu), Dense(3 => 2))  # 6 arrays, 35 params, 380 byte ⋯
  ╰── BagModel ↦ [SegmentedMean(3); SegmentedMax(3)] ↦ Dense(6 => 3, relu)  # 4 arrays, 27 params, 268 bytes
        ╰── ArrayModel(Dense(4 => 3, relu))  # 2 arrays, 15 params, 140 bytes

and can be directly applied to obtain a result:

julia> m(ds)2×3 Matrix{Float32}:
 0.0  -0.0165677  -0.0404657
 0.0   0.013503    0.0329802

Here we again make use of the property that even if each instance is represented with an arbitrarily complex structure, we always obtain a vector representation after applying instance model im, regardless of the complexity of im and Mill.data(ds):

julia> m.im(Mill.data(ds))3×5 Matrix{Float32}:
 0.0       0.545624  0.0  1.03625  0.018455
 0.0       0.708773  0.0  1.16774  0.0
 0.403142  0.740321  0.0  1.4657   0.556679

In one final example we demonstrate a complex model consisting of all types of nodes introduced so far:

julia> ds = BagNode(ProductNode((BagNode(randn(Float32, 4, 10),
                                         [1:2, 3:4, 5:5, 6:7, 8:10]),
                                 randn(Float32, 3, 5),
                                 BagNode(BagNode(randn(Float32, 2, 30),
                                                 [i:i+1 for i in 1:2:30]),
                                         [1:3, 4:6, 7:9, 10:12, 13:15]),
                                 randn(Float32, 2, 5))),
                    [1:1, 2:3, 4:5])BagNode  # 3 obs, 160 bytes
  ╰── ProductNode  # 5 obs, 56 bytes
        ├── BagNode  # 5 obs, 144 bytes
        ╰── ArrayNode(4×10 Array with Float32 elements)  # 10 obs, 208 bytes
        ├── ArrayNode(3×5 Array with Float32 elements)  # 5 obs, 108 bytes
        ├── BagNode  # 5 obs, 152 bytes
        ╰── BagNode  # 15 obs, 304 bytes
              ╰── ArrayNode(2×30 Array with Float32 elements)  # 30 obs, 288 bytes
        ╰── ArrayNode(2×5 Array with Float32 elements)  # 5 obs, 88 bytes

When data and model trees become complex, Mill limits the printing. To inspect the whole tree, use printtree:

julia> printtree(ds)BagNode  # 3 obs, 160 bytes
  ╰── ProductNode  # 5 obs, 56 bytes
        ├── BagNode  # 5 obs, 144 bytes
        │     ╰── ArrayNode(4×10 Array with Float32 elements)  # 10 obs, 208 bytes
        ├── ArrayNode(3×5 Array with Float32 elements)  # 5 obs, 108 bytes
        ├── BagNode  # 5 obs, 152 bytes
        │     ╰── BagNode  # 15 obs, 304 bytes
        │           ╰── ArrayNode(2×30 Array with Float32 elements)  # 30 obs, 288 bytes
        ╰── ArrayNode(2×5 Array with Float32 elements)  # 5 obs, 88 bytes

Instead of defining a model manually, we can also make use of Model reflection, another Mill functionality, which simplifies model creation:

julia> m = reflectinmodel(ds, d -> Dense(d, 2), SegmentedMean)BagModel ↦ SegmentedMean(2) ↦ Dense(2 => 2)  # 3 arrays, 8 params, 152 bytes
  ╰── ProductModel ↦ Dense(8 => 2)  # 2 arrays, 18 params, 152 bytes
        ├── BagModel ↦ SegmentedMean(2) ↦ Dense(2 => 2)  # 3 arrays, 8 params, 152 bytes
        ╰── ArrayModel(Dense(4 => 2))  # 2 arrays, 10 params, 120 bytes
        ├── ArrayModel(Dense(3 => 2))  # 2 arrays, 8 params, 112 bytes
        ├── BagModel ↦ SegmentedMean(2) ↦ Dense(2 => 2)  # 3 arrays, 8 params, 152 bytes
        ╰── BagModel ↦ SegmentedMean(2) ↦ Dense(2 => 2)  # 3 arrays, 8 params, 152 bytes
              ╰── ArrayModel(Dense(2 => 2))  # 2 arrays, 6 params, 104 bytes
        ╰── ArrayModel(Dense(2 => 2))  # 2 arrays, 6 params, 104 bytes
julia> m(ds)2×3 Matrix{Float32}: -0.0760647 0.0462476 -0.0648969 0.030561 0.0626038 -0.0322707

Node conveniences

To make the handling of data and model hierarchies easier, Mill provides several tools. Let's setup some data:

julia> AN = ArrayNode(Float32.([1 2 3 4; 5 6 7 8]))2×4 ArrayNode{Matrix{Float32}, Nothing}:
 1.0  2.0  3.0  4.0
 5.0  6.0  7.0  8.0
julia> AM = reflectinmodel(AN)ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
julia> BN = BagNode(AN, [1:1, 2:3, 4:4])BagNode # 3 obs, 112 bytes ╰── ArrayNode(2×4 Array with Float32 elements) # 4 obs, 80 bytes
julia> BM = reflectinmodel(BN)BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) # 4 arrays, 240 params, 1.094 KiB ╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes
julia> PN = ProductNode(a=Float32.([1 2 3; 4 5 6]), b=BN)ProductNode # 3 obs, 24 bytes ├── a: ArrayNode(2×3 Array with Float32 elements) # 3 obs, 72 bytes ╰── b: BagNode # 3 obs, 112 bytes ╰── ArrayNode(2×4 Array with Float32 elements) # 4 obs, 80 bytes
julia> PM = reflectinmodel(PN)ProductModel ↦ Dense(20 => 10) # 2 arrays, 210 params, 920 bytes ├── a: ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes ╰── b: BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) # 4 arrays, 240 params, 1.094 KiB ╰── ArrayModel(Dense(2 => 10)) # 2 arrays, 30 params, 200 bytes

Function: numobs

numobs function from MLUtils.jl returns a number of samples from the current level point of view. This number usually increases as we go down the tree when BagNodes are involved, as each bag may contain more than one instance.

julia> numobs(AN)4
julia> numobs(BN)3
julia> numobs(PN)3

Indexing and Slicing

Indexing in [Mill] operates on the level of observations:

julia> AN[1]2×1 ArrayNode{Matrix{Float32}, Nothing}:
 1.0
 5.0
julia> numobs(ans)1
julia> BN[2]BagNode # 1 obs, 80 bytes ╰── ArrayNode(2×2 Array with Float32 elements) # 2 obs, 64 bytes
julia> numobs(ans)1
julia> PN[3]ProductNode # 1 obs, 24 bytes ├── a: ArrayNode(2×1 Array with Float32 elements) # 1 obs, 56 bytes ╰── b: BagNode # 1 obs, 80 bytes ╰── ArrayNode(2×1 Array with Float32 elements) # 1 obs, 56 bytes
julia> numobs(ans)1
julia> AN[[1, 4]]2×2 ArrayNode{Matrix{Float32}, Nothing}: 1.0 4.0 5.0 8.0
julia> numobs(ans)2
julia> BN[1:2]BagNode # 2 obs, 96 bytes ╰── ArrayNode(2×3 Array with Float32 elements) # 3 obs, 72 bytes
julia> numobs(ans)2
julia> PN[[2, 3]]ProductNode # 2 obs, 24 bytes ├── a: ArrayNode(2×2 Array with Float32 elements) # 2 obs, 64 bytes ╰── b: BagNode # 2 obs, 96 bytes ╰── ArrayNode(2×3 Array with Float32 elements) # 3 obs, 72 bytes
julia> numobs(ans)2
julia> PN[Int[]]ProductNode # 0 obs, 24 bytes ├── a: ArrayNode(2×0 Array with Float32 elements) # 0 obs, 48 bytes ╰── b: BagNode # 0 obs, 64 bytes ╰── ArrayNode(2×0 Array with Float32 elements) # 0 obs, 48 bytes
julia> numobs(ans)0

This may be useful for creating minibatches and their permutations.

Note that apart from the perhaps apparent recurrent effect, this operation requires other implicit actions, such as properly recomputing bag indices:

julia> BN.bagsAlignedBags{Int64}(UnitRange{Int64}[1:1, 2:3, 4:4])
julia> BN[[1, 3]].bagsAlignedBags{Int64}(UnitRange{Int64}[1:1, 2:2])

Function: catobs

catobs function concatenates several samples (datasets) together:

julia> catobs(AN[1], AN[4])2×2 ArrayNode{Matrix{Float32}, Nothing}:
 1.0  4.0
 5.0  8.0
julia> catobs(BN[3], BN[[2, 1]])BagNode # 3 obs, 112 bytes ╰── ArrayNode(2×4 Array with Float32 elements) # 4 obs, 80 bytes
julia> catobs(PN[[1, 2]], PN[3:3]) == PNtrue

Again, the effect is recurrent and everything is appropriately recomputed:

julia> BN.bagsAlignedBags{Int64}(UnitRange{Int64}[1:1, 2:3, 4:4])
julia> catobs(BN[3], BN[[1]]).bagsAlignedBags{Int64}(UnitRange{Int64}[1:1, 2:2])

This operation is an analogy to what is usually done in the classical setting. If every observation is represented as a vector of features, each (mini)batch of samples is first concatenated into one matrix and the whole matrix is run through the neural network using fast matrix multiplication procedures. The same reasoning applies here, but instead of Base.cat, catobs is needed.

Equipped with everything mentioned above there are two different ways to construct minibatches from data. First option, applicable mainly to smaller datasets, is to load all avaiable data into memory, store it as one big data node containing all observations, and use Indexing and Slicing to obtain minibatches. Such approach is demonstrated in the Musk example. The other option is to read all observations into memory separately (or load them on demand) and construct minibatches with catobs.

More tips

For more tips for handling datasets and models, see External tools.

Metadata

Each AbstractMillNode can also carry arbitrary metadata (defaulting to nothing). Metadata is provided upon construction of the node and accessed metadata by Mill.metadata:

julia> n1 = ArrayNode(randn(2, 2), ["metadata"])2×2 ArrayNode{Matrix{Float64}, Vector{String}}:
 -0.22684270643341806   1.3321657770216255
  0.015328464950175435  1.059974135058343
julia> Mill.metadata(n1)1-element Vector{String}: "metadata"
julia> n2 = ProductNode(n1, [1 3; 2 4])ProductNode # 2 obs, 96 bytes ╰── ArrayNode(2×2 Array with Float64 elements) # 2 obs, 152 bytes
julia> Mill.metadata(n2)2×2 Matrix{Int64}: 1 3 2 4