1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
| library(mxnet)
download_ <- function(data_dir) { dir.create(data_dir, showWarnings = FALSE) setwd(data_dir) if ((!file.exists('train-images-idx3-ubyte')) || (!file.exists('train-labels-idx1-ubyte')) || (!file.exists('t10k-images-idx3-ubyte')) || (!file.exists('t10k-labels-idx1-ubyte'))) { download.file(url='http://data.mxnet.io/mxnet/data/mnist.zip', destfile='mnist.zip') unzip("mnist.zip") file.remove("mnist.zip") } setwd("..") }
get_iterator <- function(data_shape) { get_iterator_impl <- function() { data_dir = 'mnist/' flat <- TRUE if (length(data_shape) == 3) flat <- FALSE train = mx.io.MNISTIter( image = paste0(data_dir, "train-images-idx3-ubyte"), label = paste0(data_dir, "train-labels-idx1-ubyte"), input_shape = data_shape, batch_size = 128, shuffle = TRUE, flat = flat) val = mx.io.MNISTIter( image = paste0(data_dir, "t10k-images-idx3-ubyte"), label = paste0(data_dir, "t10k-labels-idx1-ubyte"), input_shape = data_shape, batch_size = 128, flat = flat) ret = list(train=train, value=val) } get_iterator_impl }
get_mlp <- function() { data <- mx.symbol.Variable('data') fc1 <- mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) act1 <- mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") fc2 <- mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) act2 <- mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") fc3 <- mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) mlp <- mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') mlp }
get_lenet <- function() { data <- mx.symbol.Variable('data') conv1 <- mx.symbol.Convolution(data=data, kernel=c(5,5), num_filter=20) tanh1 <- mx.symbol.Activation(data=conv1, act_type="tanh") pool1 <- mx.symbol.Pooling(data=tanh1, pool_type="max", kernel=c(2,2), stride=c(2,2)) conv2 <- mx.symbol.Convolution(data=pool1, kernel=c(5,5), num_filter=50) tanh2 <- mx.symbol.Activation(data=conv2, act_type="tanh") pool2 <- mx.symbol.Pooling(data=tanh2, pool_type="max", kernel=c(2,2), stride=c(2,2)) flatten <- mx.symbol.Flatten(data=pool2) fc1 <- mx.symbol.FullyConnected(data=flatten, num_hidden=500) tanh3 <- mx.symbol.Activation(data=fc1, act_type="tanh") fc2 <- mx.symbol.FullyConnected(data=tanh3, num_hidden=10) lenet <- mx.symbol.SoftmaxOutput(data=fc2, name='softmax') lenet }
data_loader <- get_iterator(c(28, 28, 1)) download_('mnist/') net <- get_lenet() data <- data_loader() train <- data$train val <- data$value devs <- mx.gpu(0)
model <- mx.model.FeedForward.create( X = train, eval.data = val, ctx = devs, symbol = net, begin.round = 0, eval.metric = mx.metric.top_k_accuracy, num.round = 10, learning.rate = 0.05, array.batch.size = 128, optimizer = "sgd", initializer = mx.init.Xavier(factor_type="in", magnitude=2), batch.end.callback = mx.callback.log.train.metric(50) )
|