HierarchicalUtils.jl

Mill.jl uses HierarchicalUtils.jl which brings a lot of additional features.

using HierarchicalUtils

Printing

For instance, Base.show with text/plain MIME calls HierarchicalUtils.printtree:

julia> ds = BagNode(ProductNode((BagNode(randn(4, 10),
                                         [1:2, 3:4, 5:5, 6:7, 8:10]),
                                 randn(3, 5),
                                 BagNode(BagNode(randn(2, 30),
                                                 [i:i+1 for i in 1:2:30]),
                                         [1:3, 4:6, 7:9, 10:12, 13:15]),
                                 randn(2, 5))),
                    [1:1, 2:3, 4:5])BagNode  3 obs
  ╰── ProductNode  5 obs
        ├── BagNode  5 obs
        ╰── ArrayNode(4×10 Array with Float64 elements)  10 obs
        ├── ArrayNode(3×5 Array with Float64 elements)  5 obs
        ├── BagNode  5 obs
        ╰── BagNode  15 obs
              ╰── ArrayNode(2×30 Array with Float64 elements)  30 obs
        ╰── ArrayNode(2×5 Array with Float64 elements)  5 obs
julia> printtree(ds; htrunc=3)BagNode 3 obs ╰── ProductNode 5 obs ├── BagNode 5 obs │ ┊ ├── ArrayNode(3×5 Array with Float64 elements) 5 obs ├── BagNode 5 obs │ ┊ ╰── ArrayNode(2×5 Array with Float64 elements) 5 obs

This can be used to print a non-truncated version of a model:

julia> printtree(ds)BagNode  3 obs
  ╰── ProductNode  5 obs
        ├── BagNode  5 obs
        │     ╰── ArrayNode(4×10 Array with Float64 elements)  10 obs
        ├── ArrayNode(3×5 Array with Float64 elements)  5 obs
        ├── BagNode  5 obs
        │     ╰── BagNode  15 obs
        │           ╰── ArrayNode(2×30 Array with Float64 elements)  30 obs
        ╰── ArrayNode(2×5 Array with Float64 elements)  5 obs

Traversal encoding

Callling with trav=true enables convenient traversal functionality with string indexing:

julia> m = reflectinmodel(Flux.f32(ds))BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)   
  ╰── ProductModel ↦ Dense(40 => 10)  2 arrays, 410 params, 1.688 KiB
        ├── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense 
        ╰── ArrayModel(Dense(4 => 10))  2 arrays, 50 params, 288 bytes
        ├── ArrayModel(Dense(3 => 10))  2 arrays, 40 params, 248 bytes
        ├── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense 
        ╰── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ 
              ╰── ArrayModel(Dense(2 => 10))  2 arrays, 30 params, 208 b 
        ╰── ArrayModel(Dense(2 => 10))  2 arrays, 30 params, 208 bytes
julia> printtree(m; trav=true)BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) [""] 4 arrays, 240 params, 1.102 KiB ╰── ProductModel ↦ Dense(40 => 10) ["U"] 2 arrays, 410 params, 1.688 KiB ├── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) ["Y"] 4 arrays, 240 params, 1.102 KiB │ ╰── ArrayModel(Dense(4 => 10)) ["a"] 2 arrays, 50 params, 288 bytes ├── ArrayModel(Dense(3 => 10)) ["c"] 2 arrays, 40 params, 248 bytes ├── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) ["g"] 4 arrays, 240 params, 1.102 KiB │ ╰── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) ["i"] 4 arrays, 240 params, 1.102 KiB │ ╰── ArrayModel(Dense(2 => 10)) ["j"] 2 arrays, 30 params, 208 bytes ╰── ArrayModel(Dense(2 => 10)) ["k"] 2 arrays, 30 params, 208 bytes

This way any node in the model tree is swiftly accessible, which may come in handy when inspecting model parameters or simply deleting/replacing/inserting nodes to tree (for instance when constructing adversarial samples). All tree nodes are accessible by indexing with the traversal code:.

julia> m["Y"]BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)   
  ╰── ArrayModel(Dense(4 => 10))  2 arrays, 50 params, 288 bytes

The following two approaches give the same result:

julia> m["Y"] ≡ m.im.ms[1]true

Counting functions and iterators

Other functions provided by HierarchicalUtils.jl:

julia> nnodes(ds)9
julia> nleafs(ds)4
julia> NodeIterator(ds) |> collect9-element Vector{AbstractMillNode}: BagNode ProductNode BagNode ArrayNode(4×10 Array with Float64 elements) ArrayNode(3×5 Array with Float64 elements) BagNode BagNode ArrayNode(2×30 Array with Float64 elements) ArrayNode(2×5 Array with Float64 elements)
julia> NodeIterator(ds, m) |> collect9-element Vector{Tuple{AbstractMillNode, AbstractMillModel}}: (BagNode, BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)) (ProductNode, ProductModel ↦ Dense(40 => 10)) (BagNode, BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)) (ArrayNode(4×10 Array with Float64 elements), ArrayModel(Dense(4 => 10))) (ArrayNode(3×5 Array with Float64 elements), ArrayModel(Dense(3 => 10))) (BagNode, BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)) (BagNode, BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)) (ArrayNode(2×30 Array with Float64 elements), ArrayModel(Dense(2 => 10))) (ArrayNode(2×5 Array with Float64 elements), ArrayModel(Dense(2 => 10)))
julia> LeafIterator(ds) |> collect4-element Vector{ArrayNode{Matrix{Float64}, Nothing}}: ArrayNode(4×10 Array with Float64 elements) ArrayNode(3×5 Array with Float64 elements) ArrayNode(2×30 Array with Float64 elements) ArrayNode(2×5 Array with Float64 elements)
julia> TypeIterator(BagModel, m) |> collect4-element Vector{BagModel{T, BagCount{AggregationStack{Tuple{SegmentedMean{Vector{Float32}}, SegmentedMax{Vector{Float32}}}}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}} where T<:AbstractMillModel}: BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)
julia> PredicateIterator(x -> numobs(x) ≥ 10, ds) |> collect3-element Vector{AbstractMillNode}: ArrayNode(4×10 Array with Float64 elements) BagNode ArrayNode(2×30 Array with Float64 elements)

For the complete showcase of possibilites, refer to HierarchicalUtils.jl and this notebook.