Title: | Oblique Classification Trees with Uncorrelated Linear Discriminant Analysis Splits |
---|---|
Description: | A classification tree method that uses Uncorrelated Linear Discriminant Analysis (ULDA) for variable selection, split determination, and model fitting in terminal nodes. It automatically handles missing values and offers visualization tools. For more details, see Wang (2024) <doi:10.48550/arXiv.2410.23147>. |
Authors: | Siyu Wang [aut, cre, cph] |
Maintainer: | Siyu Wang <[email protected]> |
License: | MIT + file LICENSE |
Version: | 0.2.0.9000 |
Built: | 2025-03-02 05:50:04 UTC |
Source: | https://github.com/moran79/ldatree |
This function visualizes either the entire decision tree or a specific node
within the tree. The tree is displayed as an interactive network of nodes and
edges, while individual nodes are scatter/density plots using ggplot2
.
## S3 method for class 'Treee' plot(x, datX, response, node = -1, ...)
## S3 method for class 'Treee' plot(x, datX, response, node = -1, ...)
x |
A fitted model object of class |
datX |
A data frame of predictor variables. Required for plotting individual nodes. |
response |
A vector of response values. Required for plotting individual nodes. |
node |
An integer specifying the node to plot. If |
... |
Additional arguments passed to the plotting functions. |
A visNetwork
interactive plot of the decision tree if node = -1
,
or a ggplot2
object if a specific node is plotted.
A full tree diagram is displayed using visNetwork when node
is not
specified (the default is -1
). The color represents the most common
(plurality) class within each node, and the size of each terminal node
reflects its relative sample size. Below each node, the fraction of
correctly predicted training samples and the total sample size for that
node are shown, along with the node index. Clicking on a node opens an
information panel with additional details.
To plot a specific node, you must provide the node index along with the
original training predictors (datX
) and responses (response
). A scatter
plot is generated if more than one discriminant score is available,
otherwise, a density plot is created. Samples are projected onto their
linear discriminant score(s).
fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) plot(fit) # plot the overall tree plot(fit, datX = iris, response = iris[, 5], node = 1) # plot a specific node
fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) plot(fit) # plot the overall tree plot(fit, datX = iris, response = iris[, 5], node = 1) # plot a specific node
Generate predictions on new data using a fitted Treee
model.
## S3 method for class 'Treee' predict(object, newdata, type = c("response", "prob", "all"), ...)
## S3 method for class 'Treee' predict(object, newdata, type = c("response", "prob", "all"), ...)
object |
A fitted model object of class |
newdata |
A data frame containing the predictor variables. Missing values are allowed and will be handled according to the fitted tree's method for handling missing data. |
type |
A character string specifying the type of prediction to return. Options are:
|
... |
Additional arguments passed to or from other methods. |
Depending on the value of type
, the function returns:
If type = 'response'
: A character vector of predicted class labels.
If type = 'prob'
: A data frame of posterior probabilities, where each class has its own column.
If type = 'all'
: A data frame containing predicted class labels, posterior probabilities, and the predicted node indices.
Note: For factor predictors, if a level not present in the training data is
found in newdata
, it will be treated as missing and handled according to
the missingMethod
specified in the fitted tree.
fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) head(predict(fit, iris)) # Predicted classes head(predict(fit, iris[, -5], type = "prob")) # Posterior probabilities head(predict(fit, iris[, -5], type = "all")) # Full details
fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) head(predict(fit, iris)) # Predicted classes head(predict(fit, iris[, -5], type = "prob")) # Posterior probabilities head(predict(fit, iris[, -5], type = "all")) # Full details
This function fits a classification tree where each node has a Uncorrelated Linear Discriminant Analysis (ULDA) model. It can also handle missing values and perform downsampling. The resulting tree can be pruned either through pre-pruning or post-pruning methods.
Treee( datX, response, ldaType = c("forward", "all"), nodeModel = c("ULDA", "mode"), pruneMethod = c("pre", "post"), numberOfPruning = 10L, maxTreeLevel = 20L, minNodeSize = NULL, pThreshold = NULL, prior = NULL, misClassCost = NULL, missingMethod = c("medianFlag", "newLevel"), kSample = -1, verbose = TRUE )
Treee( datX, response, ldaType = c("forward", "all"), nodeModel = c("ULDA", "mode"), pruneMethod = c("pre", "post"), numberOfPruning = 10L, maxTreeLevel = 20L, minNodeSize = NULL, pThreshold = NULL, prior = NULL, misClassCost = NULL, missingMethod = c("medianFlag", "newLevel"), kSample = -1, verbose = TRUE )
datX |
A data frame of predictor variables. |
response |
A vector of response values corresponding to |
ldaType |
A character string specifying the type of LDA to use. Options
are |
nodeModel |
A character string specifying the type of model used in each
node. Options are |
pruneMethod |
A character string specifying the pruning method. |
numberOfPruning |
An integer specifying the number of folds for
cross-validation during post-pruning. Default is |
maxTreeLevel |
An integer controlling the maximum depth of the tree.
Increasing this value allows for deeper trees with more nodes. Default is
|
minNodeSize |
An integer controlling the minimum number of samples required in a node. Setting a higher value may lead to earlier stopping and smaller trees. If not specified, it defaults to one plus the number of response classes. |
pThreshold |
A numeric value used as a threshold for pre-pruning based
on p-values. Lower values result in more conservative trees. If not
specified, defaults to |
prior |
A numeric vector of prior probabilities for each class. If
|
misClassCost |
A square matrix |
missingMethod |
A character string specifying how missing values should
be handled. Options include |
kSample |
An integer specifying the number of samples to use for
downsampling during tree construction. Set to |
verbose |
A logical value. If |
An object of class Treee
containing the fitted tree, which is a
list of nodes, each an object of class TreeeNode
. Each TreeeNode
contains:
currentIndex
: The node index in the tree.
currentLevel
: The depth of the current node in the tree.
idxRow
, idxCol
: Row and column indices indicating which part of the original data was used for this node.
currentLoss
: The training error for this node.
accuracy
: The training accuracy for this node.
stopInfo
: Information on why the node stopped growing.
proportions
: The observed frequency of each class in this node.
prior
: The (adjusted) class prior probabilities used for ULDA or mode prediction.
misClassCost
: The misclassification cost matrix used in this node.
parent
: The index of the parent node.
children
: A vector of indices of this node’s direct children.
splitFun
: The splitting function used for this node.
nodeModel
: Indicates the model fitted at the node ('ULDA'
or 'mode'
).
nodePredict
: The fitted model at the node, either a ULDA object or the plurality class.
alpha
: The p-value from a two-sample t-test used to evaluate the strength of the split.
childrenTerminal
: A vector of indices representing the terminal nodes that are descendants of this node.
childrenTerminalLoss
: The total training error accumulated from all nodes listed in childrenTerminal
.
Wang, S. (2024). FoLDTree: A ULDA-Based Decision Tree Framework for Efficient Oblique Splits and Feature Selection. arXiv preprint arXiv:2410.23147. Available at https://arxiv.org/abs/2410.23147.
Wang, S. (2024). A New Forward Discriminant Analysis Framework Based On Pillai's Trace and ULDA. arXiv preprint arXiv:2409.03136. Available at https://arxiv.org/abs/2409.03136.
fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) # Use cross-validation to prune the tree fitCV <- Treee(datX = iris[, -5], response = iris[, 5], pruneMethod = "post", verbose = FALSE) head(predict(fit, iris)) # prediction plot(fit) # plot the overall tree plot(fit, datX = iris[, -5], response = iris[, 5], node = 1) # plot a certain node
fit <- Treee(datX = iris[, -5], response = iris[, 5], verbose = FALSE) # Use cross-validation to prune the tree fitCV <- Treee(datX = iris[, -5], response = iris[, 5], pruneMethod = "post", verbose = FALSE) head(predict(fit, iris)) # prediction plot(fit) # plot the overall tree plot(fit, datX = iris[, -5], response = iris[, 5], node = 1) # plot a certain node