Ching-Chuan Chen's Blogger

Statistics, Machine Learning and Programming

0%

Grid search in data.table

有人傳了一篇用tidyverse做grid search的blogger給我看 (Grid search in the tidyverse)

我想說那我來寫一篇for data.table的吧

code如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
library(data.table)
library(pipeR)
library(rpart)

set.seed(245)
dataDT <- data.table(mtcars) %>>% `[`(j = am := factor(am, labels = c("Automatic", "Manual")))
dataDT[ , trainFlag := !is.na(match(.I, sample.int(nrow(mtcars), floor(.8 * nrow(mtcars)))))]

buildMdl <- function(s, d) {
rpart(am ~ hp + mpg, dataDT[trainFlag == TRUE], control = rpart.control(minsplit = s, maxdepth = d))
}

gs <- CJ(minsplit = c(2, 5, 10), maxdepth = c(1, 3, 8))
gs[ , mod := list(mapply(buildMdl, minsplit, maxdepth, SIMPLIFY = FALSE))]
print(gs)
# minsplit maxdepth mod
# 1: 2 1 <rpart>
# 2: 2 3 <rpart>
# 3: 2 8 <rpart>
# 4: 5 1 <rpart>
# 5: 5 3 <rpart>
# 6: 5 8 <rpart>
# 7: 10 1 <rpart>
# 8: 10 3 <rpart>
# 9: 10 8 <rpart>

calAccu <- function(mod, testData, testLabel) {
mean(predict(mod, testData, type = "class") == testLabel)
}

gs[ , `:=`(trainAccu = mapply(function(m) dataDT[trainFlag == TRUE] %>>% {calAccu(m, ., .$am)}, mod),
testAccu = mapply(function(m) dataDT[trainFlag == FALSE] %>>% {calAccu(m, ., .$am)}, mod))]
setorder(gs, -testAccu, -trainAccu)
print(gs)
# minsplit maxdepth mod trainAccu testAccu
# 1: 2 8 <rpart> 1.00 0.8571429
# 2: 2 3 <rpart> 0.92 0.8571429
# 3: 5 3 <rpart> 0.88 0.8571429
# 4: 5 8 <rpart> 0.88 0.8571429
# 5: 2 1 <rpart> 0.84 0.7142857
# 6: 5 1 <rpart> 0.84 0.7142857
# 7: 10 1 <rpart> 0.84 0.7142857
# 8: 10 3 <rpart> 0.84 0.7142857
# 9: 10 8 <rpart> 0.84 0.7142857

我是覺得寫起來有點麻煩XD,改成用foreach + iterators我覺得會好很多,code如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
library(data.table)
library(pipeR)
library(rpart)
library(foreach)
library(iterators)

set.seed(245)
dataDT <- data.table(mtcars) %>>% `[`(j = am := factor(am, labels = c("Automatic", "Manual")))
dataDT[ , trainFlag := !is.na(match(.I, sample.int(nrow(mtcars), floor(.8 * nrow(mtcars)))))]

resDT <- CJ(minsplit = c(2, 5, 10), maxdepth = c(1, 3, 8)) %>>%
{
foreach(i = isplit(., as.list(.)), .final = rbindlist) %do%
{
mod <- rpart(am ~ hp + mpg, dataDT[trainFlag == TRUE],
control = do.call(rpart.control, i$value))
trainAccu <- dataDT[trainFlag == TRUE] %>>% {mean(predict(mod, ., type = "class") == .$am)}
testAccu <- dataDT[trainFlag == FALSE] %>>% {mean(predict(mod, ., type = "class") == .$am)}
return(cbind(i$value, data.table(mod = list(mod), trainAccu = trainAccu, testAccu = testAccu)))
}
} %>>% setorder(-testAccu, -trainAccu)
print(resDT)
# minsplit maxdepth mod trainAccu testAccu
# 1: 2 8 <rpart> 1.00 0.8571429
# 2: 2 3 <rpart> 0.92 0.8571429
# 3: 5 3 <rpart> 0.88 0.8571429
# 4: 5 8 <rpart> 0.88 0.8571429
# 5: 2 1 <rpart> 0.84 0.7142857
# 6: 5 1 <rpart> 0.84 0.7142857
# 7: 10 1 <rpart> 0.84 0.7142857
# 8: 10 3 <rpart> 0.84 0.7142857
# 9: 10 8 <rpart> 0.84 0.7142857

isplit那裏是shadow copy,所以應該不會造成什麼問題

這樣寫會比用tidyverse或是直接data.table + mapply來的直覺,而且程式會相對精簡很多