Predict method for Boosted Configuration Networks (BCN)

# S3 method for bcn
predict(object, newx, type = c("response", "probs"))

Arguments

object

a object of class 'bcn'

newx

new data, with no intersection with training data

type

a string, "response" is the class, "probs" are the classifier's probabilities

Examples


set.seed(1234)
train_idx <- sample(nrow(iris), 0.8 * nrow(iris))
X_train <- as.matrix(iris[train_idx, -ncol(iris)])
X_test <- as.matrix(iris[-train_idx, -ncol(iris)])
y_train <- iris$Species[train_idx]
y_test <- iris$Species[-train_idx]

fit_obj <- bcn::bcn(x = X_train, y = y_train, B = 10, nu = 0.335855,
lam = 10**0.7837525, r = 1 - 10**(-5.470031), tol = 10**-7,
activation = "tanh", type_optim = "nlminb")
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |=======                                                               |  10%
  |                                                                            
  |==============                                                        |  20%
  |                                                                            
  |=====================                                                 |  30%
  |                                                                            
  |============================                                          |  40%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |==========================================                            |  60%
  |                                                                            
  |=================================================                     |  70%
  |                                                                            
  |========================================================              |  80%
  |                                                                            
  |===============================================================       |  90%
  |                                                                            
  |======================================================================| 100%

print(predict(fit_obj, newx = X_test) == y_test)
#>  [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
#> [16] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
print(mean(predict(fit_obj, newx = X_test) == y_test))
#> [1] 1

print(predict(fit_obj, newx = X_test, type="probs"))
#>          setosa versicolor virginica
#>  [1,] 0.4024654  0.3174903 0.2800443
#>  [2,] 0.4002180  0.3188798 0.2809022
#>  [3,] 0.3993414  0.3209437 0.2797149
#>  [4,] 0.4072561  0.3125790 0.2801649
#>  [5,] 0.4023243  0.3167276 0.2809481
#>  [6,] 0.4043465  0.3152716 0.2803818
#>  [7,] 0.3946759  0.3289303 0.2763938
#>  [8,] 0.4062572  0.3127968 0.2809460
#>  [9,] 0.4070092  0.3121744 0.2808163
#> [10,] 0.4036604  0.3168198 0.2795198
#> [11,] 0.3941399  0.3305342 0.2753259
#> [12,] 0.4000736  0.3157515 0.2841748
#> [13,] 0.4016246  0.3147469 0.2836285
#> [14,] 0.3073135  0.3724050 0.3202815
#> [15,] 0.2982237  0.3920059 0.3097704
#> [16,] 0.3092910  0.3785265 0.3121825
#> [17,] 0.3015702  0.3837633 0.3146665
#> [18,] 0.2967950  0.3539556 0.3492494
#> [19,] 0.2994432  0.3980609 0.3024960
#> [20,] 0.3012615  0.3985922 0.3001462
#> [21,] 0.3010154  0.3781519 0.3208327
#> [22,] 0.2766522  0.3323093 0.3910385
#> [23,] 0.2693026  0.3618746 0.3688228
#> [24,] 0.2830866  0.3365970 0.3803164
#> [25,] 0.2879091  0.3395733 0.3725176
#> [26,] 0.2897414  0.3335011 0.3767575
#> [27,] 0.2805241  0.3188193 0.4006566
#> [28,] 0.2835231  0.3124985 0.4039784
#> [29,] 0.2756331  0.3362094 0.3881575
#> [30,] 0.2854148  0.3191843 0.3954010