Musk

Musk dataset is a classic MIL problem of the field, introduced in [5]. Below we demonstrate how to solve this problem using Mill.jl.

Jupyter notebook

This example is also available as a Jupyter notebook and the environment is accessible here.

We load all dependencies and fix the seed:

using FileIO, JLD2, Statistics, Mill, Flux, OneHotArrays

using Random; Random.seed!(42);

Loading the data

Now we load the dataset and transform it into a Mill structure. The musk.jld2 file contains...

  • a matrix with features, each column is one instance:
fMat = load("musk.jld2", "fMat")
166×476 Matrix{Float32}:
   42.0    42.0    42.0    42.0    42.0    42.0    42.0    42.0  …    43.0    52.0    49.0    38.0    43.0    39.0    52.0
 -198.0  -191.0  -191.0  -198.0  -198.0  -191.0  -190.0  -199.0     -104.0  -123.0  -199.0  -123.0  -102.0   -58.0  -121.0
 -109.0  -142.0  -142.0  -110.0  -102.0  -142.0  -142.0  -102.0      -20.0   -24.0  -161.0  -139.0   -20.0    27.0   -24.0
  -75.0   -65.0   -75.0   -65.0   -75.0   -65.0   -75.0   -65.0       23.0   -43.0    29.0    30.0  -101.0    31.0  -104.0
 -117.0  -117.0  -117.0  -117.0  -117.0  -117.0  -117.0  -117.0     -117.0  -117.0   -95.0  -117.0  -116.0  -117.0  -116.0
   11.0    55.0    11.0    55.0    10.0    55.0    12.0    55.0  …   -76.0   -82.0   -86.0   -88.0   200.0   -92.0   195.0
   23.0    49.0    49.0    23.0    24.0    49.0    49.0    23.0     -167.0  -163.0   -48.0   214.0  -166.0    85.0  -162.0
  -88.0  -170.0  -161.0   -95.0   -87.0  -170.0  -161.0   -94.0       48.0    60.0     2.0   -13.0    66.0    21.0    76.0
  -28.0   -45.0   -45.0   -28.0   -28.0   -45.0   -45.0   -29.0     -229.0  -234.0   112.0   -74.0  -222.0   -73.0  -226.0
  -27.0     5.0   -28.0     5.0   -28.0     6.0   -29.0     6.0        6.0   -13.0   -79.0  -129.0   -49.0   -68.0   -56.0
    ⋮                                       ⋮                    ⋱             ⋮                                       ⋮
 -238.0  -238.0  -238.0  -238.0  -238.0  -238.0  -238.0  -238.0     -236.0  -247.0  -220.0  -236.0   114.0  -228.0    99.0
  -74.0  -302.0   -73.0  -302.0   -73.0  -300.0   -72.0  -300.0     -260.0  -285.0  -246.0  -226.0    32.0  -232.0    34.0
 -129.0    60.0  -127.0    60.0  -127.0    61.0  -125.0    61.0     -204.0  -212.0  -209.0  -210.0   136.0  -206.0   133.0
 -120.0  -120.0  -120.0  -120.0    51.0    51.0    51.0    51.0      -16.0   -20.0    33.0    20.0   -15.0    13.0   -20.0
  -38.0   -39.0   -38.0   -39.0   128.0   127.0   124.0   127.0  …   143.0   -44.0   152.0    55.0   143.0    45.0   -46.0
   30.0    31.0    30.0    30.0   144.0   143.0   143.0   144.0      120.0    96.0   134.0   119.0   121.0   116.0    95.0
   48.0    48.0    48.0    48.0    43.0    42.0    44.0    42.0       55.0    99.0    47.0    79.0    55.0    79.0    98.0
  -37.0   -37.0   -37.0   -37.0   -30.0   -31.0   -30.0   -30.0      -37.0   -13.0   -43.0   -28.0   -37.0   -28.0   -14.0
    6.0     5.0     5.0     6.0    14.0    14.0    14.0    14.0      -19.0    11.0   -15.0     4.0   -19.0     3.0    12.0
   30.0    30.0    31.0    30.0    26.0    26.0    29.0    25.0  …   -36.0    97.0   -10.0    74.0   -36.0    74.0    96.0
  • the ids of samples (bags in MIL terminology) specifying to which each instance (column in fMat) belongs to:
bagids = load("musk.jld2", "bagids")
476-element Vector{Int64}:
  1
  1
  1
  1
  2
  2
  2
  2
  3
  3
  ⋮
 91
 91
 92
 92
 92
 92
 92
 92
 92
 92
  • and labels defined on the level of instances:
y = load("musk.jld2", "y")
476-element Vector{Int64}:
 1
 1
 1
 1
 1
 1
 1
 1
 1
 1
 ⋮
 0
 0
 0
 0
 0
 0
 0
 0
 0
 0

We create a BagNode structure which holds:

  1. feature matrix and
  2. ranges identifying which columns in the feature matrix each bag spans.
ds = BagNode(ArrayNode(fMat), bagids)
BagNode  # 92 obs, 1.500 KiB
  ╰── ArrayNode(166×476 Array with Float32 elements)  # 476 obs, 308.703 KiB

This representation ensures that feed-forward networks do not need to deal with bag boundaries and always process full continuous matrices:

We also compute labels on the level of bags. In the Musk problem, bag label is defined as a maximum of instance labels (i.e. a bag is positive if at least one of its instances is positive):

y = map(i -> maximum(y[i]) + 1, ds.bags)
y_oh = onehotbatch(y, 1:2)
2×92 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  …  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
 1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1     ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅  ⋅

Model construction

Once the data are in Mill internal format, we will manually create a model. BagModel is designed to implement a basic multi-instance learning model utilizing two feed-forward networks with an aggregaton operator in between:

model = BagModel(
    Dense(166, 50, Flux.tanh),
    SegmentedMeanMax(50),
    Chain(Dense(100, 50, Flux.tanh), Dense(50, 2)))
BagModel ↦ [SegmentedMean(50); SegmentedMax(50)] ↦ Chain(Dense(100 => 50, tanh), Dense(50 => 2))  # 6 arrays, 5_252 params, ⋯
  ╰── ArrayModel(Dense(166 => 50, tanh))  # 2 arrays, 8_350 params, 32.695 KiB

Instances are first passed through a single layer with 50 neurons (input dimension is 166) with tanh non-linearity, then we use mean and max aggregation functions simultaneously (for some problems, max is better then mean, therefore we use both), and then we use one layer with 50 neurons and tanh nonlinearity followed by linear layer with 2 neurons (output dimension). We check that forward pass works

model(ds)
2×92 Matrix{Float32}:
  0.890284  1.01132    0.88188   0.917716  1.20012    0.548941  …   0.238126    1.30782   1.46193  -0.254     -1.00311
 -0.864341  0.303749  -0.324059  0.144588  0.107678  -1.17662      -0.0126107  -0.97219  -0.71427   0.569847  -0.302701
Easier model construction

Note that the model can be obtained in a more straightforward way using Model reflection.

Training

Since Mill is entirely compatible with Flux.jl, we can use its Adam optimizer:

opt_state = Flux.setup(Adam(), model)
(im = (m = (weight = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()),), a = (fs = ((ψ = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))),), (ψ = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))),)),), bm = (layers = ((weight = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0; … ; 0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0.9, 0.999))), σ = ()), (weight = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.001, (0.9, 0.999), 1.0e-8), (Float32[0.0, 0.0], Float32[0.0, 0.0], (0.9, 0.999))), σ = ())),))

...define a loss function as Flux.logitcrossentropy:

loss(m, x, y) = Flux.logitcrossentropy(m(x), y)
loss (generic function with 1 method)

...and run a simple training procedure using the Flux.train! procedure:

for e in 1:100
    if e % 10 == 1
        @info "Epoch $e" training_loss=loss(model, ds, y_oh)
    end
    Flux.train!(loss, model, [(ds, y_oh)], opt_state)
end
┌ Info: Epoch 1
└   training_loss = 0.79128903f0
┌ Info: Epoch 11
└   training_loss = 0.3943807f0
┌ Info: Epoch 21
└   training_loss = 0.26313362f0
┌ Info: Epoch 31
└   training_loss = 0.1771421f0
┌ Info: Epoch 41
└   training_loss = 0.12149703f0
┌ Info: Epoch 51
└   training_loss = 0.08492497f0
┌ Info: Epoch 61
└   training_loss = 0.058849733f0
┌ Info: Epoch 71
└   training_loss = 0.043363024f0
┌ Info: Epoch 81
└   training_loss = 0.030421462f0
┌ Info: Epoch 91
└   training_loss = 0.023233365f0

Finally, we calculate the (training) accuracy:

mean(Flux.onecold(model(ds), 1:2) .== y)
1.0