-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy path05_fit_mlp_model.R
55 lines (39 loc) · 1.23 KB
/
05_fit_mlp_model.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
library(mxnet)
#Devices
#Using CPU
#devs <- mx.cpu()
#For graphic devices
#use nvidia-smi to check status
gpus_cnt <- 0
devs <- lapply(X = 0:gpus_cnt, FUN = function(i) mx.gpu(i))
#How many learning rounds?
num_round <- 50
#How fast follow the gradient?
learning_rate <- 0.1
#How to deal with distributed learning?
kv_store <- "local"
#Checkpoint model
?mx.callback.save.checkpoint
?mx.model.FeedForward.create
checkpoint_mlp <- function(iteration, nbatch, env, verbose=TRUE) {
print(paste("Iteration=",iteration,format(Sys.time(), "%H:%M:%S")))
if (iteration %% 5 == 0) {
mx.model.save(env$model, "mlp", iteration)
cat(sprintf("Model checkpoint saved to %s-%04d.params\n", "mlp", iteration))
}
return(TRUE)
}
model <- mx.model.FeedForward.create(
X = mnist_train,
eval.data = mnist_validate,
ctx = devs,
symbol = mlp,
eval.metric = mx.metric.accuracy,
num.round = num_round,
learning.rate = learning_rate,
momentum = 0.9,
wd = 0.00001,
kvstore = kv_store,
array.batch.size = batch_size,
epoch.end.callback = checkpoint_mlp,
batch.end.callback = mx.callback.log.train.metric(150))