ANN/ORE #3 – Construction du réseau de neurones avec neuralnet

Dans les billets précédents, j’ai utilisé le package nnet. Celui-ci ne permet que la création de réseaux de neurones simples (une seule couche cachée). Pour construire un ANN plus complexe, d’autres packages sont disponibles sous R, ici je vais utiliser « neuralnet » avec, là encore, les données du MNIST.

Contrairement à nnet, la librairie neuralnet doit être installée. Je procède donc à sa mise en place sur la distribution ORE du serveur de base de données:

 
oracle@psu888: /home/oracle [HODBA04D1_1]# ORE CMD INSTALL /tmp/neuralnet_1.33.tar.gz
* installing to library ‘/soft/oracle/product/rdbms/12.2.0.1/R/library’
* installing *source* package ‘neuralnet’ ...
** package ‘neuralnet’ successfully unpacked and MD5 sums checked
** R
** preparing package for lazy loading
** help
*** installing help indices
  converting help for package ‘neuralnet’
    finding HTML links ... done
    compute                                 html
    confidence.interval                     html
    gwplot                                  html
    neuralnet-package                       html
    neuralnet                               html
    plot.nn                                 html
    prediction                              html
** building package indices
** testing if installed package can be loaded
* DONE (neuralnet)
oracle@psu888: /home/oracle [HODBA04D1_1]#

Le package peut alors être utilisé dans des appels ore – deux points à noter cependant:

  • les expressions du type « y ~ . » ne sont pas supportées, il faut donc spécifier explicitement les prédicteurs et les variables dépendantes via une formule.
  • les facteurs ne sont pas supportés, il faut donc passer par une étape de codification one-hot

Je vais maintenant construire un ANN avec deux couches cachées – la première constituée de 300 neurones, la seconde de 100 neurones:

 
> library(tictoc)
> tic()
> 
> ore.doEval(function() {
+   library(ORE)
+   library(neuralnet)
+   set.seed(3456)
+   ore.sync(table = "MNIST_TRAINING_SET")
+   mnist_training <- ore.pull(ore.get("MNIST_TRAINING_SET"))
+   
+   # -- One Hot Encoding du champ IMG_LBL
+   mnist_training_ohe <- as.data.frame(model.matrix(~.-1,mnist_training))
+   
+   # -- Construction de la formule en spécifiant les champs
+   f <- as.formula(paste(paste(paste("IMG_LBL",seq(0,9),sep=""),collapse="+"), " ~", paste(paste("P",seq(1,784),sep=""),collapse="+")))
+   
+   # -- Construction du modèle
+   nn_neuralnet <- neuralnet(f, data=mnist_training_ohe,
+                             hidden=c(300, 100),
+                             linear.output=FALSE)
+   
+   # -- Sauvegarde du modèle
+   ore.save(list=c("nn_neuralnet"),name="DS NeuralNet", append = TRUE)
+   
+   # -- Scoring du modèle
+   ore.sync(table = "MNIST_TEST_SET")
+   mnist_test_orig <- ore.pull(ore.get("MNIST_TEST_SET"))
+   mnist_test <- mnist_test_orig[,c(-1,-2)]
+   mnist_pred <- compute(nn_neuralnet,mnist_test)$net.result
+   
+   # -- Reversibilité de l'encodage One-Hot
+   names(mnist_pred) <- c("0","1","2","3","4","5","6","7","8","9")
+   pred <- names(mnist_pred)[max.col(mnist_pred)]
+   
+   # -- Table de contingence
+   print(table(pred,mnist_test_orig[1]$IMG_LBL))
+ }, ore.connect = TRUE)
    
pred    0    1    2    3    4    5    6    7    8    9
   0  966    0    8    2    1    3    8    3    9    9
   1    0 1121    2    3    2    1    3    3    1    4
   2    3    2  977    8    6    3    3   12    9    0
   3    1    2   11  952    2   13    1    6   17   11
   4    1    1    4    1  934    4    5    4    5   15
   5    3    0    1    7    1  847   12    1   12    4
   6    2    3    5    1    7    7  924    0    5    0
   7    2    2   10   11    5    2    0  987    4   12
   8    1    4   13   18    2    8    2    2  909    5
   9    1    0    1    7   22    4    0   10    3  949
> 
> toc()
2385.31 sec elapsed
> ore.datastoreSummary("DS NeuralNet")
   object.name        class        size length row.count col.count
1      nn_nnet nnet.formula   407018410     19        NA        NA   
2 nn_neuralnet           nn 23492644634     13        NA        NA
>

A l’instar de nnet, on peut voir avec la table de contingence qu’on obtient d’excellents résultats de classification.

En revanche, on peut remarquer deux aspects intéressants:

  • la durée de construction du modèle est drastiquement réduite par rapport au test avec nnet: 2385 secondes contre 49865 – soit une diminution de 95%!! On s’intéressera à cet aspect dans un futur billet.
  • la taille du modèle sauvegardé est gigantesque par rapport à celui produit par nnet: 23.5GB contre 400MB.

Cela s’avère malheureusement source de problèmes dans la mesure ou le rechargement du modèle requiert énormément de temps et de mémoire:

 
> tic()
> 
> ore.doEval(function() {
+   library(ORE)
+   library(neuralnet)
+   ore.load("DS NeuralNet", list = c("nn_neuralnet"))
+ }, ore.connect = TRUE)
Error in .oci.GetQuery(conn, statement, data = data, prefetch = prefetch,  : 
  ORA-20000: RQuery error
Error : vector memory exhausted (limit reached?)
ORA-06512: at "RQSYS.RQEVALIMPL", line 104
ORA-06512: at "RQSYS.RQEVALIMPL", line 101
> 
> toc()
7044.19 sec elapsed
> 
> ore.doEval(gc)
          used (Mb) gc trigger    (Mb) limit (Mb)   max used    (Mb)
Ncells 1830257 97.8   17095973   913.1       5600   30228485  1614.4
Vcells 2547251 19.5 2497018182 19050.8      32768 3901465136 29765.9
>

Ici, le rechargement a échoué au bout d’un peu moins de 2 heures en raison d’une saturation de la mémoire allouée au Vcells…

En conclusion, le package neuralnet permet de construire des ANN de typologie plus complexe que nnet. Si le package est extrêmement performant lors de la phase de construction du modèle, la ré-exploitation du modèle une fois sauvegardé s’avère problématique.

A suivre…

Laisser un commentaire

Votre adresse de messagerie ne sera pas publiée. Les champs obligatoires sont indiqués avec *

38 + = forty two