Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
13010db
Add bagging.
Weekend-Warrior Jul 15, 2017
d6916ec
Modify with linear effects.
Weekend-Warrior Jul 15, 2017
27faea5
Change centering and scaling order of operations to deal with missing…
Jul 26, 2017
7c2a149
Cleaned the centering and scaling a bit.
Jul 26, 2017
2373f64
Completed centering and scaling and decentering and descaling betas.
Jul 26, 2017
3df07cd
Housekeeping on the train output.
Jul 27, 2017
9ca5383
Compress rules matrix before estimating GLM.
Jul 27, 2017
80e7b45
Adopt Zelazny's bootstrapping code.
Jul 27, 2017
4d363a7
Update README.
Weekend-Warrior Jul 27, 2017
03ec2ac
Update README.md
Weekend-Warrior Jul 27, 2017
5006ba0
Update pretty sas code.
Weekend-Warrior Jul 28, 2017
fb323f9
Updates to pretty statements and moving singular value calculations t…
Jul 28, 2017
8de8b86
Updates to pretty statements and moving singular value calculations t…
Jul 28, 2017
062f8c1
Update rulefit.GBMfit
Weekend-Warrior Jul 28, 2017
9dc6530
Add SQL support for score output.
Jul 28, 2017
acb0751
Merge branch 'master' of https://github.com/StewartBobbitt/rulefit-1
Jul 28, 2017
e42f250
Merge branch 'master' of https://github.com/StewartBobbitt/rulefit-1
Jul 28, 2017
c75bd87
Merge branch 'master' of https://github.com/StewartBobbitt/rulefit-1
Jul 28, 2017
ddfaee8
Add winsors to numerical variables.
Weekend-Warrior Jul 29, 2017
d7cf7ba
Remove that first constant column from the nodes matrix.
Weekend-Warrior Jul 29, 2017
2ffe413
Change the scoping for interactions.
Weekend-Warrior Jul 29, 2017
d9dd5d0
Remove curvilinear effects.
Weekend-Warrior Jul 30, 2017
0244801
Carry parameters to bagged versions.
Jul 31, 2017
33e61b4
Edit the output of the train object.
Jul 31, 2017
ab12230
Fix summary indexing.
Jul 31, 2017
4dc29d3
Adopt Zelazny's classing of train object.
Jul 31, 2017
095dd62
Adopt Zelazny's thresholding for rules.
Jul 31, 2017
6777165
Add singularity option.
Jul 31, 2017
d73c7a8
Provide center and scaling options.
Jul 31, 2017
13de780
Adopt Zelazny's variable importance.
Jul 31, 2017
8348a9b
Update default behavior.
Aug 2, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
# Generated by roxygen2: do not edit by hand

S3method(predict,rulefit)
S3method(importance,rulefitFit)
S3method(plot,rulefitVarImp)
S3method(predict,rulefitFit)
S3method(print,rule)
S3method(print,rulefit)
S3method(rlang,default)
S3method(rlang,rule)
S3method(rlang,statement_factor)
S3method(rlang,statement_missing)
S3method(rlang,statement_numeric)
S3method(rlang,statement_ordered)
S3method(rulefit,GBMFit)
S3method(rulefit,gbm)
S3method(sas,default)
S3method(sas,rule)
S3method(sas,rulefitFit)
S3method(sas,statement_factor)
S3method(sas,statement_missing)
S3method(sas,statement_numeric)
S3method(sas,statement_ordered)
S3method(summary,rulefit)
S3method(sql,default)
S3method(sql,rule)
S3method(sql,rulefitFit)
S3method(sql,statement_factor)
S3method(sql,statement_missing)
S3method(sql,statement_numeric)
S3method(sql,statement_ordered)
S3method(summary,rulefitFit)
S3method(toString,rule)
S3method(toString,statement_factor)
S3method(toString,statement_missing)
S3method(toString,statement_numeric)
S3method(toString,statement_ordered)
S3method(train,rulefit)
export(importance)
export(rlang)
export(rulefit)
export(sas)
export(sql)
export(train)
46 changes: 46 additions & 0 deletions R/importance.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

## Variable level importance
#' @export
importance <- function(x, ...) UseMethod("importance")

#' Variable Importance
#' @param x rulefitFit object as output by the \code{\link{train}} function
#' @param newx development dataset usede to train the rulefit model
#' @param s lambda penalty term used to retrieve coefficients
#' @export
importance.rulefitFit <- function(x, newx, s=c("lambda.1se", "lambda.min"), ...) {
s <- match.arg(s)

su <- summary(x, newx, s=s, dedup=FALSE)
su <- su[order(su$node),]

vars_per_rule <- lapply(x$rules[su$node], function(r) sapply(r, '[[', "name"))

rules_per_var <- list()
for (i in seq_along(vars_per_rule)) {
if (length(vars_per_rule[[i]]) > 0) {
idx <- vars_per_rule[[i]]
rules_per_var[idx] <- lapply(rules_per_var[idx], function(x) c(x, i))
}
}

nodes <- predict(x, newx, s=s, nodes=TRUE)

## importance at X
m <- as.matrix(nodes)
Ik = sweep(abs(t(t(m) - su$support)), MARGIN = 2, abs(su$coefficient), FUN = `*`)

## divide each rule Ik by the number of vars in each rule
imp <- colSums(Ik)/lengths(vars_per_rule)
imp <- sapply(rules_per_var, function(r) sum(imp[r]))

structure(sort(imp/max(imp), decreasing = TRUE), class="rulefitVarImp")
}

#' @export
plot.rulefitVarImp <- function(x, y, ...) {
f <- colorRampPalette(c("lightblue", "blue"))
barplot(rev(x), horiz=TRUE, ylab = "Variable", xlab = "Relative Importance",
col = f(length(x)))
title("Variable Importance")
}
124 changes: 115 additions & 9 deletions R/pretty.R
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
#' @export
sas <- function(r) UseMethod("sas")
sas <- function(x, ...) UseMethod("sas")

#' @export
sas.statement_numeric <- function(l) {
fmt <- if (l$dir == -1) "(.z < %s < %s)" else "(%s >= %s)"
sprintf(fmt, l$name, l$value)
sas.default <- function(x, ...) NULL

#' @export
sas.statement_numeric <- function(x, ...) {
fmt <- if (x$dir == -1) "(.z < %s < %s)" else "(%s >= %s)"
sprintf(fmt, x$name, x$value)
}

#' @export
sas.statement_factor <- function(l) {
sprintf("(%s in (\"%s\"))", l$name, paste(l$value, collapse="\",\""))
sas.statement_factor <- function(x, ...) {
sprintf("(%s in (\"%s\"))", x$name, paste(x$value, collapse="\",\""))
}

#' @export
sas.statement_ordered <- function(l) sas.node_factor(l)
sas.statement_ordered <- function(x, ...) sas.node_factor(x, ...)

#' @export
sas.statement_missing <- function(l) sprintf("(missing(%s))", l$name)
sas.statement_missing <- function(x, ...) sprintf("(missing(%s))", x$name)

#' @export
sas.rule <- function(x, ...) {
paste0(lapply(x, sas), collapse = " AND ")
}

#' @export
toString.statement_factor <- function(x, ...) {
Expand Down Expand Up @@ -48,4 +56,102 @@ toString.rule <- function(x, ...) {
#' @export
print.rule <- function(x, ...) {
print(toString(x))
}
}

#' @export
rlang <- function(x) UseMethod("rlang")

#' @export
rlang.statement_numeric <- function(x) {
fmt <- if (x$dir == -1) "(%s < %s & !is.na(%s))" else "(%s >= %s & !is.na(%s))"
sprintf(fmt, x$name, x$value, x$name)
}

#' @export
rlang.statement_factor <- function(x) {
sprintf("(%s %%in%% c(\"%s\"))", x$name, paste(x$value, collapse="\",\""))
}

#' @export
rlang.statement_ordered <- function(x) rlang.node_factor(x)

#' @export
rlang.statement_missing <- function(x) sprintf("(is.na(%s))", x$name)

#' @export
rlang.default <- function(x) NULL

#' @export
rlang.rule <- function(x) {
paste0(lapply(x, rlang), collapse = " & ")
}

### SAS model
#' @export
sas.rulefitFit <- function(x, s=c("lambda.1se", "lambda.min"), pfx="rf", ...) {
s <- match.arg(s)
cf <- coef(x$fit, s)[,1]
rules <- x$rules[which(cf[-1] != 0)]

nm <- sprintf("%s_rule%03d", pfx, seq_along(rules))

code <- c(
c("/* Rule Definitions */"),
sprintf("%s = %s;", nm, sapply(rules, sas)),
c("\n/* Model Equation */"),
sprintf("%s_rulefit_mod = %3.6f", pfx, cf[1]),
sprintf(" + % 3.6f * %s", cf[cf != 0][-1], nm), ";")

code
}

### SQL support

#' @export
sql <- function(x, ...) UseMethod("sql")

#' @export
sql.default <- function(x, ...) NULL

#' @export
sql.statement_numeric <- function(x, ...) {
fmt <- if (x$dir == -1) "(%s < %s)" else "(%s >= %s)"
sprintf(fmt, x$name, x$value)
}

#' @export
sql.statement_factor <- function(x, ...) {
sprintf("(%s in (\'%s\'))", x$name, paste(x$value, collapse="\',\'"))
}

#' @export
sql.statement_ordered <- function(x, ...) sql.node_factor(x, ...)

#' @export
sql.statement_missing <- function(x, ...) sprintf("(%s is NULL)", x$name)

#' @export
sql.rule <- function(x, ...) {
paste0(lapply(x, sql), collapse = " AND ")
}

#' @export
sql.rulefitFit <- function(x, s=c("lambda.1se", "lambda.min"), pfx="rf", ...) {
s <- match.arg(s)
cf <- coef(x$fit, s)[,1]
rules <- x$rules[which(cf[-1] != 0)]

nm <- sprintf("%s_rule%03d", pfx, seq_along(rules))
len <- length(nm)

code <- c(
c("/* Rule Definitions */"),
sprintf("CASE WHEN %s THEN 1 ELSE 0 END as %s,", sapply(rules[-len], sql), nm[-len]),
sprintf("CASE WHEN %s THEN 1 ELSE 0 END as %s", sapply(rules[len], sql), nm[len]),
c("/* Model Equation */"),
sprintf(" % 3.6f", cf[1]),
sprintf(" + % 3.6f * %s", cf[cf != 0][-1], nm),
sprintf(" as %s_rulefit_mod", pfx))

code
}
Loading