Solved – Predictions for rpart model require more variables than shown in the classification tree

Using rpart from the caret package, when plotting the final model I get a classification tree that seems fairly simple (6 variables shown in tree). However, when I request the final variables from the model, I get a list of 23 variables instead. So, first I am confused why the model apparently uses many more variables than are shown in the tree. However, even if I create new datasets that contain (a) the 7 variables shown in the tree, or (b) the 23 variables listed as predictors for the final model, I cannot use the model to make predictions due to an error that certain variables which were expected are not found in the new data. Is rpart using additional variables behind the scenes not shown the final model, or in the classification tree? Below is a reproducible example.

library(repmis) library(caret) library(rattle) library(rpart.plot) set.seed(1)  # Read data data = source_DropboxData(file = "levels_issue.csv", key = "5uo1pidphf34lvl", sep = ",", header = T) str(data) data$outcome = as.factor(data$outcome)  fast.control = trainControl(method = "repeatedCV", number = 10, repeats = 10, summaryFunction = twoClassSummary, classProbs = T, verboseIter = FALSE, savePredictions = TRUE)# 10-fold CV  rpart.mod = train(outcome~., data = data, method = "rpart", trControl = fast.control, tuneLength = 30)  # visualize the tree fancyRpartPlot(model = rpart.mod$finalModel) # less than 23 predictors used predictors(rpart.mod$finalModel) #23 predictors shown here  #inspect the final model tree.vars = rpart.mod$finalModel$frame$var tree.vars.index = !tree.vars %in% '<leaf>' = tree.vars[tree.vars.index] # 6 unique variables = unique(   # Try to rerun data through the tree using only the 23 or 7 variables from above six.vars = data[,colnames(data)%in%] colnames(six.vars)  twenty.three.vars = data[,predictors(rpart.mod$finalModel)]  # predict m1 = predict(rpart.mod$finalModel, six.vars) # V2 not found (this var was not shown in tree) m2 = predict(rpart.mod, twenty.three.vars) # V8 not found (not in list of final predictors)  setdiff(colnames(twenty.three.vars), predictors(rpart.mod$finalModel))     setdiff(predictors(rpart.mod$finalModel), colnames(twenty.three.vars)) # seems that all needed predictors should be in the dataset but the model wants additional predictors  # sanity check with full data m3 = predict(rpart.mod, data) # runs fine 

To get the the actual tree fitting part of the process, rpart creates a model matrix based on your formula passed into train or rpart. When you go to make predictions on new samples, all of the originals need to be there. Basically, rpart doesn't know not to disregard the others. Also, unless you turn them off, the model saves surrogate splits and these predictors may not be in the final model (and thus an error).

A few notes:

  • most R functions internally store predictors information as integers for their column position (instead of the column name). This is more efficient but it requires all of the original columns to always be there. I would like the functions to ugly require the predictors actually used in the model but they usually are not built that way.

  • the predictors class works on train objects so you can use predictors(rpart.mod) instead of predictors(rpart.mod$finalModel)

  • Please don't use predict(rpart.mod$finalModel). The rpart object knows nothing about what happens in train. If you use pre-processing or other operations in train that happen outside of rpart you will get incorrect predictions.


Similar Posts:

Rate this post

Leave a Comment