HierarchicalUtils.jl
Mill.jl
uses HierarchicalUtils.jl
which brings a lot of additional features.
using HierarchicalUtils
Printing
For instance, Base.show
with text/plain
MIME calls HierarchicalUtils.printtree
:
julia> ds = BagNode(ProductNode((BagNode(randn(4, 10), [1:2, 3:4, 5:5, 6:7, 8:10]), randn(3, 5), BagNode(BagNode(randn(2, 30), [i:i+1 for i in 1:2:30]), [1:3, 4:6, 7:9, 10:12, 13:15]), randn(2, 5))), [1:1, 2:3, 4:5])
BagNode 3 obs ╰── ProductNode 5 obs ├── BagNode 5 obs │ ╰── ArrayNode(4×10 Array with Float64 elements) 10 obs ├── ArrayNode(3×5 Array with Float64 elements) 5 obs ├── BagNode 5 obs │ ╰── BagNode 15 obs │ ╰── ArrayNode(2×30 Array with Float64 elements) 30 obs ╰── ArrayNode(2×5 Array with Float64 elements) 5 obs
julia> printtree(ds; htrunc=3)
BagNode 3 obs ╰── ProductNode 5 obs ├── BagNode 5 obs │ ┊ ├── ArrayNode(3×5 Array with Float64 elements) 5 obs ├── BagNode 5 obs │ ┊ ╰── ArrayNode(2×5 Array with Float64 elements) 5 obs
This can be used to print a non-truncated version of a model:
julia> printtree(ds)
BagNode 3 obs ╰── ProductNode 5 obs ├── BagNode 5 obs │ ╰── ArrayNode(4×10 Array with Float64 elements) 10 obs ├── ArrayNode(3×5 Array with Float64 elements) 5 obs ├── BagNode 5 obs │ ╰── BagNode 15 obs │ ╰── ArrayNode(2×30 Array with Float64 elements) 30 obs ╰── ArrayNode(2×5 Array with Float64 elements) 5 obs
Traversal encoding
Callling with trav=true
enables convenient traversal functionality with string indexing:
julia> m = reflectinmodel(Flux.f32(ds))
BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) ⋯ ╰── ProductModel ↦ Dense(40 => 10) 2 arrays, 410 params, 1.688 KiB ├── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense ⋯ │ ╰── ArrayModel(Dense(4 => 10)) 2 arrays, 50 params, 288 bytes ├── ArrayModel(Dense(3 => 10)) 2 arrays, 40 params, 248 bytes ├── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense ⋯ │ ╰── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ ⋯ │ ╰── ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 b ⋯ ╰── ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 bytes
julia> printtree(m; trav=true)
BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) [""] 4 arrays, 240 params, 1.102 KiB ╰── ProductModel ↦ Dense(40 => 10) ["U"] 2 arrays, 410 params, 1.688 KiB ├── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) ["Y"] 4 arrays, 240 params, 1.102 KiB │ ╰── ArrayModel(Dense(4 => 10)) ["a"] 2 arrays, 50 params, 288 bytes ├── ArrayModel(Dense(3 => 10)) ["c"] 2 arrays, 40 params, 248 bytes ├── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) ["g"] 4 arrays, 240 params, 1.102 KiB │ ╰── BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) ["i"] 4 arrays, 240 params, 1.102 KiB │ ╰── ArrayModel(Dense(2 => 10)) ["j"] 2 arrays, 30 params, 208 bytes ╰── ArrayModel(Dense(2 => 10)) ["k"] 2 arrays, 30 params, 208 bytes
This way any node in the model tree is swiftly accessible, which may come in handy when inspecting model parameters or simply deleting/replacing/inserting nodes to tree (for instance when constructing adversarial samples). All tree nodes are accessible by indexing with the traversal code:.
julia> m["Y"]
BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) ⋯ ╰── ArrayModel(Dense(4 => 10)) 2 arrays, 50 params, 288 bytes
The following two approaches give the same result:
julia> m["Y"] ≡ m.im.ms[1]
true
Counting functions and iterators
Other functions provided by HierarchicalUtils.jl
:
julia> nnodes(ds)
9
julia> nleafs(ds)
4
julia> NodeIterator(ds) |> collect
9-element Vector{AbstractMillNode}: BagNode ProductNode BagNode ArrayNode(4×10 Array with Float64 elements) ArrayNode(3×5 Array with Float64 elements) BagNode BagNode ArrayNode(2×30 Array with Float64 elements) ArrayNode(2×5 Array with Float64 elements)
julia> NodeIterator(ds, m) |> collect
9-element Vector{Tuple{AbstractMillNode, AbstractMillModel}}: (BagNode, BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)) (ProductNode, ProductModel ↦ Dense(40 => 10)) (BagNode, BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)) (ArrayNode(4×10 Array with Float64 elements), ArrayModel(Dense(4 => 10))) (ArrayNode(3×5 Array with Float64 elements), ArrayModel(Dense(3 => 10))) (BagNode, BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)) (BagNode, BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)) (ArrayNode(2×30 Array with Float64 elements), ArrayModel(Dense(2 => 10))) (ArrayNode(2×5 Array with Float64 elements), ArrayModel(Dense(2 => 10)))
julia> LeafIterator(ds) |> collect
4-element Vector{ArrayNode{Matrix{Float64}, Nothing}}: ArrayNode(4×10 Array with Float64 elements) ArrayNode(3×5 Array with Float64 elements) ArrayNode(2×30 Array with Float64 elements) ArrayNode(2×5 Array with Float64 elements)
julia> TypeIterator(BagModel, m) |> collect
4-element Vector{BagModel{T, BagCount{AggregationStack{Tuple{SegmentedMean{Vector{Float32}}, SegmentedMax{Vector{Float32}}}}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}} where T<:AbstractMillModel}: BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10) BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dense(21 => 10)
julia> PredicateIterator(x -> numobs(x) ≥ 10, ds) |> collect
3-element Vector{AbstractMillNode}: ArrayNode(4×10 Array with Float64 elements) BagNode ArrayNode(2×30 Array with Float64 elements)
For the complete showcase of possibilites, refer to HierarchicalUtils.jl
and this notebook.