Solved – What’s the meaning of nested resampling

I have a feature selection and regression task in which a dataset dataset is provided with 64×3000 numeric data. I want to use the best features for best settings possible for a number of learners (i.e., linear regressor and decision tree). I use mlr package in R. If I understand the provided tutorial correctly, I have to use resampling (terms like inner and outer resampling) mlr tutorial  - nested resampling. Why is this type of nested resampling required while I am cross-validating the learner and settings during parameter optimization?

Code: (dataset is not provided)

library('mlr')  set.seed(1234, "L'Ecuyer")  dataset = read.csv("dataset.csv")  # shuffle columns dataset <- dataset[, c(sample(ncol(dataset) - 1), ncol(dataset))]  # try to make full rank cov matrix for linear regression q <- qr(dataset) dataset <- dataset[, q$pivot[seq(q$rank)]]   regr.task = makeRegrTask(id = "dataset.ig", data = dataset, target = "target")  rdescCV2 = makeResampleDesc("CV", iters=2) rdescCV3 = makeResampleDesc("CV", iters=3)  inner = rdescCV2 outer = rdescCV3   lrns = list(   "regr.lm"   , makeLearner("regr.rpart", minbucket = 1) )  measures = list(mse, rsq)  for (lrn1 in lrns) {   set.seed(1234, "L'Ecuyer")    lrnName = ifelse(typeof(lrn1) == "list", lrn1$id, lrn1)   if (typeof(lrn1) == "list")   {     lrnPars = lrn1$par.vals   } else {     lrnPars = list()   }    lrnName2Save = lrnName    lrn = makeFilterWrapper(learner = makeLearner(cl = lrnName, par.vals = lrnPars))    ps = makeParamSet(     makeDiscreteParam("fw.abs", values = seq(3, 5, 1)),     makeDiscreteParam("fw.method", values = c('chi.squared',                                               'information.gain'     )))    if ("minsplit" %in% names(getParamSet(lrn)$pars))     ps$pars$minsplit = makeIntegerParam("minsplit", lower = 2L, upper = 3L)    # try to find the best feature set and setting for each learner   res = makeTuneWrapper(     lrn,     resampling = inner,     measures = measures,     par.set = ps,     control = makeTuneControlGrid(),     show.info = FALSE   )    # same indices for each learner   set.seed(1234, "L'Ecuyer")    r = resample(res, regr.task, outer, models = TRUE, measures = measures)    res2 = lapply(r$models, getTuneResult)   opt.paths = lapply(res2, function(x) as.data.frame(x$opt.path))   optimalFeatures[[lrnName2Save]] = lapply(r$models,                                             function(x) getFilteredFeatures(x$learner.model$next.model))   print(res2)   print(opt.paths)   print(optimalFeatures[[lrnName2Save]]) } 

Output:

[Resample] cross-validation iter 1: mse.test.mean=4.95e+03,rsq.test.mean=0.222 [Resample] cross-validation iter 2: mse.test.mean=4.03e+03,rsq.test.mean=0.497 [Resample] cross-validation iter 3: mse.test.mean= 961,rsq.test.mean=0.765 [Resample] Aggr. Result: mse.test.mean=3.31e+03,rsq.test.mean=0.495 [[1]] Tune result: Op. pars: fw.abs=4; fw.method=information.gain mse.test.mean=2.73e+03,rsq.test.mean=0.568  [[2]] Tune result: Op. pars: fw.abs=5; fw.method=information.gain mse.test.mean=2.9e+03,rsq.test.mean=0.383  [[3]] Tune result: Op. pars: fw.abs=5; fw.method=chi.squared mse.test.mean=6.64e+03,rsq.test.mean=-0.0448  [[1]]   fw.abs        fw.method mse.test.mean rsq.test.mean dob eol error.message exec.time 1      3      chi.squared      4711.697     0.3248701   1  NA          <NA>      0.36 2      4      chi.squared      2891.273     0.5480474   2  NA          <NA>      0.33 3      5      chi.squared      2861.078     0.5526319   3  NA          <NA>      0.31 4      3 information.gain      2726.971     0.5631411   4  NA          <NA>      0.43 5      4 information.gain      2726.018     0.5678868   5  NA          <NA>      0.38 6      5 information.gain      2970.028     0.5395522   6  NA          <NA>      0.39  [[2]]   fw.abs        fw.method mse.test.mean rsq.test.mean dob eol error.message exec.time 1      3      chi.squared      5357.465    -0.2319388   1  NA          <NA>      0.34 2      4      chi.squared      3747.050     0.2437902   2  NA          <NA>      0.35 3      5      chi.squared      2897.023     0.3831484   3  NA          <NA>      0.31 4      3 information.gain      5357.465    -0.2319388   4  NA          <NA>      0.41 5      4 information.gain      3747.050     0.2437902   5  NA          <NA>      0.42 6      5 information.gain      2897.023     0.3831484   6  NA          <NA>      0.43  [[3]]   fw.abs        fw.method mse.test.mean rsq.test.mean dob eol error.message exec.time 1      3      chi.squared      7593.989   -0.10557789   1  NA          <NA>      0.37 2      4      chi.squared      6786.384   -0.02621949   2  NA          <NA>      0.33 3      5      chi.squared      6637.264   -0.04484878   3  NA          <NA>      0.32 4      3 information.gain      7593.989   -0.10557789   4  NA          <NA>      0.40 5      4 information.gain      6786.384   -0.02621949   5  NA          <NA>      0.39 6      5 information.gain      6637.264   -0.04484878   6  NA          <NA>      0.41  [[1]] [1] "RDF065u_640" "RTp_1225"    "L2u_940"     "TIC3_182"     [[2]] [1] "RTp_1225"    "L2u_940"     "Mor03m_813"  "TIC3_182"    "Mor03m_2294"  [[3]] [1] "H.046_1401"  "Mor21u_2280" "RDF065u_640" "RTp_1225"    "CIC2_1660"    [Resample] cross-validation iter 1: mse.test.mean=3.13e+03,rsq.test.mean=0.509 [Resample] cross-validation iter 2: mse.test.mean=8.04e+03,rsq.test.mean=-0.00294 [Resample] cross-validation iter 3: mse.test.mean=3.49e+03,rsq.test.mean=0.148 [Resample] Aggr. Result: mse.test.mean=4.89e+03,rsq.test.mean=0.218 [[1]] Tune result: Op. pars: fw.abs=5; fw.method=chi.squared; minsplit=3 mse.test.mean=3.15e+03,rsq.test.mean=0.443  [[2]] Tune result: Op. pars: fw.abs=5; fw.method=information.gain; minsplit=3 mse.test.mean=3.3e+03,rsq.test.mean=0.206  [[3]] Tune result: Op. pars: fw.abs=5; fw.method=chi.squared; minsplit=2 mse.test.mean=4.33e+03,rsq.test.mean=0.368  [[1]]    fw.abs        fw.method minsplit mse.test.mean rsq.test.mean dob eol error.message exec.time 1       3      chi.squared        2      3875.576     0.3448855   1  NA          <NA>      0.35 2       4      chi.squared        2      4054.182     0.2971222   2  NA          <NA>      0.33 3       5      chi.squared        2      3149.302     0.4433532   3  NA          <NA>      0.34 4       3 information.gain        2      3351.588     0.4077916   4  NA          <NA>      0.42 5       4 information.gain        2      3904.129     0.3151364   5  NA          <NA>      0.41 6       5 information.gain        2      3649.004     0.3833628   6  NA          <NA>      0.39 7       3      chi.squared        3      3875.576     0.3448855   7  NA          <NA>      0.35 8       4      chi.squared        3      4054.182     0.2971222   8  NA          <NA>      0.35 9       5      chi.squared        3      3149.302     0.4433532   9  NA          <NA>      0.38 10      3 information.gain        3      3351.588     0.4077916  10  NA          <NA>      0.40 11      4 information.gain        3      3904.129     0.3151364  11  NA          <NA>      0.41 12      5 information.gain        3      3649.004     0.3833628  12  NA          <NA>      0.42  [[2]]    fw.abs        fw.method minsplit mse.test.mean rsq.test.mean dob eol error.message exec.time 1       3      chi.squared        2      4846.020   -0.01409290   1  NA          <NA>      0.32 2       4      chi.squared        2      3316.516    0.20477753   2  NA          <NA>      0.32 3       5      chi.squared        2      3304.965    0.20643353   3  NA          <NA>      0.36 4       3 information.gain        2      4848.166   -0.01480330   4  NA          <NA>      0.43 5       4 information.gain        2      3316.516    0.20477753   5  NA          <NA>      0.42 6       5 information.gain        2      3304.965    0.20643353   6  NA          <NA>      0.42 7       3      chi.squared        3      4613.949    0.05281112   7  NA          <NA>      0.38 8       4      chi.squared        3      3316.516    0.20477753   8  NA          <NA>      0.41 9       5      chi.squared        3      3304.965    0.20643353   9  NA          <NA>      0.33 10      3 information.gain        3      4795.237   -0.00721534  10  NA          <NA>      0.39 11      4 information.gain        3      3316.516    0.20477753  11  NA          <NA>      0.38 12      5 information.gain        3      3304.965    0.20643353  12  NA          <NA>      0.36  [[3]]    fw.abs        fw.method minsplit mse.test.mean rsq.test.mean dob eol error.message exec.time 1       3      chi.squared        2      8346.300    -0.0896325   1  NA          <NA>      0.29 2       4      chi.squared        2     10435.255    -0.3316064   2  NA          <NA>      0.32 3       5      chi.squared        2      4325.461     0.3684383   3  NA          <NA>      0.30 4       3 information.gain        2      8346.300    -0.0896325   4  NA          <NA>      0.39 5       4 information.gain        2     10435.255    -0.3316064   5  NA          <NA>      0.39 6       5 information.gain        2      4325.461     0.3684383   6  NA          <NA>      0.41 7       3      chi.squared        3      8346.300    -0.0896325   7  NA          <NA>      0.36 8       4      chi.squared        3     10435.255    -0.3316064   8  NA          <NA>      0.34 9       5      chi.squared        3      4325.461     0.3684383   9  NA          <NA>      0.34 10      3 information.gain        3      8346.300    -0.0896325  10  NA          <NA>      0.41 11      4 information.gain        3     10435.255    -0.3316064  11  NA          <NA>      0.44 12      5 information.gain        3      4325.461     0.3684383  12  NA          <NA>      0.40  [[1]] [1] "RDF065u_640" "RTp_1225"    "L2u_940"     "Mor03m_813"  "TIC3_182"     [[2]] [1] "RTp_1225"    "L2u_940"     "Mor03m_813"  "TIC3_182"    "Mor03m_2294"  [[3]] [1] "H.046_1401"  "Mor21u_2280" "RDF065u_640" "RTp_1225"    "CIC2_1660"  

The purpose of a nested cross-validation is to reduce overfitting — that is, thinking a model has good generalization performance when it doesn't.

To see why this is a problem, consider the extreme case where a model simply memorizes the data (for example KNN with 1 neighbour). If you evaluate this model with the data you've trained it on (or part thereof), you'll get perfect performance, but any other data will probably give terrible results.

That's why you need separate train and test sets. But even with that, it's possible that you get an unlucky split and train and test end up being too similar, again giving a misleading impression of the real performance. It could work the other way, too, where the train and test sets are too dissimilar and no matter what you learn on the training set, you won't do well on the test set.

So you can go one step further and use a series of train and test sets — cross-validation. You take your entire data and split into n of folds, using 1 fold for testing and remaining n-1 for training, then another for testing, and so on for n rounds.

Why is nested cross-validation for things like tuning (where lots of different models are evaluated and compared) better? Consider the following thought experiment. A learner has one parameter which just adds random noise; there's no real effect. Comparing different parameterizations of this learner will result in one of those being best by pure chance, even when using a cross-validation. In a nested cross-validation, the models will be evaluated on yet another set, showing that the parameter doesn't actually do anything (or at least more likely to show that).

Neither cross-validation nor nested cross-validation are really required in any case, but they'll likely improve the generalisation performance of the end result dramatically.

Similar Posts:

Rate this post

Leave a Comment