1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
|
require(mxnet) require(data.table)
get.lenet <- function() { source <- mx.symbol.Variable("data") source <- (source-128) / 128 frames <- mx.symbol.SliceChannel(source, num.outputs = 30) diffs <- list() for (i in 1:29) { diffs <- c(diffs, frames[[i + 1]] - frames[[i]]) } diffs$num.args = 29 source <- mxnet:::mx.varg.symbol.Concat(diffs) net <- mx.symbol.Convolution(source, kernel = c(5, 5), num.filter = 40) net <- mx.symbol.BatchNorm(net, fix.gamma = TRUE) net <- mx.symbol.Activation(net, act.type = "relu") net <- mx.symbol.Pooling( net, pool.type = "max", kernel = c(2, 2), stride = c(2, 2) ) net <- mx.symbol.Convolution(net, kernel = c(3, 3), num.filter = 40) net <- mx.symbol.BatchNorm(net, fix.gamma = TRUE) net <- mx.symbol.Activation(net, act.type = "relu") net <- mx.symbol.Pooling( net, pool.type = "max", kernel = c(2, 2), stride = c(2, 2) ) flatten <- mx.symbol.Flatten(net) flatten <- mx.symbol.Dropout(flatten) fc1 <- mx.symbol.FullyConnected(data = flatten, num.hidden = 600) return(mx.symbol.LogisticRegressionOutput(data = fc1, name = 'softmax')) }
network <- get.lenet() batch_size <- 32
data_train <- mx.io.CSVIter( data.csv = "train-64x64-data.csv", data.shape = c(64, 64, 30), label.csv = "train-systole.csv", label.shape = 600, batch.size = batch_size )
data_validate <- mx.io.CSVIter( data.csv = "validate-64x64-data.csv", data.shape = c(64, 64, 30), batch.size = 1 )
mx.metric.CRPS <- mx.metric.custom("CRPS", function(label, pred) { pred <- as.array(pred) label <- as.array(label) for (i in 1:dim(pred)[2]) { for (j in 1:(dim(pred)[1] - 1)) { if (pred[j, i] > pred[j + 1, i]) { pred[j + 1, i] = pred[j, i] } } } return(sum((label - pred) ^ 2) / length(label)) })
mx.set.seed(0) stytole_model <- mx.model.FeedForward.create( X = data_train, ctx = mx.gpu(0), symbol = network, num.round = 65, learning.rate = 0.001, wd = 0.00001, momentum = 0.9, eval.metric = mx.metric.CRPS )
stytole_prob = predict(stytole_model, data_validate)
network = get.lenet() batch_size = 32 data_train <- mx.io.CSVIter( data.csv = "./train-64x64-data.csv", data.shape = c(64, 64, 30), label.csv = "./train-diastole.csv", label.shape = 600, batch.size = batch_size )
diastole_model = mx.model.FeedForward.create( X = data_train, ctx = mx.gpu(0), symbol = network, num.round = 65, learning.rate = 0.001, wd = 0.00001, momentum = 0.9, eval.metric = mx.metric.CRPS )
diastole_prob = predict(diastole_model, data_validate)
accumulate_result <- function(validate_lst, prob) { t <- read.table(validate_lst, sep = ",") p <- cbind(t[,1], t(prob)) dt <- as.data.table(p) return(dt[, lapply(.SD, mean), by = V1]) }
stytole_result = as.data.frame(accumulate_result("./validate-label.csv", stytole_prob)) diastole_result = as.data.frame(accumulate_result("./validate-label.csv", diastole_prob))
train_csv <- read.table("./train-label.csv", sep = ',')
doHist <- function(data) { res <- rep(0, 600) for (i in 1:length(data)) { for (j in round(data[i]):600) { res[j] = res[j] + 1 } } return(res / length(data)) }
hSystole = doHist(train_csv[, 2]) hDiastole = doHist(train_csv[, 3])
res <- read.table("data/sample_submission_validate.csv", sep = ",", header = TRUE, stringsAsFactors = FALSE)
submission_helper <- function(pred) { for (i in 2:length(pred)) { if (pred[i] < pred[i - 1]) { pred[i] = pred[i - 1] } } return(pred) }
for (i in 1:nrow(res)) { key <- unlist(strsplit(res$Id[i], "_"))[1] target <- unlist(strsplit(res$Id[i], "_"))[2] if (key %in% stytole_result$V1) { if (target == 'Diastole') { res[i, 2:601] <- submission_helper(diastole_result[which(diastole_result$V1 == key), 2:601]) } else { res[i, 2:601] <- submission_helper(stytole_result[which(stytole_result$V1 == key), 2:601]) } } else { if (target == 'Diastole') { res[i, 2:601] <- hDiastole } else { res[i, 2:601] <- hSystole } } }
write.table(res, file = "submission.csv", sep = ",", quote = FALSE, row.names = FALSE)
|