이제 가장 앞에서 했던 예제를 이용해서 nnet과 neuralnet의 성능을 다시 비교해보겠습니다. 결과가 몇 가지 카테고리로 나눠지는 분류형 작업에서는 nnet이 나은 성능을 보여줬습니다. 가격처럼 연속 변수를 추정하는데는 어느 쪽이 더 나은지 아직 비교를 하지 않았습니다. diamond 데이터에서 1000개를 뽑아 900개를 훈련 데이터로, 100개를 테스트 데이터로 삼고 비교를 해보겠습니다. 이번에는 scale을 사용해서 가격과 캐럿을 정규화시켜 같은 데이터를 두 인공지능 신경망 패키지에 넣습니다.
library("ggplot2")
library("neuralnet")
library(Metrics)
library(nnet)
diamonds$cut2[diamonds$cut=="Fair"]=0
diamonds$cut2[diamonds$cut=="Good"]=1
diamonds$cut2[diamonds$cut=="Very Good"]=2
diamonds$cut2[diamonds$cut=="Premium"]=3
diamonds$cut2[diamonds$cut=="Ideal"]=4
diamonds$color2[diamonds$color=="D"]=0
diamonds$color2[diamonds$color=="E"]=1
diamonds$color2[diamonds$color=="F"]=2
diamonds$color2[diamonds$color=="G"]=3
diamonds$color2[diamonds$color=="H"]=4
diamonds$color2[diamonds$color=="I"]=5
diamonds$color2[diamonds$color=="J"]=6
diamonds1<-diamonds c="" carat="" color2="" cut2="" price="" span="">-diamonds>
diamonds1$price2=scale(diamonds1$price)
diamonds1$carat2=scale(diamonds1$carat)
set.seed(1234)
diamonds2<-sample 1000="" diamonds1="" nrow="" replace="FALSE)</span">-sample>
diamonds2<-diamonds1 diamonds2="" span="">-diamonds1>
set.seed(1234)
n = nrow(diamonds2)
train <- 900="" n="" sample="" span="">->
test <- diamonds2="" span="" train="">->
train <- diamonds2="" span="" train="">->
neuralnet은 3:3:3구성으로 신경망을 구성합니다.
f<-price2 carat2="" color2="" cut2="" span="">-price2>
fit<-neuralnet f="" span="">-neuralnet>
data=train,
hidden=c(3,3,3),
algorithm = "rprop+",
err.fct = "sse",
act.fct = "logistic",
threshold = 0.1,
stepmax=1e7,
linear.output = TRUE)
pred<-compute c="" carat2="" color2="" cut2="" fit="" span="" test="">-compute>
result1<-cbind net="" pred="" span="" test="">-cbind>
result1$price3=pred$net*sd(train$price)+mean(train$price)
result1[,c("price","price3")]
result1$price4=result1$price-result1$price3
mean(result1$price4);sd(result1$price4)
결과는 아래와 같습니다.
> result1[,c("price","price3")]
price price3
1 5676 7158.4385300
2 878 719.8074114
3 829 831.3749722
4 15428 16547.6083877
5 421 694.0573474
6 1698 1324.2573993
7 2861 3542.3685795
8 658 895.0632051
9 8974 8354.0780604
10 9189 9101.1099778
11 492 918.1250754
12 6048 7087.4086483
13 15485 16419.7049268
14 16532 16059.2629705
15 3612 2505.1566194
16 8640 6620.4990300
17 3494 1688.2556233
18 1013 469.0950248
19 1659 1774.9824311
20 612 544.2927174
21 6053 6893.3466775
22 2205 1158.6017509
23 16733 16672.5136981
24 1380 1875.1675574
25 1125 895.0632051
26 15783 16566.4267716
27 5747 5396.3075564
28 723 943.2778839
29 657 800.9945224
30 12791 12190.2612976
31 7378 5750.2767065
32 624 515.8404538
33 2990 1541.3452801
34 10076 11068.4254156
35 1063 989.5016278
36 15312 14097.2092245
37 7119 5701.8445957
38 10316 7421.3986817
39 7340 8506.6367291
40 942 910.5026958
41 828 752.6034797
42 3625 4378.5422048
43 790 905.1833696
44 923 743.4189138
45 787 762.4159732
46 1608 1777.7261808
47 8075 8632.3751287
48 1289 1606.2536315
49 2671 2755.4297627
50 2578 2475.9034934
51 3337 3962.7829507
52 905 798.0972265
53 3348 4007.1614974
54 561 544.2927174
55 545 719.8074114
56 774 1052.5556643
57 4066 3878.7743141
58 10913 16508.5906138
59 4939 5732.4468281
60 9240 13285.0456263
61 5622 6214.4863908
62 12150 8228.7889543
63 5607 5038.2854738
64 2265 2475.9034934
65 1624 1405.7177625
66 6271 5841.0145488
67 680 910.5026958
68 15079 12292.4224548
69 13196 13324.3064177
70 2441 2799.6964773
71 2401 1800.0705063
72 16100 15529.9506136
73 2525 2558.6185463
74 2096 2599.6071408
75 1082 1112.0117404
76 802 806.3563402
77 817 644.6212505
78 4482 2762.6028638
79 526 355.0189879
80 421 752.6034797
81 8403 7137.2076733
82 3892 4617.0757212
83 2010 1606.2536315
84 685 868.5866515
85 2012 1311.4712970
86 3644 4988.1585407
87 17153 15026.2835178
88 850 998.2275551
89 3972 3280.0532132
90 12339 13465.4731286
91 759 807.9683077
92 696 547.3506904
93 8774 7059.1140595
94 3183 2951.0382289
95 2591 2477.5763848
96 945 584.3414026
97 5292 5732.4468281
98 1298 2565.5735340
99 889 735.2652309
100 3796 4450.8771871
>
> result1$price4=result1$price-result1$price3
> mean(result1$price4);sd(result1$price4)
[1] 46.3577482
[1] 1158.710832
그럭저럭 추정을 하기는 하는데, 가격이 평균 46달러 정도 차이가 있습니다. nnet에서는 maxit를 기본값으로 하면 결과를 찾지 못하기 때문에 500을 주고 노드는 5개로 했습니다.
nn <- carat2="" color2="" cut2="" data="train," decay="5e-4," maxit="500) </span" nnet="" price2="" size="5,">->
nn
price3<-predict nn="" span="" test="">-predict>
result2=cbind(test,price3)
result2
result2$price4=result2$price3*sd(train$price)+mean(train$price)
result2[,c("price","price4")]
result2$price5=result2$price-result2$price4
mean(result2$price5);sd(result2$price5)
결과를 보면 상당히 차이가 큽니다.
> result2$price4=result2$price3*sd(train$price)+mean(train$price)
> result2[,c("price","price4")]
price price4
1 5676 7672.007437
2 878 3894.393333
3 829 3894.393333
4 15428 7931.338301
5 421 3894.393333
6 1698 3894.393333
7 2861 3894.536388
8 658 3894.393333
9 8974 7916.883643
10 9189 7926.028269
11 492 3894.397454
12 6048 6569.112058
13 15485 7931.338288
14 16532 7931.338296
15 3612 3894.393333
16 8640 7898.185086
17 3494 3894.393333
18 1013 3894.393333
19 1659 3894.393333
20 612 3894.393333
21 6053 7930.365201
22 2205 3894.393333
23 16733 7931.338227
24 1380 3894.473040
25 1125 3894.393333
26 15783 7931.338301
27 5747 5179.993313
28 723 3894.393333
29 657 3894.396359
30 12791 7931.329113
31 7378 7007.191522
32 624 3894.394989
33 2990 3894.400044
34 10076 7931.333821
35 1063 3894.393333
36 15312 7931.338208
37 7119 5800.351948
38 10316 7929.702124
39 7340 7931.070966
40 942 3894.393333
41 828 3894.393333
42 3625 3897.432456
43 790 3894.445754
44 923 3894.393333
45 787 3894.393333
46 1608 3894.393333
47 8075 7927.931554
48 1289 3894.393333
49 2671 3894.393333
50 2578 3894.393333
51 3337 3950.021059
52 905 3894.393333
53 3348 4029.696071
54 561 3894.393333
55 545 3894.393333
56 774 3894.393333
57 4066 3894.681967
58 10913 7931.338299
59 4939 5429.746341
60 9240 7931.335309
61 5622 6141.245167
62 12150 7179.328294
63 5607 5098.811431
64 2265 3894.393333
65 1624 3894.393333
66 6271 6255.388441
67 680 3894.393333
68 15079 7931.334046
69 13196 7926.575540
70 2441 3894.393333
71 2401 3894.393333
72 16100 7931.338239
73 2525 3894.473636
74 2096 3894.393333
75 1082 3894.393333
76 802 3894.393333
77 817 3894.393333
78 4482 3894.395153
79 526 3894.393333
80 421 3894.393333
81 8403 7930.851183
82 3892 4846.138573
83 2010 3894.393333
84 685 3894.397274
85 2012 3894.393333
86 3644 3897.736542
87 17153 7931.337662
88 850 3894.399397
89 3972 3894.473520
90 12339 7931.337700
91 759 3894.393333
92 696 3894.393333
93 8774 7916.076486
94 3183 3894.399407
95 2591 3894.395289
96 945 3894.393333
97 5292 5429.746341
98 1298 3894.393333
99 889 3894.393333
100 3796 4021.678458
>
> result2$price5=result2$price-result2$price4
> mean(result2$price5);sd(result2$price5)
[1] -407.1997565
[1] 3325.384675
이것만 보면 하나의 신경망 층만 만드는 nnet 보다 여러 개의 층을 만들 수 있는 neuralnet이 더 좋은 선택인 것 같습니다. 물론 이는 학습하고자하는 자료나 알고리즘에 따라 크게 차이가 있을 수 있어 어느 쪽이 더 좋다고 단순하게 말할 순 없습니다. 하지만 하나의 패키지로 만족할만한 결과를 얻지 못했을 때 몇 가지 패키지의 사용법을 알고 있다면 가장 좋은 성능을 내는 패키지를 사용할 수 있다는 장점이 있습니다. 앞으로 몇 가지 패키지를 더 알아보겠습니다.
댓글
댓글 쓰기