기본 콘텐츠로 건너뛰기

R을 이용해 간단한 신경망 만들기 (10)





 이제 가장 앞에서 했던 예제를 이용해서 nnet과 neuralnet의 성능을 다시 비교해보겠습니다. 결과가 몇 가지 카테고리로 나눠지는 분류형 작업에서는 nnet이 나은 성능을 보여줬습니다. 가격처럼 연속 변수를 추정하는데는 어느 쪽이 더 나은지 아직 비교를 하지 않았습니다. diamond 데이터에서 1000개를 뽑아 900개를 훈련 데이터로, 100개를 테스트 데이터로 삼고 비교를 해보겠습니다. 이번에는 scale을 사용해서 가격과 캐럿을 정규화시켜 같은 데이터를 두 인공지능 신경망 패키지에 넣습니다.


library("ggplot2")
library("neuralnet")
library(Metrics)
library(nnet)

diamonds$cut2[diamonds$cut=="Fair"]=0
diamonds$cut2[diamonds$cut=="Good"]=1
diamonds$cut2[diamonds$cut=="Very Good"]=2
diamonds$cut2[diamonds$cut=="Premium"]=3
diamonds$cut2[diamonds$cut=="Ideal"]=4

diamonds$color2[diamonds$color=="D"]=0
diamonds$color2[diamonds$color=="E"]=1
diamonds$color2[diamonds$color=="F"]=2
diamonds$color2[diamonds$color=="G"]=3
diamonds$color2[diamonds$color=="H"]=4
diamonds$color2[diamonds$color=="I"]=5
diamonds$color2[diamonds$color=="J"]=6

diamonds1<-diamonds c="" carat="" color2="" cut2="" price="" span="">
diamonds1$price2=scale(diamonds1$price)
diamonds1$carat2=scale(diamonds1$carat)

set.seed(1234)
diamonds2<-sample 1000="" diamonds1="" nrow="" replace="FALSE)</span">
diamonds2<-diamonds1 diamonds2="" span="">

set.seed(1234)

n = nrow(diamonds2)
train <- 900="" n="" sample="" span="">
test <- diamonds2="" span="" train="">
train <- diamonds2="" span="" train="">


 neuralnet은 3:3:3구성으로 신경망을 구성합니다. 


f<-price2 carat2="" color2="" cut2="" span="">

fit<-neuralnet f="" span="">
               data=train, 
               hidden=c(3,3,3),
               algorithm = "rprop+",
               err.fct = "sse",
               act.fct = "logistic",
               threshold = 0.1,
               stepmax=1e7,
               linear.output = TRUE)

pred<-compute c="" carat2="" color2="" cut2="" fit="" span="" test="">
result1<-cbind net="" pred="" span="" test="">
result1$price3=pred$net*sd(train$price)+mean(train$price)
result1[,c("price","price3")]

result1$price4=result1$price-result1$price3
mean(result1$price4);sd(result1$price4)


결과는 아래와 같습니다. 



> result1[,c("price","price3")]
    price        price3
1    5676  7158.4385300
2     878   719.8074114
3     829   831.3749722
4   15428 16547.6083877
5     421   694.0573474
6    1698  1324.2573993
7    2861  3542.3685795
8     658   895.0632051
9    8974  8354.0780604
10   9189  9101.1099778
11    492   918.1250754
12   6048  7087.4086483
13  15485 16419.7049268
14  16532 16059.2629705
15   3612  2505.1566194
16   8640  6620.4990300
17   3494  1688.2556233
18   1013   469.0950248
19   1659  1774.9824311
20    612   544.2927174
21   6053  6893.3466775
22   2205  1158.6017509
23  16733 16672.5136981
24   1380  1875.1675574
25   1125   895.0632051
26  15783 16566.4267716
27   5747  5396.3075564
28    723   943.2778839
29    657   800.9945224
30  12791 12190.2612976
31   7378  5750.2767065
32    624   515.8404538
33   2990  1541.3452801
34  10076 11068.4254156
35   1063   989.5016278
36  15312 14097.2092245
37   7119  5701.8445957
38  10316  7421.3986817
39   7340  8506.6367291
40    942   910.5026958
41    828   752.6034797
42   3625  4378.5422048
43    790   905.1833696
44    923   743.4189138
45    787   762.4159732
46   1608  1777.7261808
47   8075  8632.3751287
48   1289  1606.2536315
49   2671  2755.4297627
50   2578  2475.9034934
51   3337  3962.7829507
52    905   798.0972265
53   3348  4007.1614974
54    561   544.2927174
55    545   719.8074114
56    774  1052.5556643
57   4066  3878.7743141
58  10913 16508.5906138
59   4939  5732.4468281
60   9240 13285.0456263
61   5622  6214.4863908
62  12150  8228.7889543
63   5607  5038.2854738
64   2265  2475.9034934
65   1624  1405.7177625
66   6271  5841.0145488
67    680   910.5026958
68  15079 12292.4224548
69  13196 13324.3064177
70   2441  2799.6964773
71   2401  1800.0705063
72  16100 15529.9506136
73   2525  2558.6185463
74   2096  2599.6071408
75   1082  1112.0117404
76    802   806.3563402
77    817   644.6212505
78   4482  2762.6028638
79    526   355.0189879
80    421   752.6034797
81   8403  7137.2076733
82   3892  4617.0757212
83   2010  1606.2536315
84    685   868.5866515
85   2012  1311.4712970
86   3644  4988.1585407
87  17153 15026.2835178
88    850   998.2275551
89   3972  3280.0532132
90  12339 13465.4731286
91    759   807.9683077
92    696   547.3506904
93   8774  7059.1140595
94   3183  2951.0382289
95   2591  2477.5763848
96    945   584.3414026
97   5292  5732.4468281
98   1298  2565.5735340
99    889   735.2652309
100  3796  4450.8771871
> result1$price4=result1$price-result1$price3
> mean(result1$price4);sd(result1$price4)
[1] 46.3577482
[1] 1158.710832


 그럭저럭 추정을 하기는 하는데, 가격이 평균 46달러 정도 차이가 있습니다. nnet에서는 maxit를 기본값으로 하면 결과를 찾지 못하기 때문에 500을 주고 노드는 5개로 했습니다. 


nn <- carat2="" color2="" cut2="" data="train," decay="5e-4," maxit="500) </span" nnet="" price2="" size="5,">
nn
price3<-predict nn="" span="" test="">
result2=cbind(test,price3)
result2

result2$price4=result2$price3*sd(train$price)+mean(train$price)
result2[,c("price","price4")]

result2$price5=result2$price-result2$price4
mean(result2$price5);sd(result2$price5)


 결과를 보면 상당히 차이가 큽니다. 


> result2$price4=result2$price3*sd(train$price)+mean(train$price)
> result2[,c("price","price4")]
    price      price4
1    5676 7672.007437
2     878 3894.393333
3     829 3894.393333
4   15428 7931.338301
5     421 3894.393333
6    1698 3894.393333
7    2861 3894.536388
8     658 3894.393333
9    8974 7916.883643
10   9189 7926.028269
11    492 3894.397454
12   6048 6569.112058
13  15485 7931.338288
14  16532 7931.338296
15   3612 3894.393333
16   8640 7898.185086
17   3494 3894.393333
18   1013 3894.393333
19   1659 3894.393333
20    612 3894.393333
21   6053 7930.365201
22   2205 3894.393333
23  16733 7931.338227
24   1380 3894.473040
25   1125 3894.393333
26  15783 7931.338301
27   5747 5179.993313
28    723 3894.393333
29    657 3894.396359
30  12791 7931.329113
31   7378 7007.191522
32    624 3894.394989
33   2990 3894.400044
34  10076 7931.333821
35   1063 3894.393333
36  15312 7931.338208
37   7119 5800.351948
38  10316 7929.702124
39   7340 7931.070966
40    942 3894.393333
41    828 3894.393333
42   3625 3897.432456
43    790 3894.445754
44    923 3894.393333
45    787 3894.393333
46   1608 3894.393333
47   8075 7927.931554
48   1289 3894.393333
49   2671 3894.393333
50   2578 3894.393333
51   3337 3950.021059
52    905 3894.393333
53   3348 4029.696071
54    561 3894.393333
55    545 3894.393333
56    774 3894.393333
57   4066 3894.681967
58  10913 7931.338299
59   4939 5429.746341
60   9240 7931.335309
61   5622 6141.245167
62  12150 7179.328294
63   5607 5098.811431
64   2265 3894.393333
65   1624 3894.393333
66   6271 6255.388441
67    680 3894.393333
68  15079 7931.334046
69  13196 7926.575540
70   2441 3894.393333
71   2401 3894.393333
72  16100 7931.338239
73   2525 3894.473636
74   2096 3894.393333
75   1082 3894.393333
76    802 3894.393333
77    817 3894.393333
78   4482 3894.395153
79    526 3894.393333
80    421 3894.393333
81   8403 7930.851183
82   3892 4846.138573
83   2010 3894.393333
84    685 3894.397274
85   2012 3894.393333
86   3644 3897.736542
87  17153 7931.337662
88    850 3894.399397
89   3972 3894.473520
90  12339 7931.337700
91    759 3894.393333
92    696 3894.393333
93   8774 7916.076486
94   3183 3894.399407
95   2591 3894.395289
96    945 3894.393333
97   5292 5429.746341
98   1298 3894.393333
99    889 3894.393333
100  3796 4021.678458
> result2$price5=result2$price-result2$price4
> mean(result2$price5);sd(result2$price5)
[1] -407.1997565
[1] 3325.384675




 이것만 보면 하나의 신경망 층만 만드는 nnet 보다 여러 개의 층을 만들 수 있는 neuralnet이 더 좋은 선택인 것 같습니다. 물론 이는 학습하고자하는 자료나 알고리즘에 따라 크게 차이가 있을 수 있어 어느 쪽이 더 좋다고 단순하게 말할 순 없습니다. 하지만 하나의 패키지로 만족할만한 결과를 얻지 못했을 때 몇 가지 패키지의 사용법을 알고 있다면 가장 좋은 성능을 내는 패키지를 사용할 수 있다는 장점이 있습니다. 앞으로 몇 가지 패키지를 더 알아보겠습니다. 

댓글

이 블로그의 인기 게시물

통계 공부는 어떻게 하는 것이 좋을까?

 사실 저도 통계 전문가가 아니기 때문에 이런 주제로 글을 쓰기가 다소 애매하지만, 그래도 누군가에게 도움이 될 수 있다고 생각해서 글을 올려봅니다. 통계학, 특히 수학적인 의미에서의 통계학을 공부하게 되는 계기는 사람마다 다르긴 하겠지만, 아마도 비교적 흔하고 난감한 경우는 논문을 써야 하는 경우일 것입니다. 오늘날의 학문적 연구는 집단간 혹은 방법간의 차이가 있다는 것을 객관적으로 보여줘야 하는데, 그려면 불가피하게 통계적인 방법을 쓸 수 밖에 없게 됩니다. 이런 이유로 분야와 주제에 따라서는 아닌 경우도 있겠지만, 상당수 논문에서는 통계학이 들어가게 됩니다.   문제는 데이터를 처리하고 분석하는 방법을 익히는 데도 상당한 시간과 노력이 필요하다는 점입니다. 물론 대부분의 학과에서 통계 수업이 들어가기는 하지만, 그것만으로는 충분하지 않은 경우가 많습니다. 대학 학부 과정에서는 대부분 논문 제출이 필요없거나 필요하다고 해도 그렇게 높은 수준을 요구하지 않지만, 대학원 이상 과정에서는 SCI/SCIE 급 논문이 필요하게 되어 처음 논문을 작성하는 입장에서는 상당히 부담되는 상황에 놓이게 됩니다.  그리고 이후 논문을 계속해서 쓰게 될 경우 통계 문제는 항상 나를 따라다니면서 괴롭히게 될 것입니다.  사정이 이렇다보니 간혹 통계 공부를 어떻게 하는 것이 좋겠냐는 질문이 들어옵니다. 사실 저는 통계 전문가라고 하기에는 실력은 모자라지만, 대신 앞서서 삽질을 한 경험이 있기 때문에 몇 가지 조언을 해줄 수 있을 것 같습니다.  1. 입문자를 위한 책을 추천해달라  사실 예습을 위해서 미리 공부하는 것은 추천하지 않습니다. 기본적인 통계는 학과별로 다르지 않더라도 주로 쓰는 분석방법은 분야별로 상당한 차이가 있을 수 있어 결국은 자신이 주로 하는 부분을 잘 해야 하기 때문입니다. 그러기 위해서는 학과 커리큘럼에 들어있는 통계 수업을 듣는 것이 더 유리합니다. 잘 쓰지도 않을 방법을 열심히 공부하는 것은 아무래도 효율

150년 만에 다시 울린 희귀 곤충의 울음 소리

  ( The katydid Prophalangopsis obscura has been lost since it was first collected, with new evidence suggesting cold areas of Northern India and Tibet may be the species' habitat. Credit: Charlie Woodrow, licensed under CC BY 4.0 ) ( The Museum's specimen of P. obscura is the only confirmed member of the species in existence. Image . Credit: The Trustees of the Natural History Museum, London )  과학자들이 1869년 처음 보고된 후 지금까지 소식이 끊긴 오래 전 희귀 곤충의 울음 소리를 재현하는데 성공했습니다. 프로팔랑곱시스 옵스큐라 ( Prophalangopsis obscura)는 이상한 이름만큼이나 이상한 곤충으로 매우 희귀한 메뚜기목 곤충입니다. 친척인 여치나 메뚜기와는 오래전 갈라진 독자 그룹으로 매우 큰 날개를 지니고 있으며 인도와 티벳의 고산 지대에 사는 것으로 보입니다.   유일한 표본은 수컷 성체로 2005년에 암컷으로 생각되는 2마리가 추가로 발견되긴 했으나 정확히 같은 종인지는 다소 미지수인 상태입니다. 현재까지 확실한 표본은 수컷 성체 한 마리가 전부인 미스터리 곤충인 셈입니다.   하지만 과학자들은 그 형태를 볼 때 이들 역시 울음 소리를 통해 짝짓기에서 암컷을 유인했을 것으로 보고 있습니다. 그런데 높은 고산 지대에서 먼 거리를 이동하는 곤충이기 때문에 낮은 피치의 울음 소리를 냈을 것으로 보입니다. 문제는 이런 소리는 암컷 만이 아니라 박쥐도 잘 듣는다는 것입니다. 사실 이들은 중생대 쥐라기 부터 존재했던 그룹으로 당시에는 박쥐가 없어 이런 방식이 잘 통했을 것입니다. 하지만 신생대에 박쥐가 등장하면서 플로팔랑곱

9000년 전 소녀의 모습을 복원하다.

( The final reconstruction. Credit: Oscar Nilsson )  그리스 아테나 대학과 스웨덴 연구자들이 1993년 발견된 선사 시대 소녀의 모습을 마치 살아있는 것처럼 복원하는데 성공했습니다. 이 유골은 그리스의 테살리아 지역의 테오페트라 동굴 ( Theopetra Cave )에서 발견된 것으로 연대는 9000년 전으로 추정됩니다. 유골의 주인공은 15-18세 사이의 소녀로 정확한 사인은 알 수 없으나 괴혈병, 빈혈, 관절 질환을 앓고 있었던 것으로 확인되었습니다.   이 소녀가 살았던 시기는 유럽 지역에서 수렵 채집인이 초기 농경으로 이전하는 시기였습니다. 다른 시기와 마찬가지로 이 시기의 사람들도 젊은 시절에 다양한 질환에 시달렸을 것이며 평균 수명 역시 매우 짧았을 것입니다. 비록 젊은 나이에 죽기는 했지만, 당시에는 이런 경우가 드물지 않았을 것이라는 이야기죠.   아무튼 문명의 새벽에 해당하는 시점에 살았기 때문에 이 소녀는 Dawn (그리스어로는  Avgi)라고 이름지어졌다고 합니다. 연구팀은 유골에 대한 상세한 스캔과 3D 프린팅 기술을 적용해서 살아있을 당시의 모습을 매우 현실적으로 복원했습니다. 그리고 그 결과 나타난 모습은.... 당시의 거친 환경을 보여주는 듯 합니다. 긴 턱은 당시를 살았던 사람이 대부분 그랬듯이 질긴 먹이를 오래 씹기 위한 것으로 보입니다.   강하고 억센 10대 소녀(?)의 모습은 당시 살아남기 위해서는 강해야 했다는 점을 말해주는 듯 합니다. 이렇게 억세보이는 주인공이라도 당시에는 전염병이나 혹은 기아에서 자유롭지는 못했기 때문에 결국 평균 수명은 길지 못했겠죠. 외모 만으로 평가해서는 안되겠지만, 당시의 거친 시대상을 보여주는 듯 해 흥미롭습니다.   참고  https://phys.org/news/2018-01-teenage-girl-years-reconstructed.html