Mutagenesis

This example demonstrates how to predict the mutagenicity on Salmonella typhimurium.

Jupyter notebook

This example is also available as a Jupyter notebook and the environment and the data are accessible here.

We load all dependencies and fix the seed:

using JsonGrinder, Mill, Flux, JSON, MLUtils, Statistics

using Random; Random.seed!(42);

Loading the data

we load the dataset (available ), and split it into training and testing set.

dataset = JSON.parsefile("mutagenesis.json");
jss_train, jss_test = dataset[1:100], dataset[101:end];

jss_train and jss_test are just lists of parsed JSONs:

jss_train[1]
Dict{String, Any} with 6 entries:
  "ind1"      => 1
  "lumo"      => -1.246
  "inda"      => 0
  "logp"      => 4.23
  "mutagenic" => 1
  "atoms"     => Any[Dict{String, Any}("element"=>"c", "atom_type"=>22, "bonds"…

We also extract binary labels, which are stored in the "mutagenic" key:

y_train = getindex.(jss_train, "mutagenic");
y_test = getindex.(jss_test, "mutagenic");
y_train
100-element Vector{Int64}:
 1
 1
 0
 1
 1
 1
 1
 1
 1
 1
 ⋮
 0
 0
 1
 1
 0
 0
 1
 0
 0

We first create the schema of the training data, which is the first important step in using the JsonGrinder.jl. This infers both the hierarchical structure of the documents and basic statistics of individual values.

sch = schema(jss_train)
DictEntry 100x updated
  ├────── atoms: ArrayEntry 100x updated
  ╰── DictEntry 2529x updated
        ├── atom_type: LeafEntry (28 unique `Real` values) 25 
        ├────── bonds: ArrayEntry 2529x updated
        ╰── DictEntry 5402x updated
            
        ├───── charge: LeafEntry (318 unique `Real` values) 2 
        ╰──── element: LeafEntry (6 unique `String` values) 2 
  ├─────── ind1: LeafEntry (2 unique `Real` values) 100x updated
  ├─────── inda: LeafEntry (1 unique `Real` values) 100x updated
  ├─────── logp: LeafEntry (62 unique `Real` values) 100x updated
  ├─────── lumo: LeafEntry (98 unique `Real` values) 100x updated
  ╰── mutagenic: LeafEntry (2 unique `Real` values) 100x updated

Of course, we have to remove the "mutagenic" key from the schema, as we don't want to include it in the data:

delete!(sch, :mutagenic);
sch
DictEntry 100x updated
  ├── atoms: ArrayEntry 100x updated
  ╰── DictEntry 2529x updated
        ├── atom_type: LeafEntry (28 unique `Real` values) 2529x  
        ├────── bonds: ArrayEntry 2529x updated
        ╰── DictEntry 5402x updated
            
        ├───── charge: LeafEntry (318 unique `Real` values) 2529x 
        ╰──── element: LeafEntry (6 unique `String` values) 2529x 
  ├─── ind1: LeafEntry (2 unique `Real` values) 100x updated
  ├─── inda: LeafEntry (1 unique `Real` values) 100x updated
  ├─── logp: LeafEntry (62 unique `Real` values) 100x updated
  ╰─── lumo: LeafEntry (98 unique `Real` values) 100x updated

Now we create an extractor capable of converting JSONs to Mill.jl structures. We use function suggestextractor with the default settings:

e = suggestextractor(sch)
DictExtractor
  ├─── lumo: CategoricalExtractor(n=99)
  ├─── inda: CategoricalExtractor(n=2)
  ├─── logp: CategoricalExtractor(n=63)
  ├─── ind1: CategoricalExtractor(n=3)
  ╰── atoms: ArrayExtractor
               ╰── DictExtractor
                     ├──── element: CategoricalExtractor(n=7)
                     ├────── bonds: ArrayExtractor
                     ╰── DictExtractor
                         
                     ├───── charge: ScalarExtractor(c=-0.781, s=0.60790277)
                     ╰── atom_type: CategoricalExtractor(n=29)

We also need to convert JSONs to Mill.jl data samples. Extractor e is callable, we can use it to extract one document as follows:

x_single = e(jss_train[1])
ProductNode  1 obs
  ├─── lumo: ArrayNode(99×1 OneHotArray with Bool elements)  1 obs
  ├─── inda: ArrayNode(2×1 OneHotArray with Bool elements)  1 obs
  ├─── logp: ArrayNode(63×1 OneHotArray with Bool elements)  1 obs
  ├─── ind1: ArrayNode(3×1 OneHotArray with Bool elements)  1 obs
  ╰── atoms: BagNode  1 obs
               ╰── ProductNode  26 obs
                     ├──── element: ArrayNode(7×26 OneHotArray with Bool eleme 
                     ├────── bonds: BagNode  26 obs
                     ╰── ProductNode  56 obs
                         
                     ├───── charge: ArrayNode(1×26 Array with Float32 elements 
                     ╰── atom_type: ArrayNode(29×26 OneHotArray with Bool elem 

To extract a batch of 10 documents, we can extract individual documents and then Mill.catobs them:

x_batch = reduce(catobs, e.(jss_train[1:10]))
ProductNode  10 obs
  ├─── lumo: ArrayNode(99×10 OneHotArray with Bool elements)  10 obs
  ├─── inda: ArrayNode(2×10 OneHotArray with Bool elements)  10 obs
  ├─── logp: ArrayNode(63×10 OneHotArray with Bool elements)  10 obs
  ├─── ind1: ArrayNode(3×10 OneHotArray with Bool elements)  10 obs
  ╰── atoms: BagNode  10 obs
               ╰── ProductNode  299 obs
                     ├──── element: ArrayNode(7×299 OneHotArray with Bool elem 
                     ├────── bonds: BagNode  299 obs
                     ╰── ProductNode  650 obs
                         
                     ├───── charge: ArrayNode(1×299 Array with Float32 element 
                     ╰── atom_type: ArrayNode(29×299 OneHotArray with Bool ele 

Or we can use a much more efficient extract function, which operates on a list of documents: Because the dataset is small, we can extract all data at once and keep it in memory:

x_train = extract(e, jss_train);
x_test = extract(e, jss_test);
x_train
ProductNode  100 obs
  ├─── lumo: ArrayNode(99×100 OneHotArray with Bool elements)  100 obs
  ├─── inda: ArrayNode(2×100 OneHotArray with Bool elements)  100 obs
  ├─── logp: ArrayNode(63×100 OneHotArray with Bool elements)  100 obs
  ├─── ind1: ArrayNode(3×100 OneHotArray with Bool elements)  100 obs
  ╰── atoms: BagNode  100 obs
               ╰── ProductNode  2529 obs
                     ├──── element: ArrayNode(7×2529 OneHotArray with Bool ele 
                     ├────── bonds: BagNode  2529 obs
                     ╰── ProductNode  5402 obs
                         
                     ├───── charge: ArrayNode(1×2529 Array with Float32 elemen 
                     ╰── atom_type: ArrayNode(29×2529 OneHotArray with Bool el 

Then we create an encoding model capable of embedding each JSON document into a fixed-size vector.

encoder = reflectinmodel(sch, e)
ProductModel ↦ Dense(50 => 10)  2 arrays, 510 params, 2.078 KiB
  ├─── lumo: ArrayModel(Dense(99 => 10))  2 arrays, 1_000 params, 3.992 KiB
  ├─── inda: ArrayModel(Dense(2 => 10))  2 arrays, 30 params, 208 bytes
  ├─── logp: ArrayModel(Dense(63 => 10))  2 arrays, 640 params, 2.586 KiB
  ├─── ind1: ArrayModel(Dense(3 => 10))  2 arrays, 40 params, 248 bytes
  ╰── atoms: BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dens 
               ╰── ProductModel ↦ Dense(31 => 10)  2 arrays, 320 params, 1.336 
                     ├──── element: ArrayModel(Dense(7 => 10))  2 arrays, 80 p 
                     ├────── bonds: BagModel ↦ BagCount([SegmentedMean(10); Se 
                     ╰── ProductModel ↦ Dense(31 => 10)  2 ar 
                         
                     ├───── charge: ArrayModel(identity)
                     ╰── atom_type: ArrayModel(Dense(29 => 10))  2 arrays, 300 
Further reading

For further details about reflectinmodel, see the Mill.jl documentation.

Finally, we chain the encoder with one more dense layer computing the logit of mutagenic probability:

model = vec ∘ Dense(10, 1) ∘ encoder
vec ∘ Dense(10 => 1) ∘ ProductModel ↦ Dense(50 => 10)

We can train the model in the standard Flux.jl way. We define the loss function, optimizer, and minibatch iterator:

pred(m, x) = σ.(m(x))
loss(m, x, y) = Flux.Losses.logitbinarycrossentropy(m(x), y);
opt_state = Flux.setup(Flux.Optimise.Descent(), model);
minibatch_iterator = Flux.DataLoader((x_train, y_train), batchsize=32, shuffle=true);

We train for 10 epochs, and after each epoch we report the training accuracy:

accuracy(p, y) = mean((p .> 0.5) .== y)
for i in 1:10
    Flux.train!(loss, model, minibatch_iterator, opt_state)
    @info "Epoch $i" accuracy=accuracy(pred(model, x_train), y_train)
end
┌ Info: Epoch 1
└   accuracy = 0.61
┌ Info: Epoch 2
└   accuracy = 0.63
┌ Info: Epoch 3
└   accuracy = 0.64
┌ Info: Epoch 4
└   accuracy = 0.74
┌ Info: Epoch 5
└   accuracy = 0.61
┌ Info: Epoch 6
└   accuracy = 0.82
┌ Info: Epoch 7
└   accuracy = 0.82
┌ Info: Epoch 8
└   accuracy = 0.84
┌ Info: Epoch 9
└   accuracy = 0.82
┌ Info: Epoch 10
└   accuracy = 0.82

We can compute the accuracy on the testing set now:

accuracy(pred(model, x_test), y_test)
0.8636363636363636