Mutagenesis
This example demonstrates how to predict the mutagenicity on Salmonella typhimurium.
This example is also available as a Jupyter notebook and the environment and the data are accessible here.
We load all dependencies and fix the seed:
using JsonGrinder, Mill, Flux, JSON, MLUtils, Statistics
using Random; Random.seed!(42);
Loading the data
we load the dataset (available ), and split it into training and testing set.
dataset = JSON.parsefile("mutagenesis.json");
jss_train, jss_test = dataset[1:100], dataset[101:end];
jss_train
and jss_test
are just lists of parsed JSONs:
jss_train[1]
Dict{String, Any} with 6 entries:
"ind1" => 1
"lumo" => -1.246
"inda" => 0
"logp" => 4.23
"mutagenic" => 1
"atoms" => Any[Dict{String, Any}("element"=>"c", "atom_type"=>22, "bonds"…
We also extract binary labels, which are stored in the "mutagenic"
key:
y_train = getindex.(jss_train, "mutagenic");
y_test = getindex.(jss_test, "mutagenic");
y_train
100-element Vector{Int64}:
1
1
0
1
1
1
1
1
1
1
⋮
0
0
1
1
0
0
1
0
0
We first create the schema
of the training data, which is the first important step in using the JsonGrinder.jl
. This infers both the hierarchical structure of the documents and basic statistics of individual values.
sch = schema(jss_train)
DictEntry 100x updated
├────── atoms: ArrayEntry 100x updated
│ ╰── DictEntry 2529x updated
│ ├── atom_type: LeafEntry (28 unique `Real` values) 25 ⋯
│ ├────── bonds: ArrayEntry 2529x updated
│ │ ╰── DictEntry 5402x updated
│ │ ┊
│ ├───── charge: LeafEntry (318 unique `Real` values) 2 ⋯
│ ╰──── element: LeafEntry (6 unique `String` values) 2 ⋯
├─────── ind1: LeafEntry (2 unique `Real` values) 100x updated
├─────── inda: LeafEntry (1 unique `Real` values) 100x updated
├─────── logp: LeafEntry (62 unique `Real` values) 100x updated
├─────── lumo: LeafEntry (98 unique `Real` values) 100x updated
╰── mutagenic: LeafEntry (2 unique `Real` values) 100x updated
Of course, we have to remove the "mutagenic"
key from the schema, as we don't want to include it in the data:
delete!(sch, :mutagenic);
sch
DictEntry 100x updated
├── atoms: ArrayEntry 100x updated
│ ╰── DictEntry 2529x updated
│ ├── atom_type: LeafEntry (28 unique `Real` values) 2529x ⋯
│ ├────── bonds: ArrayEntry 2529x updated
│ │ ╰── DictEntry 5402x updated
│ │ ┊
│ ├───── charge: LeafEntry (318 unique `Real` values) 2529x ⋯
│ ╰──── element: LeafEntry (6 unique `String` values) 2529x ⋯
├─── ind1: LeafEntry (2 unique `Real` values) 100x updated
├─── inda: LeafEntry (1 unique `Real` values) 100x updated
├─── logp: LeafEntry (62 unique `Real` values) 100x updated
╰─── lumo: LeafEntry (98 unique `Real` values) 100x updated
Now we create an extractor capable of converting JSONs to Mill.jl
structures. We use function suggestextractor
with the default settings:
e = suggestextractor(sch)
DictExtractor
├─── lumo: CategoricalExtractor(n=99)
├─── inda: CategoricalExtractor(n=2)
├─── logp: CategoricalExtractor(n=63)
├─── ind1: CategoricalExtractor(n=3)
╰── atoms: ArrayExtractor
╰── DictExtractor
├──── element: CategoricalExtractor(n=7)
├────── bonds: ArrayExtractor
│ ╰── DictExtractor
│ ┊
├───── charge: ScalarExtractor(c=-0.781, s=0.60790277)
╰── atom_type: CategoricalExtractor(n=29)
We also need to convert JSONs to Mill.jl
data samples. Extractor e
is callable, we can use it to extract one document as follows:
x_single = e(jss_train[1])
ProductNode 1 obs
├─── lumo: ArrayNode(99×1 OneHotArray with Bool elements) 1 obs
├─── inda: ArrayNode(2×1 OneHotArray with Bool elements) 1 obs
├─── logp: ArrayNode(63×1 OneHotArray with Bool elements) 1 obs
├─── ind1: ArrayNode(3×1 OneHotArray with Bool elements) 1 obs
╰── atoms: BagNode 1 obs
╰── ProductNode 26 obs
├──── element: ArrayNode(7×26 OneHotArray with Bool eleme ⋯
├────── bonds: BagNode 26 obs
│ ╰── ProductNode 56 obs
│ ┊
├───── charge: ArrayNode(1×26 Array with Float32 elements ⋯
╰── atom_type: ArrayNode(29×26 OneHotArray with Bool elem ⋯
To extract a batch of 10 documents, we can extract individual documents and then Mill.catobs
them:
x_batch = reduce(catobs, e.(jss_train[1:10]))
ProductNode 10 obs
├─── lumo: ArrayNode(99×10 OneHotArray with Bool elements) 10 obs
├─── inda: ArrayNode(2×10 OneHotArray with Bool elements) 10 obs
├─── logp: ArrayNode(63×10 OneHotArray with Bool elements) 10 obs
├─── ind1: ArrayNode(3×10 OneHotArray with Bool elements) 10 obs
╰── atoms: BagNode 10 obs
╰── ProductNode 299 obs
├──── element: ArrayNode(7×299 OneHotArray with Bool elem ⋯
├────── bonds: BagNode 299 obs
│ ╰── ProductNode 650 obs
│ ┊
├───── charge: ArrayNode(1×299 Array with Float32 element ⋯
╰── atom_type: ArrayNode(29×299 OneHotArray with Bool ele ⋯
Or we can use a much more efficient extract
function, which operates on a list of documents: Because the dataset is small, we can extract all data at once and keep it in memory:
x_train = extract(e, jss_train);
x_test = extract(e, jss_test);
x_train
ProductNode 100 obs
├─── lumo: ArrayNode(99×100 OneHotArray with Bool elements) 100 obs
├─── inda: ArrayNode(2×100 OneHotArray with Bool elements) 100 obs
├─── logp: ArrayNode(63×100 OneHotArray with Bool elements) 100 obs
├─── ind1: ArrayNode(3×100 OneHotArray with Bool elements) 100 obs
╰── atoms: BagNode 100 obs
╰── ProductNode 2529 obs
├──── element: ArrayNode(7×2529 OneHotArray with Bool ele ⋯
├────── bonds: BagNode 2529 obs
│ ╰── ProductNode 5402 obs
│ ┊
├───── charge: ArrayNode(1×2529 Array with Float32 elemen ⋯
╰── atom_type: ArrayNode(29×2529 OneHotArray with Bool el ⋯
Then we create an encoding model capable of embedding each JSON document into a fixed-size vector.
encoder = reflectinmodel(sch, e)
ProductModel ↦ Dense(50 => 10) 2 arrays, 510 params, 2.078 KiB
├─── lumo: ArrayModel(Dense(99 => 10)) 2 arrays, 1_000 params, 3.992 KiB
├─── inda: ArrayModel(Dense(2 => 10)) 2 arrays, 30 params, 208 bytes
├─── logp: ArrayModel(Dense(63 => 10)) 2 arrays, 640 params, 2.586 KiB
├─── ind1: ArrayModel(Dense(3 => 10)) 2 arrays, 40 params, 248 bytes
╰── atoms: BagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dens ⋯
╰── ProductModel ↦ Dense(31 => 10) 2 arrays, 320 params, 1.336 ⋯
├──── element: ArrayModel(Dense(7 => 10)) 2 arrays, 80 p ⋯
├────── bonds: BagModel ↦ BagCount([SegmentedMean(10); Se ⋯
│ ╰── ProductModel ↦ Dense(31 => 10) 2 ar ⋯
│ ┊
├───── charge: ArrayModel(identity)
╰── atom_type: ArrayModel(Dense(29 => 10)) 2 arrays, 300 ⋯
For further details about reflectinmodel
, see the Mill.jl documentation.
Finally, we chain the encoder
with one more dense layer computing the logit of mutagenic probability:
model = vec ∘ Dense(10, 1) ∘ encoder
vec ∘ Dense(10 => 1) ∘ ProductModel ↦ Dense(50 => 10)
We can train the model in the standard Flux.jl
way. We define the loss function, optimizer, and minibatch iterator:
pred(m, x) = σ.(m(x))
loss(m, x, y) = Flux.Losses.logitbinarycrossentropy(m(x), y);
opt_state = Flux.setup(Flux.Optimise.Descent(), model);
minibatch_iterator = Flux.DataLoader((x_train, y_train), batchsize=32, shuffle=true);
We train for 10 epochs, and after each epoch we report the training accuracy:
accuracy(p, y) = mean((p .> 0.5) .== y)
for i in 1:10
Flux.train!(loss, model, minibatch_iterator, opt_state)
@info "Epoch $i" accuracy=accuracy(pred(model, x_train), y_train)
end
┌ Info: Epoch 1
└ accuracy = 0.61
┌ Info: Epoch 2
└ accuracy = 0.63
┌ Info: Epoch 3
└ accuracy = 0.64
┌ Info: Epoch 4
└ accuracy = 0.74
┌ Info: Epoch 5
└ accuracy = 0.61
┌ Info: Epoch 6
└ accuracy = 0.82
┌ Info: Epoch 7
└ accuracy = 0.82
┌ Info: Epoch 8
└ accuracy = 0.84
┌ Info: Epoch 9
└ accuracy = 0.82
┌ Info: Epoch 10
└ accuracy = 0.82
We can compute the accuracy on the testing set now:
accuracy(pred(model, x_test), y_test)
0.8636363636363636