1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
| library(data.table) library(pipeR) library(rpart)
set.seed(245) dataDT <- data.table(mtcars) %>>% `[`(j = am := factor(am, labels = c("Automatic", "Manual"))) dataDT[ , trainFlag := !is.na(match(.I, sample.int(nrow(mtcars), floor(.8 * nrow(mtcars)))))]
buildMdl <- function(s, d) { rpart(am ~ hp + mpg, dataDT[trainFlag == TRUE], control = rpart.control(minsplit = s, maxdepth = d)) }
gs <- CJ(minsplit = c(2, 5, 10), maxdepth = c(1, 3, 8)) gs[ , mod := list(mapply(buildMdl, minsplit, maxdepth, SIMPLIFY = FALSE))] print(gs)
calAccu <- function(mod, testData, testLabel) { mean(predict(mod, testData, type = "class") == testLabel) }
gs[ , `:=`(trainAccu = mapply(function(m) dataDT[trainFlag == TRUE] %>>% {calAccu(m, ., .$am)}, mod), testAccu = mapply(function(m) dataDT[trainFlag == FALSE] %>>% {calAccu(m, ., .$am)}, mod))] setorder(gs, -testAccu, -trainAccu) print(gs)
|