R - Extract factor predictor names from caret and glmnet lasso model object -


in below example, set df 3 variables, predict, var1, , var2 (a factor).

when run model in caret or glmnet, factor converted dummy variable, such var2b.

i'd extract variable names programmatically , match original variable names, not dummy variable names -- there way this?

this example, real world problem has many variables many different levels , therefore, want avoid doing manually, trying substring out "b".

thanks!

library(caret) library(glmnet)  df <- data.frame(predict = c('y','y','n','y','n','y','y','n','y','n'), var1 = c(1,2,5,1,6,7,3,4,5,6),               var2 = c('a','a','b','b','a','a','a','b','b','a'))  str(df)  # 'data.frame': 10 obs. of  3 variables: # $ predict: factor w/ 2 levels "n","y": 2 2 1 2 1 2 2 1 2 1 # $ var1   : num  1 2 5 1 6 7 3 4 5 6 # $ var2   : factor w/ 2 levels "a","b": 1 1 2 2 1 1 1 2 2 1  test <- train(predict ~ .,            data = df,            method = 'glmnet',            trcontrol = traincontrol(classprobs = true,                                     summaryfunction = twoclasssummary,                                     allowparallel = false),            metric = 'roc',            tunegrid = expand.grid(alpha = 1,                                   lambda = .005))  predictors(test) # [1] "var1"  "var2b" varimp(test) # glmnet variable importance  # overall # var2b     100 # var1        0  coef(test) # null ################# x <- model.matrix(as.formula(predict~.),data=df) x <-  x[,-1] ##remove intercept  df$predict <- ifelse(df$predict == 'y', true, false)  glmnet1 <- glmnet::cv.glmnet(x = x,                           y = df$predict,                           type.measure='auc',                           nfolds=3,                           alpha=1,                           parallel = false)  rownames(coef(glmnet1)) # [1] "(intercept)" "var1"        "var2b 

the formula method 'train' object returns 'formula' object attributes looking for.

f1 <- formula(test) f1 # predict ~ var1 + var2 # attr(,"variables") # list(predict, var1, var2) # attr(,"factors") #         var1 var2 # predict    0    0 # var1       1    0 # var2       0    1 # attr(,"term.labels") # [1] "var1" "var2" # attr(,"order") # [1] 1 1 # attr(,"intercept") # [1] 1 # attr(,"response") # [1] 1 # attr(,"predvars") # list(predict, var1, var2) # attr(,"dataclasses") #   predict      var1      var2  #  "factor" "numeric"  "factor"  attr(f1, "term.labels") # [1] "var1" "var2" 

it not appear variable names available in 'cv.glmnet' object. not aware of elegant way of collecting these. glmnetutils package might have quality of life functions.

here code try; note return false positives since searching column names pattern input data (e.g. "var11" match "var1").

# generic method termlabels <- function(object, ...) {     usemethod("termlabels") } # add train object save typing termlabels.train <- function(object, ...) {     attr(formula(object), "term.labels") } # try find term labels cv.glmnet object # lambda must provided , snaps search grid # allowed column names must provided corresponding data object termlabels.cv.glmnet <- function(object, lambda, names, ...) {     if (missing(lambda)) { stop("lambda missing") }     if (missing(names)) { stop("names missing") }     # match lambda     lambdaarray <- object$glmnet.fit$a0     if (lambda > max(lambdaarray) || lambda < min(lambdaarray)) {         stop(paste("lambda must in range",              paste(range(lambdaarray), collapse = ":")))     }     # find closest lambda     whichlambda <- which.min(abs(lambdaarray - lambda))     message(paste("using lambda", lambdaarray[whichlambda]))     # matrix of parameter estimates     betalambda <- object$glmnet.fit$beta[, whichlambda, drop = false]     # non-zero estimates     betalambda <- betalambda[betalambda[, 1] != 0, , drop = false]     vars <- rownames(betalambda)     # search names pattern     # note, not account nested names, e.g. var1 , var11     matchnames <- apply(matrix(names), margin = 1, fun = grepl, x = vars)     names[apply(matchnames, margin = 2, fun = any)] } termlabels(glmnet1, lambda = 1, names = colnames(df)) # using lambda 0.998561314952713 # [1] "var1" "var2" 

Comments

Popular posts from this blog

php - Vagrant up error - Uncaught Reflection Exception: Class DOMDocument does not exist -

vue.js - Create hooks for automated testing -

Add new key value to json node in java -