데이터마이닝 연습으로 Random Forest을 사용한 예제이다.
먼저 전처리에서 사용한 packages이다.
library(tidyverse)
library(skimr)
그 다음 모형을 만들고자 하는 데이터를 불러왔다. 해당 데이터는 카글에서 가지고 온 데이터다.
raw_data=read_csv('./data/ausraindata.csv',col_types = cols(Evaporation = col_double(),
Sunshine = col_double()) )
해당 데이터는 오늘의 날씨를 바탕으로 내일의 강수유무를 예측해 보는 데이터 이다. 이 자료는 수많은 Australian weather stations의 자료가 포함되어 있다. 총 24개의 열로 구성되어 있으며 각 열별 데이터의 의미는 아래와 같다.
Date : The date of observation
Location : weather station의 위치
MinTemp : minimum temperature in degrees Celsius
MaxTemp : The maximum temperature in degrees Celsius
Rainfall : 당일 강수량 (mm)
Evaporation : Class A pan evaporation (mm) in the 24 hours to 9am
Sunshine : The number of hours of bright sunshine in the day.
WindGustDir : The direction of the strongest wind gust in the 24 hours to midnight
WindGustSpeed : The speed (km/h) of the strongest wind gust in the 24 hours to midnight
WindDir9am : Direction of the wind at 9am
WindDir3pm : Direction of the wind at 3pm
WindSpeed9am : Wind speed (km/hr) averaged over 10 minutes prior to 9am
WindSpeed3pm : Wind speed (km/hr) averaged over 10 minutes prior to 3pm
Humidity9am : Humidity (percent) at 9am
Humidity3pm : Humidity (percent) at 3pm
Pressure9am : Atmospheric pressure (hpa) reduced to mean sea level at 9am
Pressure3pm : Atmospheric pressure (hpa) reduced to mean sea level at 3pm
Cloud9am : Fraction of sky obscured by cloud at 9am. This is measured in “oktas”, which are a unit of eigths. It records how many eigths of the sky are obscured by cloud. A 0 measure indicates completely clear sky whilst an 8 indicates that it is completely overcast.
Cloud3pm : Fraction of sky obscured by cloud at 3pm. This is measured in “oktas”, which are a unit of eigths. It records how many eigths of the sky are obscured by cloud. A 0 measure indicates completely clear sky whilst an 8 indicates that it is completely overcast.
Temp9am : Temperature (degrees C) at 9am
Temp3pm : Temperature (degrees C) at 3pm
RainToday : Boolean: 1 if precipitation (mm) in the 24 hours to 9am exceeds 1mm, otherwise 0
RISK_MM : 다음날의 강수량(mm)
RainTomorrow : 우리의 목표변수로 다음날 강수 유무
전처리
분석에 앞서, 실제 데이터 구성을 직접 보면서 전처리를 시행한다.
glimpse(raw_data)
## Rows: 142,193
## Columns: 24
## $ Date <date> 2008-12-01, 2008-12-02, 2008-12-03, 2008-12-04, 2008...
## $ Location <chr> "Albury", "Albury", "Albury", "Albury", "Albury", "Al...
## $ MinTemp <dbl> 13.4, 7.4, 12.9, 9.2, 17.5, 14.6, 14.3, 7.7, 9.7, 13....
## $ MaxTemp <dbl> 22.9, 25.1, 25.7, 28.0, 32.3, 29.7, 25.0, 26.7, 31.9,...
## $ Rainfall <dbl> 0.6, 0.0, 0.0, 0.0, 1.0, 0.2, 0.0, 0.0, 0.0, 1.4, 0.0...
## $ Evaporation <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N...
## $ Sunshine <dbl> NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, NA, N...
## $ WindGustDir <chr> "W", "WNW", "WSW", "NE", "W", "WNW", "W", "W", "NNW",...
## $ WindGustSpeed <dbl> 44, 44, 46, 24, 41, 56, 50, 35, 80, 28, 30, 31, 61, 4...
## $ WindDir9am <chr> "W", "NNW", "W", "SE", "ENE", "W", "SW", "SSE", "SE",...
## $ WindDir3pm <chr> "WNW", "WSW", "WSW", "E", "NW", "W", "W", "W", "NW", ...
## $ WindSpeed9am <dbl> 20, 4, 19, 11, 7, 19, 20, 6, 7, 15, 17, 15, 28, 24, N...
## $ WindSpeed3pm <dbl> 24, 22, 26, 9, 20, 24, 24, 17, 28, 11, 6, 13, 28, 20,...
## $ Humidity9am <dbl> 71, 44, 38, 45, 82, 55, 49, 48, 42, 58, 48, 89, 76, 6...
## $ Humidity3pm <dbl> 22, 25, 30, 16, 33, 23, 19, 19, 9, 27, 22, 91, 93, 43...
## $ Pressure9am <dbl> 1007.7, 1010.6, 1007.6, 1017.6, 1010.8, 1009.2, 1009....
## $ Pressure3pm <dbl> 1007.1, 1007.8, 1008.7, 1012.8, 1006.0, 1005.4, 1008....
## $ Cloud9am <dbl> 8, NA, NA, NA, 7, NA, 1, NA, NA, NA, NA, 8, 8, NA, 0,...
## $ Cloud3pm <dbl> NA, NA, 2, NA, 8, NA, NA, NA, NA, NA, NA, 8, 8, 7, NA...
## $ Temp9am <dbl> 16.9, 17.2, 21.0, 18.1, 17.8, 20.6, 18.1, 16.3, 18.3,...
## $ Temp3pm <dbl> 21.8, 24.3, 23.2, 26.5, 29.7, 28.9, 24.6, 25.5, 30.2,...
## $ RainToday <chr> "No", "No", "No", "No", "No", "No", "No", "No", "No",...
## $ RISK_MM <dbl> 0.0, 0.0, 0.0, 1.0, 0.2, 0.0, 0.0, 0.0, 1.4, 0.0, 2.2...
## $ RainTomorrow <chr> "No", "No", "No", "No", "No", "No", "No", "No", "Yes"...
먼저 이 데이터는 2008-12-01 부터 2017-06-24 까지의 데이터이다. 해당 날짜의 모든 데이터를 가지고 있는 시계열 자료로 볼 수 있으나, 뉴럴네트워크를 사용하기에 해당 Date 변수는 사용하지 않는다.
또한 우리는 분류를 해야하므로 내일 강수량이 아닌 강수유무를 목적변수로 한다. 따라서 내일 강수량인 RISK_MM도 제외한다.
raw_data %>% filter(is.na(Evaporation)) %>%
count() / nrow(raw_data)
## n
## 1 0.4278903
raw_data %>% filter(is.na(Sunshine)) %>%
count() / nrow(raw_data)
## n
## 1 0.4769292
raw_data %>% filter(is.na(Cloud9am)) %>%
count() / nrow(raw_data)
## n
## 1 0.3773533
raw_data %>% filter(is.na(Cloud3pm)) %>%
count() / nrow(raw_data)
## n
## 1 0.4015247
또한 위에서 볼 수있듯이 4개의 열들은 결측치의 값이 높은 것을 알 수 있다. 해당 행을 `na.omit을 이용해 삭제하였다.
nrow( raw_data %>% distinct(Location) )
## [1] 49
nrow( raw_data %>% distinct(WindGustDir) )
## [1] 17
nrow( raw_data %>% distinct(WindDir9am) )
## [1] 17
nrow( raw_data %>% distinct(WindDir3pm) )
## [1] 17
또한 49개로 그 범주의 개수가 다양한 Location과 16개(NA 1개포함) 방향변수들도 제거하였다.
앞서 언급한 열들을 제거하고, na인 행을 제거하고, 비가 온 경우(‘Yes’)를 1로 변환하였다. factor형으로 만들었다.
temp_data = raw_data %>%
select(
-Date, -RISK_MM,
# -Evaporation, -Sunshine, -Cloud9am, -Cloud3pm,
-Location, -WindGustDir,-WindDir9am, -WindDir3pm
) %>%
na.omit() %>%
mutate(
RainToday = as.numeric(RainToday=='Yes'),
RainTomorrow = as.factor(as.numeric(RainTomorrow=='Yes'))
)
총 58090개의 관측지가 있다. 그 결과를 skimr 패키지를 이용해 간략히 보였다.
skim(temp_data) # skimr 패키지
Name | temp_data |
Number of rows | 58090 |
Number of columns | 18 |
_______________________ | |
Column type frequency: | |
factor | 1 |
numeric | 17 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
RainTomorrow | 0 | 1 | FALSE | 2 | 0: 45361, 1: 12729 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
MinTemp | 0 | 1 | 13.34 | 6.47 | -6.7 | 8.4 | 13.1 | 18.3 | 31.4 | ▁▅▇▆▁ |
MaxTemp | 0 | 1 | 24.13 | 6.97 | 4.1 | 18.6 | 23.8 | 29.6 | 48.1 | ▁▇▇▅▁ |
Rainfall | 0 | 1 | 2.12 | 6.99 | 0.0 | 0.0 | 0.0 | 0.6 | 206.2 | ▇▁▁▁▁ |
Evaporation | 0 | 1 | 5.45 | 3.69 | 0.0 | 2.8 | 4.8 | 7.4 | 81.2 | ▇▁▁▁▁ |
Sunshine | 0 | 1 | 7.70 | 3.77 | 0.0 | 5.0 | 8.6 | 10.7 | 14.5 | ▃▃▅▇▃ |
WindGustSpeed | 0 | 1 | 40.56 | 13.38 | 9.0 | 31.0 | 39.0 | 48.0 | 124.0 | ▃▇▂▁▁ |
WindSpeed9am | 0 | 1 | 15.24 | 8.58 | 0.0 | 9.0 | 15.0 | 20.0 | 67.0 | ▇▆▂▁▁ |
WindSpeed3pm | 0 | 1 | 19.58 | 8.56 | 0.0 | 13.0 | 19.0 | 24.0 | 76.0 | ▆▇▂▁▁ |
Humidity9am | 0 | 1 | 66.22 | 18.63 | 0.0 | 55.0 | 67.0 | 80.0 | 100.0 | ▁▂▅▇▅ |
Humidity3pm | 0 | 1 | 49.70 | 20.22 | 0.0 | 36.0 | 51.0 | 63.0 | 100.0 | ▂▅▇▅▂ |
Pressure9am | 0 | 1 | 1017.33 | 6.94 | 980.5 | 1012.7 | 1017.3 | 1022.0 | 1040.4 | ▁▁▇▇▁ |
Pressure3pm | 0 | 1 | 1014.88 | 6.90 | 977.1 | 1010.1 | 1014.8 | 1019.5 | 1038.9 | ▁▁▇▇▁ |
Cloud9am | 0 | 1 | 4.25 | 2.80 | 0.0 | 1.0 | 5.0 | 7.0 | 8.0 | ▆▃▁▃▇ |
Cloud3pm | 0 | 1 | 4.33 | 2.65 | 0.0 | 2.0 | 5.0 | 7.0 | 9.0 | ▆▃▃▇▂ |
Temp9am | 0 | 1 | 18.09 | 6.60 | -0.9 | 12.9 | 17.7 | 23.2 | 39.4 | ▁▇▇▅▁ |
Temp3pm | 0 | 1 | 22.63 | 6.84 | 3.7 | 17.3 | 22.3 | 27.8 | 46.1 | ▁▇▇▃▁ |
RainToday | 0 | 1 | 0.22 | 0.41 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | ▇▁▁▁▂ |
이상 전처리를 완료하였다.
1. Training 데이터와 Test 데이터를 50:50의 비율로 분할하시오.
set.seed(1234)
sp_n = sample(1:nrow(temp_data), round(nrow(temp_data)/2))
train = temp_data[sp_n,]
test = temp_data[-sp_n,]
nrow(train);nrow(test)
## [1] 29045
## [1] 29045
2. R 프로그램의 ‘randomForest’ 명령어를 사용하여 랜덤포레스트 분석을 수행하고자 한다. 단, hyper-parameter는 아래와 같이 조정한다.
A. ntree=100 을 사용하고,
B. mtry = (0.1p, 0.2p, … , 1.0*p)을 내림하여 사용한다.
C. nodesize = (1, 0.01n, 0.02n, … , 0.10*n) 을 반올림하여 사용한다.
D. 그외 parameter 값들은 default 값을 사용한다.
먼저 17개의 변수로 RainTomorrow를 예측하므로 p는 17이 된다. 따라서 B의 조건에 따른 mtry를 구할 수 있다. 또한 train 데이터의 수는 29045개 이므로 n은 29045가 된다. 따라서 C조건에 따른 nodesize를 구할 수 있다. 각각은 다음과 같다.
( para_mtry = floor(17 * 1:10 / 10) )
## [1] 1 3 5 6 8 10 11 13 15 17
( para_ndsz = c(1, round( nrow(train) * 1:10 / 100 ) ) )
## [1] 1 290 581 871 1162 1452 1743 2033 2324 2614 2904
이제 코드를 돌리면 아래와 같다.
library(randomForest)
## randomForest 4.6-14
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
for(mt in para_mtry){
for(nd in para_ndsz){
print(paste('rf',mt,nd,sep='_'))
set.seed(2010)
assign(
paste('rf',mt,nd,sep='_'),
randomForest(RainTomorrow ~ ., data = train, ntree=100, mtry=mt, nodesize=nd)
)
}
}
## [1] "rf_1_1"
## [1] "rf_1_290"
## [1] "rf_1_581"
## [1] "rf_1_871"
## [1] "rf_1_1162"
## [1] "rf_1_1452"
## [1] "rf_1_1743"
## [1] "rf_1_2033"
## [1] "rf_1_2324"
## [1] "rf_1_2614"
## [1] "rf_1_2904"
## [1] "rf_3_1"
## [1] "rf_3_290"
## [1] "rf_3_581"
## [1] "rf_3_871"
## [1] "rf_3_1162"
## [1] "rf_3_1452"
## [1] "rf_3_1743"
## [1] "rf_3_2033"
## [1] "rf_3_2324"
## [1] "rf_3_2614"
## [1] "rf_3_2904"
## [1] "rf_5_1"
## [1] "rf_5_290"
## [1] "rf_5_581"
## [1] "rf_5_871"
## [1] "rf_5_1162"
## [1] "rf_5_1452"
## [1] "rf_5_1743"
## [1] "rf_5_2033"
## [1] "rf_5_2324"
## [1] "rf_5_2614"
## [1] "rf_5_2904"
## [1] "rf_6_1"
## [1] "rf_6_290"
## [1] "rf_6_581"
## [1] "rf_6_871"
## [1] "rf_6_1162"
## [1] "rf_6_1452"
## [1] "rf_6_1743"
## [1] "rf_6_2033"
## [1] "rf_6_2324"
## [1] "rf_6_2614"
## [1] "rf_6_2904"
## [1] "rf_8_1"
## [1] "rf_8_290"
## [1] "rf_8_581"
## [1] "rf_8_871"
## [1] "rf_8_1162"
## [1] "rf_8_1452"
## [1] "rf_8_1743"
## [1] "rf_8_2033"
## [1] "rf_8_2324"
## [1] "rf_8_2614"
## [1] "rf_8_2904"
## [1] "rf_10_1"
## [1] "rf_10_290"
## [1] "rf_10_581"
## [1] "rf_10_871"
## [1] "rf_10_1162"
## [1] "rf_10_1452"
## [1] "rf_10_1743"
## [1] "rf_10_2033"
## [1] "rf_10_2324"
## [1] "rf_10_2614"
## [1] "rf_10_2904"
## [1] "rf_11_1"
## [1] "rf_11_290"
## [1] "rf_11_581"
## [1] "rf_11_871"
## [1] "rf_11_1162"
## [1] "rf_11_1452"
## [1] "rf_11_1743"
## [1] "rf_11_2033"
## [1] "rf_11_2324"
## [1] "rf_11_2614"
## [1] "rf_11_2904"
## [1] "rf_13_1"
## [1] "rf_13_290"
## [1] "rf_13_581"
## [1] "rf_13_871"
## [1] "rf_13_1162"
## [1] "rf_13_1452"
## [1] "rf_13_1743"
## [1] "rf_13_2033"
## [1] "rf_13_2324"
## [1] "rf_13_2614"
## [1] "rf_13_2904"
## [1] "rf_15_1"
## [1] "rf_15_290"
## [1] "rf_15_581"
## [1] "rf_15_871"
## [1] "rf_15_1162"
## [1] "rf_15_1452"
## [1] "rf_15_1743"
## [1] "rf_15_2033"
## [1] "rf_15_2324"
## [1] "rf_15_2614"
## [1] "rf_15_2904"
## [1] "rf_17_1"
## [1] "rf_17_290"
## [1] "rf_17_581"
## [1] "rf_17_871"
## [1] "rf_17_1162"
## [1] "rf_17_1452"
## [1] "rf_17_1743"
## [1] "rf_17_2033"
## [1] "rf_17_2324"
## [1] "rf_17_2614"
## [1] "rf_17_2904"
각각의 parameter별 모형은 **rf_(mtry값)_(nodesize값)** 라는 객체에 저장되었다. 즉 mtry가 1이고 nodesize가 290인 모형은 rf_1_290라는 객체 저장되어 있다.
3. 위 2번의 조건에 맞는 랜덤포레스트를 training 데이터를 이용하여 생성하고, test 데이터를 이용하여 예측 정확도를 계산하고자 한다. 이때 예측정확도는 AUROC 값을 사용한다.
앞서 random forest를 생성하였다. auroc를 test 데이터를 이용해 계산하면 아래와 같다. (message=F)
for( md in ls(pattern = "^rf_") ){
assign(
paste('pred',md ,sep='_'),
predict( get(md), newdata=test , type="prob")
)
assign(
paste('roc',md ,sep='_'),
roc(test$RainTomorrow ~ get(paste('pred',md ,sep='_'))[,2])
)
}
test 데이터를 이용하여 각각의 예측 정확도auroc값을 계산한 결과는 아래와 같다.
auc_matrix = matrix(nrow=10,ncol=11)
for( md in ls(pattern = "^roc_") ){
temp_mt = strsplit(md,split='_')[[1]][3]
temp_ns = strsplit(md,split='_')[[1]][4]
cat('mtry값:',temp_mt,'\t',
'nodesize값:',temp_ns,'\n',
sep ="")
temp_auc = get(md)$auc
print(temp_auc)
cat('\n')
auc_matrix[which(para_mtry ==temp_mt),which(para_ndsz ==temp_ns)] = temp_auc
}
## mtry값:1 nodesize값:1
## Area under the curve: 0.8798
##
## mtry값:1 nodesize값:1162
## Area under the curve: 0.853
##
## mtry값:1 nodesize값:1452
## Area under the curve: 0.8509
##
## mtry값:1 nodesize값:1743
## Area under the curve: 0.8466
##
## mtry값:1 nodesize값:2033
## Area under the curve: 0.849
##
## mtry값:1 nodesize값:2324
## Area under the curve: 0.8509
##
## mtry값:1 nodesize값:2614
## Area under the curve: 0.8466
##
## mtry값:1 nodesize값:290
## Area under the curve: 0.8572
##
## mtry값:1 nodesize값:2904
## Area under the curve: 0.8424
##
## mtry값:1 nodesize값:581
## Area under the curve: 0.8548
##
## mtry값:1 nodesize값:871
## Area under the curve: 0.8542
##
## mtry값:10 nodesize값:1
## Area under the curve: 0.8886
##
## mtry값:10 nodesize값:1162
## Area under the curve: 0.8393
##
## mtry값:10 nodesize값:1452
## Area under the curve: 0.8436
##
## mtry값:10 nodesize값:1743
## Area under the curve: 0.8351
##
## mtry값:10 nodesize값:2033
## Area under the curve: 0.8312
##
## mtry값:10 nodesize값:2324
## Area under the curve: 0.8316
##
## mtry값:10 nodesize값:2614
## Area under the curve: 0.8313
##
## mtry값:10 nodesize값:290
## Area under the curve: 0.859
##
## mtry값:10 nodesize값:2904
## Area under the curve: 0.8306
##
## mtry값:10 nodesize값:581
## Area under the curve: 0.8465
##
## mtry값:10 nodesize값:871
## Area under the curve: 0.8466
##
## mtry값:11 nodesize값:1
## Area under the curve: 0.888
##
## mtry값:11 nodesize값:1162
## Area under the curve: 0.8392
##
## mtry값:11 nodesize값:1452
## Area under the curve: 0.8365
##
## mtry값:11 nodesize값:1743
## Area under the curve: 0.8294
##
## mtry값:11 nodesize값:2033
## Area under the curve: 0.8265
##
## mtry값:11 nodesize값:2324
## Area under the curve: 0.8334
##
## mtry값:11 nodesize값:2614
## Area under the curve: 0.8274
##
## mtry값:11 nodesize값:290
## Area under the curve: 0.8599
##
## mtry값:11 nodesize값:2904
## Area under the curve: 0.8289
##
## mtry값:11 nodesize값:581
## Area under the curve: 0.8479
##
## mtry값:11 nodesize값:871
## Area under the curve: 0.8373
##
## mtry값:13 nodesize값:1
## Area under the curve: 0.887
##
## mtry값:13 nodesize값:1162
## Area under the curve: 0.8391
##
## mtry값:13 nodesize값:1452
## Area under the curve: 0.831
##
## mtry값:13 nodesize값:1743
## Area under the curve: 0.8314
##
## mtry값:13 nodesize값:2033
## Area under the curve: 0.8231
##
## mtry값:13 nodesize값:2324
## Area under the curve: 0.823
##
## mtry값:13 nodesize값:2614
## Area under the curve: 0.8182
##
## mtry값:13 nodesize값:290
## Area under the curve: 0.8557
##
## mtry값:13 nodesize값:2904
## Area under the curve: 0.8207
##
## mtry값:13 nodesize값:581
## Area under the curve: 0.8423
##
## mtry값:13 nodesize값:871
## Area under the curve: 0.8393
##
## mtry값:15 nodesize값:1
## Area under the curve: 0.8856
##
## mtry값:15 nodesize값:1162
## Area under the curve: 0.8333
##
## mtry값:15 nodesize값:1452
## Area under the curve: 0.8302
##
## mtry값:15 nodesize값:1743
## Area under the curve: 0.8219
##
## mtry값:15 nodesize값:2033
## Area under the curve: 0.8201
##
## mtry값:15 nodesize값:2324
## Area under the curve: 0.8177
##
## mtry값:15 nodesize값:2614
## Area under the curve: 0.8135
##
## mtry값:15 nodesize값:290
## Area under the curve: 0.8576
##
## mtry값:15 nodesize값:2904
## Area under the curve: 0.8134
##
## mtry값:15 nodesize값:581
## Area under the curve: 0.8329
##
## mtry값:15 nodesize값:871
## Area under the curve: 0.8343
##
## mtry값:17 nodesize값:1
## Area under the curve: 0.8852
##
## mtry값:17 nodesize값:1162
## Area under the curve: 0.8248
##
## mtry값:17 nodesize값:1452
## Area under the curve: 0.8245
##
## mtry값:17 nodesize값:1743
## Area under the curve: 0.8235
##
## mtry값:17 nodesize값:2033
## Area under the curve: 0.815
##
## mtry값:17 nodesize값:2324
## Area under the curve: 0.8085
##
## mtry값:17 nodesize값:2614
## Area under the curve: 0.8077
##
## mtry값:17 nodesize값:290
## Area under the curve: 0.8545
##
## mtry값:17 nodesize값:2904
## Area under the curve: 0.7965
##
## mtry값:17 nodesize값:581
## Area under the curve: 0.8312
##
## mtry값:17 nodesize값:871
## Area under the curve: 0.827
##
## mtry값:3 nodesize값:1
## Area under the curve: 0.89
##
## mtry값:3 nodesize값:1162
## Area under the curve: 0.8507
##
## mtry값:3 nodesize값:1452
## Area under the curve: 0.8467
##
## mtry값:3 nodesize값:1743
## Area under the curve: 0.8435
##
## mtry값:3 nodesize값:2033
## Area under the curve: 0.8476
##
## mtry값:3 nodesize값:2324
## Area under the curve: 0.8467
##
## mtry값:3 nodesize값:2614
## Area under the curve: 0.8396
##
## mtry값:3 nodesize값:290
## Area under the curve: 0.8589
##
## mtry값:3 nodesize값:2904
## Area under the curve: 0.8432
##
## mtry값:3 nodesize값:581
## Area under the curve: 0.8547
##
## mtry값:3 nodesize값:871
## Area under the curve: 0.8531
##
## mtry값:5 nodesize값:1
## Area under the curve: 0.8902
##
## mtry값:5 nodesize값:1162
## Area under the curve: 0.842
##
## mtry값:5 nodesize값:1452
## Area under the curve: 0.8445
##
## mtry값:5 nodesize값:1743
## Area under the curve: 0.8416
##
## mtry값:5 nodesize값:2033
## Area under the curve: 0.8336
##
## mtry값:5 nodesize값:2324
## Area under the curve: 0.8358
##
## mtry값:5 nodesize값:2614
## Area under the curve: 0.8394
##
## mtry값:5 nodesize값:290
## Area under the curve: 0.8624
##
## mtry값:5 nodesize값:2904
## Area under the curve: 0.837
##
## mtry값:5 nodesize값:581
## Area under the curve: 0.8525
##
## mtry값:5 nodesize값:871
## Area under the curve: 0.8487
##
## mtry값:6 nodesize값:1
## Area under the curve: 0.8896
##
## mtry값:6 nodesize값:1162
## Area under the curve: 0.8495
##
## mtry값:6 nodesize값:1452
## Area under the curve: 0.8422
##
## mtry값:6 nodesize값:1743
## Area under the curve: 0.8427
##
## mtry값:6 nodesize값:2033
## Area under the curve: 0.8414
##
## mtry값:6 nodesize값:2324
## Area under the curve: 0.8394
##
## mtry값:6 nodesize값:2614
## Area under the curve: 0.8389
##
## mtry값:6 nodesize값:290
## Area under the curve: 0.8598
##
## mtry값:6 nodesize값:2904
## Area under the curve: 0.8303
##
## mtry값:6 nodesize값:581
## Area under the curve: 0.8477
##
## mtry값:6 nodesize값:871
## Area under the curve: 0.8476
##
## mtry값:8 nodesize값:1
## Area under the curve: 0.8891
##
## mtry값:8 nodesize값:1162
## Area under the curve: 0.844
##
## mtry값:8 nodesize값:1452
## Area under the curve: 0.8441
##
## mtry값:8 nodesize값:1743
## Area under the curve: 0.8379
##
## mtry값:8 nodesize값:2033
## Area under the curve: 0.8389
##
## mtry값:8 nodesize값:2324
## Area under the curve: 0.8368
##
## mtry값:8 nodesize값:2614
## Area under the curve: 0.8314
##
## mtry값:8 nodesize값:290
## Area under the curve: 0.8596
##
## mtry값:8 nodesize값:2904
## Area under the curve: 0.8304
##
## mtry값:8 nodesize값:581
## Area under the curve: 0.8494
##
## mtry값:8 nodesize값:871
## Area under the curve: 0.8464
행렬로도 보여주면 아래와 같다.
rownames(auc_matrix) = para_mtry
colnames(auc_matrix) = para_ndsz
auc_matrix
## 1 290 581 871 1162 1452 1743
## 1 0.8797539 0.8572112 0.8548120 0.8541534 0.8530406 0.8508504 0.8465923
## 3 0.8899524 0.8589102 0.8546948 0.8531171 0.8507239 0.8466689 0.8434589
## 5 0.8901695 0.8624353 0.8525360 0.8486703 0.8419568 0.8445128 0.8416349
## 6 0.8895977 0.8597663 0.8477312 0.8476448 0.8495143 0.8421653 0.8426893
## 8 0.8890539 0.8596478 0.8494105 0.8464140 0.8439507 0.8441229 0.8379413
## 10 0.8885989 0.8589943 0.8464877 0.8465859 0.8392590 0.8435724 0.8351362
## 11 0.8880276 0.8599148 0.8479312 0.8372760 0.8392385 0.8365081 0.8294359
## 13 0.8869692 0.8557121 0.8423401 0.8392755 0.8390583 0.8309899 0.8313972
## 15 0.8855846 0.8575909 0.8329103 0.8342663 0.8333442 0.8302059 0.8218959
## 17 0.8851665 0.8544948 0.8311615 0.8269814 0.8248338 0.8244663 0.8234599
## 2033 2324 2614 2904
## 1 0.8489772 0.8509397 0.8465913 0.8424482
## 3 0.8475778 0.8466515 0.8395814 0.8431646
## 5 0.8336063 0.8358320 0.8394180 0.8369673
## 6 0.8413764 0.8393648 0.8389086 0.8302772
## 8 0.8388660 0.8367690 0.8313858 0.8303713
## 10 0.8311746 0.8315563 0.8313170 0.8306044
## 11 0.8264946 0.8333505 0.8274361 0.8288907
## 13 0.8231113 0.8229738 0.8182062 0.8207216
## 15 0.8200801 0.8176567 0.8134933 0.8133673
## 17 0.8149581 0.8084856 0.8076712 0.7964592
4. 3번의 결과, 총 110개의 AUROC 값을 구할 수 있다. 이를 mtry 값과 nodesize 값의 조합에 따라 AUROC 값으로 3차원 포물선 그래프를 생성하시오. (3D surface plot)
해당 auroc를 이용하여 3d surface plot을 그리면 아래와 같다.
library(plotly)
##
## Attaching package: 'plotly'
## The following object is masked from 'package:ggplot2':
##
## last_plot
## The following object is masked from 'package:stats':
##
## filter
## The following object is masked from 'package:graphics':
##
## layout
plot_ly(z = auc_matrix,x=para_mtry,y=para_ndsz) %>%
add_surface() %>%
layout(scene = list(
xaxis = list(title = 'mtry',tickvals=para_mtry),
yaxis = list(title = 'nodesize',tickvals=para_ndsz),
zaxis = list(title = 'auc'))
)
5. 4번의 결과에서 예측정확도가 가장 높은 최적의 hyper-parameter 조합은 무엇인지 밝히시오.
auc_matrix == max(auc_matrix)
## 1 290 581 871 1162 1452 1743 2033 2324 2614 2904
## 1 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 3 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 5 TRUE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 6 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 8 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 10 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 11 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 13 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 15 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
## 17 FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE FALSE
즉 mtry가 3이고, node size 가 1일 떄 가장 예측정확도가 높은 조합이다.