Skip to content

Commit

Permalink
Merge pull request #50 from aced-differentiate/example_cleanup
Browse files Browse the repository at this point in the history
tidying up examples
  • Loading branch information
rkurchin authored Feb 22, 2021
2 parents 4a9ae29 + 06ce74d commit 5a5633f
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 109 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
CSV = "0.7, 0.8"
ChemistryFeaturization = "0.2"
ChemistryFeaturization = "0.2.2"
DataFrames = "0.21, 0.22"
Flux = "0.11"
LightGraphs = "1.3"
Expand Down
19 changes: 3 additions & 16 deletions examples/1_formation_energy/formation_energy.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#=
Train a simple network to predict formation energy per atom (downloaded from Materials Project).
=#
#using Pkg
#Pkg.activate("../../")
using CSV, DataFrames
using Random, Statistics
using Flux
Expand All @@ -14,7 +12,7 @@ using AtomicGraphNets
println("Setting things up...")

# data-related options
num_pts = 100 # how many points to use? Up to 32530 in the formation energy case as of 2020/04/01
num_pts = 100 # how many points to use?
train_frac = 0.8 # what fraction for training?
num_epochs = 5 # how many epochs to train?
num_train = Int32(round(train_frac * num_pts))
Expand Down Expand Up @@ -48,13 +46,8 @@ output = y[indices]

# next, make graphs and build input features (matrices of dimension (# features, # nodes))
println("Building graphs and feature vectors from structures...")
#graphs = SimpleWeightedGraph{Int32, Float32}[]
#element_lists = Array{String}[]
#inputs = Tuple{Array{Float32,2},SparseArrays.SparseMatrixCSC{Float32,Int64}}[]
inputs = AtomGraph[]

#TODO: this with bulk processing fcn

for r in eachrow(info)
cifpath = string(datadir, prop, "_cifs/", r[Symbol(id)], ".cif")
gr = build_graph(cifpath)
Expand All @@ -71,21 +64,15 @@ train_input = inputs[1:num_train]
test_input = inputs[num_train+1:end]
train_data = zip(train_input, train_output)

# build the network (basically just copied from CGCNN.py for now): the convolutional layers, a mean pooling function, some dense layers, then fully connected output to one value for prediction

# build the model
println("Building the network...")
#model = Chain([AGNConv(num_features=>num_features) for i in 1:num_conv]..., AGNMeanPool(crys_fea_len, 0.1), [Dense(crys_fea_len, crys_fea_len, softplus) for i in 1:num_hidden_layers]..., Dense(crys_fea_len, 1))
model = Xie_model(num_features, num_conv=num_conv, atom_conv_feature_length=crys_fea_len, num_hidden_layers=1)

# MaxPool might make more sense?

# define loss function
# define loss function and a callback to monitor progress
loss(x,y) = Flux.mse(model(x), y)
# and a callback to see training progress
evalcb() = @show(mean(loss.(test_input, test_output)))
evalcb()

# train
println("Training!")
#Flux.train!(loss, params(model), train_data, opt)
@epochs num_epochs Flux.train!(loss, params(model), train_data, opt, cb = Flux.throttle(evalcb, 5))
4 changes: 3 additions & 1 deletion examples/2_qm9/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Example 2: QM9

In this example we will use the same network architecture but use data from the QM9 dataset. Note that the .xyz files provided within the QM9 dataset are not parsable directly by ASE, you need to remove the last couple lines, which is easy enough to script yourself, but I've included a small set of them here for demonstration purposes.
In this example we will use the same network architecture but use data from the QM9 dataset. Note that the .xyz files provided within the QM9 dataset are not parsable directly by ASE, you need to remove the last couple lines, which is easy enough to script yourself, but I've included a small set of them here for demonstration purposes.

NB: the actual model performance on QM9 is not that great because we're currently not encoding a variety of important features for organic molecules. This is provided mainly to show the processing of a different dataset and demonstrate batch processing capabilities.
5 changes: 0 additions & 5 deletions examples/3_deq/README.md

This file was deleted.

86 changes: 0 additions & 86 deletions examples/3_deq/deq.jl

This file was deleted.

0 comments on commit 5a5633f

Please sign in to comment.