监督学习分类之决策树(Decision Tree)

原创
2022/06/07 18:02
阅读数 99

决策树算法是一种非参数化监督学习方法,可用于分类和回归。其基本原理是用一系列的if else分支来对未知数据进行判断,从而给出其正确的分类。if else中的参数,则通过样本训练得到。决策树也可以看作是对函数的一系列分段常数近似。

决策树有以下优点:

  • 易于理解和解释,并且整棵决策树可以实现可视化。
  • 只需要较少的训练数据,且不需要对数据进行归一化处理
  • 对数据预测的计算开销和用于训练的数据样本数为对数关系
  • 可处理多输出问题
  • 可通过统计检验验证模型,并给出模型可靠性
  • 即使数据不符合模型假设的情况下,也能表现良好

同时具有以下缺点:

  • 可能会产生过于复杂的树,过拟合,导致数据预测无法外推。可通过剪枝,设置叶节点最小数据数量或设置树的最大深度来避免这种情况。
  • 可能不够稳定,数据很小的变化生成完全不同的决策树。可通过集合训练(ensemble)的方法来减小这种情况带来的影响。
  • 决策树的预测既不光滑也不连续,而是分段常数近似,因此不适于外插
  • 实际构建决策树时往往用到诸如贪心算法寻找局部最优,因此无法保证得到全局最优决策树。可通过集合训练(ensemble)的方法来缓解。
  • 有一些概念决策树很难表达,比如XOR,对称或多路选择器(multiplexer)。
  • 如果数据集中某些类占主导,决策树训练结果可能不准确。

 

相比于前两种监督学习分类方法Gaussian Naive Bayes和Nearest Centroid,决策树方法要复杂许多。其涉及到较复杂的数据结构——树及其相关算法。下面给出scikit-learn中决策树的源代码,由于在scikit-learn采用了cython以加速计算,但代码看起来不直观且无法直接运行,此处进行了修改将其全部用python实现。

Utils.py

SIZE_MAX = 2**32 - 1

# =============================================================================
# Stack data structure
# =============================================================================

# A record on the stack for depth-first tree growing
class StackRecord:
    start = 0
    end = 0
    depth = 0
    parent = 0
    is_left = 1
    impurity = 1
    n_constant_features = 0


# =============================================================================
# PriorityHeap data structure
# =============================================================================

# A record on the frontier for best-first tree growing
class PriorityHeapRecord:
    node_id = 0
    start = 0
    end = 0
    pos = 0
    depth = 0
    is_leaf = 1
    impurity = 1
    impurity_left = 1
    impurity_right = 1
    improvement = 1

# =============================================================================
# Stack data structure
# =============================================================================

class Stack:
    """A LIFO data structure.
    Attributes
    ----------
    capacity : SIZE_t
        The elements the stack can hold; if more added then ``self.stack_``
        needs to be resized.
    top : SIZE_t
        The number of elements currently on the stack.
    stack : StackRecord pointer
        The stack of records (upward in the stack corresponds to the right).
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.top = 0
        self.stack_ = []
        for i in range(capacity):
            self.stack_.append(StackRecord())

    def is_empty(self):
        return self.top <= 0

    def push(self, start, end, depth, parent,
                  is_left, impurity,
                  n_constant_features):
        """Push a new element onto the stack.
        Return -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        top = self.top

        # Resize if capacity not sufficient
        if top >= self.capacity:
            for idx in range(self.capacity):
                self.stack_.append(StackRecord())
            self.capacity *= 2

        stack = self.stack_[top]
        stack.start = start
        stack.end = end
        stack.depth = depth
        stack.parent = parent
        stack.is_left = is_left
        stack.impurity = impurity
        stack.n_constant_features = n_constant_features

        # Increment stack pointer
        self.top = top + 1
        return 0

    def pop(self):
        """Remove the top element from the stack and copy to ``res``.
        Returns 0 if pop was successful (and ``res`` is set); -1
        otherwise.
        """
        top = self.top
        stack = self.stack_

        if top <= 0:
            return (-1, None)

        res = stack[top - 1]
        self.top = top - 1

        return (0, res)


# =============================================================================
# PriorityHeap data structure
# =============================================================================

class PriorityHeap:
    """A priority queue implemented as a binary heap.
    The heap invariant is that the impurity improvement of the parent record
    is larger then the impurity improvement of the children.
    Attributes
    ----------
    capacity : SIZE_t
        The capacity of the heap
    heap_ptr : SIZE_t
        The water mark of the heap; the heap grows from left to right in the
        array ``heap_``. The following invariant holds ``heap_ptr < capacity``.
    heap_ : PriorityHeapRecord*
        The array of heap records. The maximum element is on the left;
        the heap grows from left to right
    """
    def __init__(self, capacity):
        self.capacity = capacity
        self.heap_ptr = 0
        self.heap_ = []
        for idx in range(capacity):
            self.heap_.append(PriorityHeapRecord())

    def is_empty(self):
        return self.heap_ptr <= 0

    def heapify_up(self, heap, pos):
        """Restore heap invariant parent.improvement > child.improvement from
           ``pos`` upwards. """
        if pos == 0:
            return

        parent_pos = (pos - 1) / 2

        if heap[parent_pos].improvement < heap[pos].improvement:
            heap[parent_pos], heap[pos] = heap[pos], heap[parent_pos]
            self.heapify_up(heap, parent_pos)

    def heapify_down(self, heap, pos, heap_length):
        """Restore heap invariant parent.improvement > children.improvement from
           ``pos`` downwards. """
        left_pos = 2 * (pos + 1) - 1
        right_pos = 2 * (pos + 1)
        largest = pos

        if (left_pos < heap_length and
                heap[left_pos].improvement > heap[largest].improvement):
            largest = left_pos

        if (right_pos < heap_length and
                heap[right_pos].improvement > heap[largest].improvement):
            largest = right_pos

        if largest != pos:
            heap[pos], heap[largest] = heap[largest], heap[pos]
            self.heapify_down(heap, largest, heap_length)

    def push(self, node_id, start, end, pos,
                  depth, is_leaf, improvement,
                  impurity, impurity_left,
                  impurity_right):
        """Push record on the priority heap.
        Return -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        heap_ptr = self.heap_ptr
        heap = self.heap_[heap_ptr]

        # Resize if capacity not sufficient
        if heap_ptr >= self.capacity:
            for idx in range(self.capacity):
                self.heap_.append(PriorityHeapRecord())
            self.capacity *= 2

        # Put element as last element of heap
        heap.node_id = node_id
        heap.start = start
        heap.end = end
        heap.pos = pos
        heap.depth = depth
        heap.is_leaf = is_leaf
        heap.impurity = impurity
        heap.impurity_left = impurity_left
        heap.impurity_right = impurity_right
        heap.improvement = improvement

        # Heapify up
        self.heapify_up(self.heap_, heap_ptr)

        # Increase element count
        self.heap_ptr = heap_ptr + 1
        return 0

    def pop(self):
        """Remove max element from the heap. """
        heap_ptr = self.heap_ptr
        heap = self.heap_

        if heap_ptr <= 0:
            return -1

        # Take first element
        res = heap[0]

        # Put last element to the front
        heap[0], heap[heap_ptr - 1] = heap[heap_ptr - 1], heap[0]

        # Restore heap invariant
        if heap_ptr > 1:
            self.heapify_down(heap, 0, heap_ptr - 1)

        self.heap_ptr = heap_ptr - 1

        return (0, res)

class Node:
    # Base storage structure for the nodes in a Tree object
    left_child = 0                    # id of the left child of the node
    right_child = 0                   # id of the right child of the node
    feature = 0                       # Feature used for splitting the node
    threshold = 1                   # Threshold value at the node
    impurity = 1                    # Impurity of the node (i.e., the value of the criterion)
    n_node_samples = 0                # Number of samples at the node
    weighted_n_node_samples = 1     # Weighted number of samples at the node

 

Criterion.py

import numpy as np
from .Splitter import log


class Criterion:
    """Interface for impurity criteria.
    This object stores methods on how to calculate how good a split is using
    different metrics.
    """
    def proxy_impurity_improvement(self):
        """Compute a proxy of the impurity reduction
        This method is used to speed up the search for the best split.
        It is a proxy quantity such that the split that maximizes this value
        also maximizes the impurity improvement. It neglects all constant terms
        of the impurity decrease for a given split.
        The absolute impurity improvement is only computed by the
        impurity_improvement method once the best split has been found.
        """
        (impurity_left, impurity_right) = self.children_impurity()

        return (- self.weighted_n_right * impurity_right
                - self.weighted_n_left * impurity_left)

    def impurity_improvement(self, impurity):
        """Compute the improvement in impurity
        This method computes the improvement in impurity when a split occurs.
        The weighted impurity improvement equation is the following:
            N_t / N * (impurity - N_t_R / N_t * right_impurity
                                - N_t_L / N_t * left_impurity)
        where N is the total number of samples, N_t is the number of samples
        at the current node, N_t_L is the number of samples in the left child,
        and N_t_R is the number of samples in the right child,
        Parameters
        ----------
        impurity : double
            The initial impurity of the node before the split
        Return
        ------
        double : improvement in impurity after the split occurs
        """
        (impurity_left, impurity_right) = self.children_impurity()

        return ((self.weighted_n_node_samples / self.weighted_n_samples) *
                (impurity - (self.weighted_n_right / 
                             self.weighted_n_node_samples * impurity_right)
                          - (self.weighted_n_left / 
                             self.weighted_n_node_samples * impurity_left)))


class ClassificationCriterion(Criterion):
    """Abstract criterion for classification."""
    def __init__(self, n_outputs,
                  n_classes):
        """Initialize attributes for this criterion.
        Parameters
        ----------
        n_outputs : SIZE_t
            The number of targets, the dimensionality of the prediction
        n_classes : numpy.ndarray, dtype=SIZE_t
            The number of unique classes in each target
        """
        self.sample_weight = None

        self.samples = None
        self.start = 0
        self.pos = 0
        self.end = 0

        self.n_outputs = n_outputs
        self.n_samples = 0
        self.n_node_samples = 0
        self.weighted_n_node_samples = 0.0
        self.weighted_n_left = 0.0
        self.weighted_n_right = 0.0

        # Count labels for each output
        self.n_classes = []
        sum_stride = 0

        # For each target, set the number of unique classes in that target,
        # and also compute the maximal stride of all targets
        for k in range(n_outputs):
            self.n_classes.append(n_classes[k])

            if n_classes[k] > sum_stride:
                sum_stride = n_classes[k]

        self.sum_stride = sum_stride

        n_elements = n_outputs * sum_stride
        self.sum_total = np.zeros(n_elements)
        self.sum_left = np.zeros(n_elements)
        self.sum_right = np.zeros(n_elements)


    def init(self, y,
                  sample_weight, weighted_n_samples,
                  samples, start, end):
        """Initialize the criterion at node samples[start:end] and
        children samples[start:start] and samples[start:end].
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        Parameters
        ----------
        y : array-like, dtype=DOUBLE_t
            The target stored as a buffer for memory efficiency
        sample_weight : array-like, dtype=DOUBLE_t
            The weight of each sample
        weighted_n_samples : double
            The total weight of all samples
        samples : array-like, dtype=SIZE_t
            A mask on the samples, showing which ones we want to use
        start : SIZE_t
            The first sample to use in the mask
        end : SIZE_t
            The last sample to use in the mask
        """
        self.y = y
        self.sample_weight = sample_weight
        self.samples = samples
        self.start = start
        self.end = end
        self.n_node_samples = end - start
        self.weighted_n_samples = weighted_n_samples
        self.weighted_n_node_samples = 0.0

        w = 1.0
        offset = 0

        for k in range(self.n_outputs):
            for j in range(self.n_classes[k]):
                self.sum_total[offset + j] = 0
            offset += self.sum_stride

        for p in range(start, end):
            i = samples[p]

            # w is originally set to be 1.0, meaning that if no sample weights
            # are given, the default weight of each sample is 1.0
            if sample_weight != None:
                w = sample_weight[i]

            # Count weighted class frequency for each target
            for k in range(self.n_outputs):
                c = int(self.y[i, k])
                self.sum_total[k * self.sum_stride + c] += w

            self.weighted_n_node_samples += w

        # Reset to pos=start
        self.reset()
        return 0

    def reset(self):
        """Reset the criterion at pos=start
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        self.pos = self.start

        self.weighted_n_left = 0.0
        self.weighted_n_right = self.weighted_n_node_samples

        n_classes = self.n_classes

        for k in range(self.n_outputs):
            for j in range(n_classes[k]):
                idx = k*self.sum_stride + j
                self.sum_left[idx] = 0
                self.sum_right[idx] = self.sum_total[idx]

        return 0

    def reverse_reset(self):
        """Reset the criterion at pos=end
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        self.pos = self.end

        self.weighted_n_left = self.weighted_n_node_samples
        self.weighted_n_right = 0.0

        for k in range(self.n_outputs):
            for j in range(self.n_classes[k]):
                idx = k*self.sum_stride + j
                self.sum_right[idx] = 0
                self.sum_left[idx] = self.sum_total[idx]

        return 0

    def update(self, new_pos):
        """Updated statistics by moving samples[pos:new_pos] to the left child.
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        Parameters
        ----------
        new_pos : SIZE_t
            The new ending position for which to move samples from the right
            child to the left child.
        """
        pos = self.pos
        end = self.end
        w = 1.0

        # Update statistics up to new_pos
        #
        # Given that
        #   sum_left[x] +  sum_right[x] = sum_total[x]
        # and that sum_total is known, we are going to update
        # sum_left from the direction that require the least amount
        # of computations, i.e. from pos to new_pos or from end to new_po.

        if (new_pos - pos) <= (end - new_pos):
            for p in range(pos, new_pos):
                i = self.samples[p]

                if self.sample_weight != None:
                    w = self.sample_weight[i]

                for k in range(self.n_outputs):
                    label_index = k * self.sum_stride + int(self.y[i, k])
                    self.sum_left[label_index] += w

                self.weighted_n_left += w
        else:
            self.reverse_reset()

            for p in range(end - 1, new_pos - 1, -1):
                i = self.samples[p]

                if self.sample_weight != None:
                    w = self.sample_weight[i]

                for k in range(self.n_outputs):
                    label_index = k * self.sum_stride + int(self.y[i, k])
                    self.sum_left[label_index] -= w

                self.weighted_n_left -= w

        # Update right part statistics
        self.weighted_n_right = self.weighted_n_node_samples - self.weighted_n_left
        for k in range(self.n_outputs):
            for c in range(self.n_classes[k]):
                idx = k*self.sum_stride + c
                self.sum_right[idx] = self.sum_total[idx] - self.sum_left[idx]

        self.pos = new_pos
        return 0

    def node_value(self):
        """Compute the node value of samples[start:end] and save it into dest.
        Parameters
        ----------
        dest : double pointer
            The memory address which we will save the node value into.
        """
        dest = []

        for k in range(self.n_outputs):
            for j in range(self.n_classes[k]):
                dest.append(self.sum_total[k*self.sum_stride+j])

        return dest


class Entropy(ClassificationCriterion):
    r"""Cross Entropy impurity criterion.
    This handles cases where the target is a classification taking values
    0, 1, ... K-2, K-1. If node m represents a region Rm with Nm observations,
    then let
        count_k = 1 / Nm \sum_{x_i in Rm} I(yi = k)
    be the proportion of class k observations in node m.
    The cross-entropy is then defined as
        cross-entropy = -\sum_{k=0}^{K-1} count_k log(count_k)
    """
    def node_impurity(self):
        """Evaluate the impurity of the current node, i.e. the impurity of
        samples[start:end], using the cross-entropy criterion."""
        entropy = 0.0

        for k in range(self.n_outputs):
            for c in range(self.n_classes[k]):
                count_k = self.sum_total[k*self.sum_stride+c]
                if count_k > 0.0:
                    count_k /= self.weighted_n_node_samples
                    entropy -= count_k * log(count_k)

        return entropy / self.n_outputs

    def children_impurity(self):
        """Evaluate the impurity in children nodes
        i.e. the impurity of the left child (samples[start:pos]) and the
        impurity the right child (samples[pos:end]).
        Parameters
        ----------
        impurity_left : double pointer
            The memory address to save the impurity of the left node
        impurity_right : double pointer
            The memory address to save the impurity of the right node
        """
        entropy_left = 0.0
        entropy_right = 0.0

        for k in range(self.n_outputs):
            for c in range(self.n_classes[k]):
                count_k = self.sum_left[k*self.sum_stride+c]
                if count_k > 0.0:
                    count_k /= self.weighted_n_left
                    entropy_left -= count_k * log(count_k)

                count_k = self.sum_right[k*self.sum_stride+c]
                if count_k > 0.0:
                    count_k /= self.weighted_n_right
                    entropy_right -= count_k * log(count_k)

        impurity_left = entropy_left / self.n_outputs
        impurity_right = entropy_right / self.n_outputs
        
        return (impurity_left, impurity_right)


class Gini(ClassificationCriterion):
    r"""Gini Index impurity criterion.
    This handles cases where the target is a classification taking values
    0, 1, ... K-2, K-1. If node m represents a region Rm with Nm observations,
    then let
        count_k = 1/ Nm \sum_{x_i in Rm} I(yi = k)
    be the proportion of class k observations in node m.
    The Gini Index is then defined as:
        index = \sum_{k=0}^{K-1} count_k (1 - count_k)
              = 1 - \sum_{k=0}^{K-1} count_k ** 2
    """
    def node_impurity(self):
        """Evaluate the impurity of the current node, i.e. the impurity of
        samples[start:end] using the Gini criterion."""
        gini = 0.0

        for k in range(self.n_outputs):
            sq_count = 0.0

            for c in range(self.n_classes[k]):
                count_k = self.sum_total[k*self.sum_stride+c]
                sq_count += count_k * count_k

            gini += 1.0 - sq_count / (self.weighted_n_node_samples *
                                      self.weighted_n_node_samples)

        return gini / self.n_outputs

    def children_impurity(self):
        """Evaluate the impurity in children nodes
        i.e. the impurity of the left child (samples[start:pos]) and the
        impurity the right child (samples[pos:end]) using the Gini index.
        Parameters
        ----------
        impurity_left : double pointer
            The memory address to save the impurity of the left node to
        impurity_right : double pointer
            The memory address to save the impurity of the right node to
        """
        gini_left = 0.0
        gini_right = 0.0

        for k in range(self.n_outputs):
            sq_count_left = 0.0
            sq_count_right = 0.0

            for c in range(self.n_classes[k]):
                count_k = self.sum_left[k*self.sum_stride+c]
                sq_count_left += count_k * count_k

                count_k = self.sum_right[k*self.sum_stride+c]
                sq_count_right += count_k * count_k

            gini_left += 1.0 - sq_count_left / (self.weighted_n_left *
                                                self.weighted_n_left)

            gini_right += 1.0 - sq_count_right / (self.weighted_n_right *
                                                  self.weighted_n_right)

        impurity_left = gini_left / self.n_outputs
        impurity_right = gini_right / self.n_outputs
        
        return (impurity_left, impurity_right)

 

Splitter.py

import numpy as np
import math

RAND_R_MAX = 0x7FFFFFFF
DEFAULT_SEED = 1

# rand_r replacement using a 32bit XorShift generator
# See http://www.jstatsoft.org/v08/i14/paper for details
def our_rand_r(seed):
    """Generate a pseudo-random np.uint32 from a np.uint32 seed"""
    # seed shouldn't ever be 0.
    if (seed == 0): 
        seed = DEFAULT_SEED

    seed ^= seed << 13
    seed ^= seed >> 17
    seed ^= seed << 5

    # Note: we must be careful with the final line cast to np.uint32 so that
    # the function behaves consistently across platforms.
    #
    # The following cast might yield different results on different platforms:
    # wrong_cast = <UINT32_t> RAND_R_MAX + 1
    #
    # We can use:
    # good_cast = <UINT32_t>(RAND_R_MAX + 1)
    # or:
    # cdef np.uint32_t another_good_cast = <UINT32_t>RAND_R_MAX + 1
    return seed % (RAND_R_MAX + 1)

def rand_int(low, high, random_state):
    """Generate a random integer in [low; end)."""
    return low + our_rand_r(random_state) % (high - low)


def rand_uniform(low, high, random_state):
    """Generate a random double in [low; high)."""
    return ((high - low) * our_rand_r(random_state) /
            RAND_R_MAX) + low

def log(x):
    return math.log(x) / math.log(2.0)

INFINITY = np.inf

# Mitigate precision differences between 32 bit and 64 bit
FEATURE_THRESHOLD = 1e-7

# Constant to switch between algorithm non zero value extract algorithm
# in SparseSplitter
EXTRACT_NNZ_SWITCH = 0.1

class SplitRecord:
    # Data to track sample split
    feature = 0         # Which feature to split on.
    pos = 0            # Split samples array at the given position,
                           # i.e. count of samples below threshold for feature.
                           # pos is >= end if the node is a leaf.
    threshold = 1.0       # Threshold to split at.
    improvement = 1.0     # Impurity improvement given parent node.
    impurity_left = 1.0   # Impurity of the left split.
    impurity_right = 1.0  # Impurity of the right split.
    
    def copy(self):
        tmp = SplitRecord()
        tmp.feature = self.feature
        tmp.pos = self.pos
        tmp.threshold = self.threshold
        tmp.improvement = self.improvement
        tmp.impurity_left = self.impurity_left
        tmp.impurity_right = self.impurity_right
        return tmp


def _init_split(self, start_pos):
    self.impurity_left = INFINITY
    self.impurity_right = INFINITY
    self.pos = start_pos
    self.feature = 0
    self.threshold = 0.
    self.improvement = -INFINITY

class Splitter:
    """Abstract splitter class.
    Splitters are called by tree builders to find the best splits on both
    sparse and dense data, one split at a time.
    """
    def __init__(self, criterion, max_features,
                  min_samples_leaf, min_weight_leaf,
                  random_state):
        """
        Parameters
        ----------
        criterion : Criterion
            The criterion to measure the quality of a split.
        max_features : SIZE_t
            The maximal number of randomly selected features which can be
            considered for a split.
        min_samples_leaf : SIZE_t
            The minimal number of samples each leaf can have, where splits
            which would result in having less samples in a leaf are not
            considered.
        min_weight_leaf : double
            The minimal weight each leaf can have, where the weight is the sum
            of the weights of each sample in it.
        random_state : object
            The user inputted random state to be used for pseudo-randomness
        """
        self.criterion = criterion

        self.samples = None
        self.n_samples = 0
        self.features = None
        self.n_features = 0
        self.feature_values = None

        self.sample_weight = None

        self.max_features = max_features
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_leaf = min_weight_leaf
        self.random_state = random_state

    def init(self,
                   X,
                   y,
                   sample_weight,
                   X_idx_sorted=None):
        """Initialize the splitter.
        Take in the input data X, the target Y, and optional sample weights.
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        Parameters
        ----------
        X : object
            This contains the inputs. Usually it is a 2d numpy array.
        y : ndarray, dtype=DOUBLE_t
            This is the vector of targets, or true labels, for the samples
        sample_weight : DOUBLE_t*
            The weights of the samples, where higher weighted samples are fit
            closer than lower weight samples. If not provided, all samples
            are assumed to have uniform weight.
        X_idx_sorted : ndarray, default=None
            The indexes of the sorted training input samples
        """
        self.rand_r_state = self.random_state.randint(0, RAND_R_MAX)
        n_samples = X.shape[0]

        # Create a new array which will be used to store nonzero
        # samples from the feature of interest
        self.samples = []
        for idx in range(n_samples):
            self.samples.append(0)

        weighted_n_samples = 0.0
        j = 0

        for i in range(n_samples):
            # Only work with positively weighted samples
            if sample_weight == None or sample_weight[i] != 0.0:
                self.samples[j] = i
                j += 1

            if sample_weight != None:
                weighted_n_samples += sample_weight[i]
            else:
                weighted_n_samples += 1.0

        # Number of samples is number of positively weighted samples
        self.n_samples = j
        self.weighted_n_samples = weighted_n_samples

        n_features = X.shape[1]
        self.features = []
        for i in range(n_features):
            self.features.append(0)

        for i in range(n_features):
            self.features[i] = i

        self.n_features = n_features
        self.feature_values = []
        for idx in range(n_samples):
            self.feature_values.append(0)
        self.constant_features = []
        for idx in range(n_features):
            self.constant_features.append(0)
        self.y = y

        self.sample_weight = sample_weight
        return 0

    def node_reset(self, start, end):
        """Reset splitter on node samples[start:end].
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        Parameters
        ----------
        start : SIZE_t
            The index of the first sample to consider
        end : SIZE_t
            The index of the last sample to consider
        weighted_n_node_samples : ndarray, dtype=double pointer
            The total weight of those samples
        """
        self.start = start
        self.end = end

        self.criterion.init(self.y,
                            self.sample_weight,
                            self.weighted_n_samples,
                            self.samples,
                            start,
                            end)

        weighted_n_node_samples = self.criterion.weighted_n_node_samples
        return (0, weighted_n_node_samples)

    def node_value(self):
        """Copy the value of node samples[start:end] into dest."""
        return self.criterion.node_value()

    def node_impurity(self):
        """Return the impurity of the current node."""
        return self.criterion.node_impurity()


class BaseDenseSplitter(Splitter):
    def __init__(self, criterion, max_features,
                  min_samples_leaf, min_weight_leaf,
                  random_state):
        super().__init__(criterion, max_features, 
                       min_samples_leaf, min_weight_leaf,
                       random_state)
        self.X_idx_sorted_ptr = None
        self.X_idx_sorted_stride = 0
        self.sample_mask = None

    def init(self,
                  X,
                  y,
                  sample_weight,
                  X_idx_sorted=None):
        """Initialize the splitter
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        # Call parent init
        Splitter.init(self, X, y, sample_weight)
        self.X = X
        return 0


class BestSplitter(BaseDenseSplitter):
    """Splitter for finding the best split."""
    def node_split(self, impurity, 
                        n_constant_features):
        """Find the best split on node samples[start:end]
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        # Find the best split
        samples = self.samples
        start = self.start
        end = self.end

        features = self.features
        constant_features = self.constant_features
        n_features = self.n_features

        Xf = self.feature_values
        max_features = self.max_features
        min_samples_leaf = self.min_samples_leaf
        min_weight_leaf = self.min_weight_leaf
        random_state = self.rand_r_state

        current_proxy_improvement = -INFINITY
        best_proxy_improvement = -INFINITY

        f_i = n_features

        n_visited_features = 0
        # Number of features discovered to be constant during the split search
        n_found_constants = 0
        # Number of features known to be constant and drawn without replacement
        n_drawn_constants = 0
        n_known_constants = n_constant_features
        # n_total_constants = n_known_constants + n_found_constants
        n_total_constants = n_known_constants

        best = SplitRecord()
        current = SplitRecord()
        _init_split(best, end)
        
        # Sample up to max_features without replacement using a
        # Fisher-Yates-based algorithm (using the local variables `f_i` and
        # `f_j` to compute a permutation of the `features` array).
        #
        # Skip the CPU intensive evaluation of the impurity criterion for
        # features that were already detected as constant (hence not suitable
        # for good splitting) by ancestor nodes and save the information on
        # newly discovered constant features to spare computation on descendant
        # nodes.
        while (f_i > n_total_constants and  # Stop early if remaining features
                                            # are constant
                (n_visited_features < max_features or
                 # At least one drawn features must be non constant
                 n_visited_features <= n_found_constants + n_drawn_constants)):

            n_visited_features += 1

            # Loop invariant: elements of features in
            # - [:n_drawn_constant[ holds drawn and known constant features;
            # - [n_drawn_constant:n_known_constant[ holds known constant
            #   features that haven't been drawn yet;
            # - [n_known_constant:n_total_constant[ holds newly found constant
            #   features;
            # - [n_total_constant:f_i[ holds features that haven't been drawn
            #   yet and aren't constant apriori.
            # - [f_i:n_features[ holds features that have been drawn
            #   and aren't constant.

            # Draw a feature at random
            f_j = rand_int(n_drawn_constants, f_i - n_found_constants,
                           random_state)

            if f_j < n_known_constants:
                # f_j in the interval [n_drawn_constants, n_known_constants[
                features[n_drawn_constants], features[f_j] = features[f_j], features[n_drawn_constants]

                n_drawn_constants += 1

            else:
                # f_j in the interval [n_known_constants, f_i - n_found_constants[
                f_j += n_found_constants
                # f_j in the interval [n_total_constants, f_i[
                current.feature = features[f_j]

                # Sort samples along that feature; by
                # copying the values into an array and
                # sorting the array in a manner which utilizes the cache more
                # effectively.
                for i in range(start, end):
                    Xf[i] = self.X[samples[i], current.feature]

                tmp_list = [(Xf[i], samples[i]) for i in range(start, end)]
                tmp_list.sort(key=lambda item: item[0])
                for i in range(start, end):
                    Xf[i] = tmp_list[i-start][0]
                    samples[i] = tmp_list[i-start][1]

                # TODO: original sort algorithm
#                sort(Xf, samples, end - start, start)
 
                if Xf[end - 1] <= Xf[start] + FEATURE_THRESHOLD:
                    features[f_j], features[n_total_constants] = features[n_total_constants], features[f_j]

                    n_found_constants += 1
                    n_total_constants += 1

                else:
                    f_i -= 1
                    features[f_i], features[f_j] = features[f_j], features[f_i]

                    # Evaluate all splits
                    self.criterion.reset()
                    p = start

                    while p < end:
                        while (p + 1 < end and
                               Xf[p + 1] <= Xf[p] + FEATURE_THRESHOLD):
                            p += 1

                        # (p + 1 >= end) or (X[samples[p + 1], current.feature] >
                        #                    X[samples[p], current.feature])
                        p += 1
                        # (p >= end) or (X[samples[p], current.feature] >
                        #                X[samples[p - 1], current.feature])

                        if p < end:
                            current.pos = p

                            # Reject if min_samples_leaf is not guaranteed
                            if (((current.pos - start) < min_samples_leaf) or
                                    ((end - current.pos) < min_samples_leaf)):
                                continue

                            self.criterion.update(current.pos)

                            # Reject if min_weight_leaf is not satisfied
                            if ((self.criterion.weighted_n_left < min_weight_leaf) or
                                    (self.criterion.weighted_n_right < min_weight_leaf)):
                                continue

                            current_proxy_improvement = self.criterion.proxy_impurity_improvement()

                            if current_proxy_improvement > best_proxy_improvement:
                                best_proxy_improvement = current_proxy_improvement
                                # sum of halves is used to avoid infinite value
                                current.threshold = Xf[p - 1] / 2.0 + Xf[p] / 2.0

                                if ((current.threshold == Xf[p]) or
                                    (current.threshold == INFINITY) or
                                    (current.threshold == -INFINITY)):
                                    current.threshold = Xf[p - 1]

                                best = current.copy()  # copy

        # Reorganize into samples[start:best.pos] + samples[best.pos:end]
        if best.pos < end:
            partition_end = end
            p = start

            while p < partition_end:
                if self.X[samples[p], best.feature] <= best.threshold:
                    p += 1
                else:
                    partition_end -= 1
                    samples[p], samples[partition_end] = samples[partition_end], samples[p]

            self.criterion.reset()
            self.criterion.update(best.pos)
            best.improvement = self.criterion.impurity_improvement(impurity)
            (best.impurity_left, best.impurity_right) = self.criterion.children_impurity()

        # Respect invariant for constant features: the original order of
        # element in features[:n_known_constants] must be preserved for sibling
        # and child nodes
        for i in range(n_known_constants):
            features[i] = constant_features[i]

        # Copy newly found constant features
        for i in range(n_found_constants):
            constant_features[n_known_constants+i] = features[n_known_constants+i]

        # Return values
        split = best
        n_constant_features = n_total_constants
        return (0, split, n_constant_features)


# Sort n-element arrays pointed to by Xf and samples, simultaneously,
# by the values in Xf. Algorithm: Introsort (Musser, SP&E, 1997).
def sort(Xf, samples, n, start):
    if n == 0:
      return
    maxd = 2 * int(log(n))
    introsort(Xf, samples, n, maxd, start)


def swap(Xf, samples, i, j):
    # Helper for sort
    Xf[i], Xf[j] = Xf[j], Xf[i]
    samples[i], samples[j] = samples[j], samples[i]


def median3(Xf, n, idx):
    # Median of three pivot selection, after Bentley and McIlroy (1993).
    # Engineering a sort function. SP&E. Requires 8/3 comparisons on average.
    a = Xf[idx]
    b = Xf[n // 2 + idx]
    c = Xf[n - 1 + idx]
    if a < b:
        if b < c:
            return b
        elif a < c:
            return c
        else:
            return a
    elif b < c:
        if a < c:
            return a
        else:
            return c
    else:
        return b


# Introsort with median of 3 pivot selection and 3-way partition function
# (robust to repeated elements, e.g. lots of zero features).
def introsort(Xf, samples, n, maxd, idx):
    while n > 1:
        if maxd <= 0:   # max depth limit exceeded ("gone quadratic")
            heapsort(Xf, samples, n, idx)
            return
        maxd -= 1

        pivot = median3(Xf, n, idx)

        # Three-way partition.
        i = l = idx
        r = n + idx
        while i < r:
            if Xf[i] < pivot:
                swap(Xf, samples, i, l)
                i += 1
                l += 1
            elif Xf[i] > pivot:
                r -= 1
                swap(Xf, samples, i, r)
            else:
                i += 1

        introsort(Xf, samples, l-idx, maxd, idx)
        n -= (r-idx)
        idx = r


def sift_down(Xf, samples, start, end):
    # Restore heap order in Xf[start:end] by moving the max element to start.
    root = start
    while True:
        child = root * 2 + 1

        # find max of root, left child, right child
        maxind = root
        if child < end and Xf[maxind] < Xf[child]:
            maxind = child
        if child + 1 < end and Xf[maxind] < Xf[child + 1]:
            maxind = child + 1

        if maxind == root:
            break
        else:
            swap(Xf, samples, root, maxind)
            root = maxind


def heapsort(Xf, samples, n, idx):
    # heapify
    start = (n - 2) // 2 + idx
    end = n + idx
    while True:
        sift_down(Xf, samples, start, end)
        if start == 0:
            break
        start -= 1

    # sort by shrinking the heap, putting the max element immediately after it
    end = n - 1 + idx
    while end > idx:
        swap(Xf, samples, idx, end)
        sift_down(Xf, samples, idx, end)
        end = end - 1


class RandomSplitter(BaseDenseSplitter):
    """Splitter for finding the best random split."""
    def node_split(self, impurity, 
                        n_constant_features):
        """Find the best random split on node samples[start:end]
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        # Draw random splits and pick the best
        samples = self.samples
        start = self.start
        end = self.end

        features = self.features
        constant_features = self.constant_features
        n_features = self.n_features

        Xf = self.feature_values
        max_features = self.max_features
        min_samples_leaf = self.min_samples_leaf
        min_weight_leaf = self.min_weight_leaf
        random_state = self.rand_r_state

        best = SplitRecord()
        current = SplitRecord()
        current_proxy_improvement = - INFINITY
        best_proxy_improvement = - INFINITY

        f_i = n_features
        # Number of features discovered to be constant during the split search
        n_found_constants = 0
        # Number of features known to be constant and drawn without replacement
        n_drawn_constants = 0
        n_known_constants = n_constant_features
        # n_total_constants = n_known_constants + n_found_constants
        n_total_constants = n_known_constants
        n_visited_features = 0

        _init_split(best, end)

        # Sample up to max_features without replacement using a
        # Fisher-Yates-based algorithm (using the local variables `f_i` and
        # `f_j` to compute a permutation of the `features` array).
        #
        # Skip the CPU intensive evaluation of the impurity criterion for
        # features that were already detected as constant (hence not suitable
        # for good splitting) by ancestor nodes and save the information on
        # newly discovered constant features to spare computation on descendant
        # nodes.
        while (f_i > n_total_constants and  # Stop early if remaining features
                                            # are constant
                (n_visited_features < max_features or
                 # At least one drawn features must be non constant
                 n_visited_features <= n_found_constants + n_drawn_constants)):
            n_visited_features += 1

            # Loop invariant: elements of features in
            # - [:n_drawn_constant[ holds drawn and known constant features;
            # - [n_drawn_constant:n_known_constant[ holds known constant
            #   features that haven't been drawn yet;
            # - [n_known_constant:n_total_constant[ holds newly found constant
            #   features;
            # - [n_total_constant:f_i[ holds features that haven't been drawn
            #   yet and aren't constant apriori.
            # - [f_i:n_features[ holds features that have been drawn
            #   and aren't constant.

            # Draw a feature at random
            f_j = rand_int(n_drawn_constants, f_i - n_found_constants,
                           random_state)

            if f_j < n_known_constants:
                # f_j in the interval [n_drawn_constants, n_known_constants[
                features[n_drawn_constants], features[f_j] = features[f_j], features[n_drawn_constants]
                n_drawn_constants += 1

            else:
                # f_j in the interval [n_known_constants, f_i - n_found_constants[
                f_j += n_found_constants
                # f_j in the interval [n_total_constants, f_i[

                current.feature = features[f_j]

                # Find min, max
                min_feature_value = self.X[samples[start], current.feature]
                max_feature_value = min_feature_value
                Xf[start] = min_feature_value

                for p in range(start + 1, end):
                    current_feature_value = self.X[samples[p], current.feature]
                    Xf[p] = current_feature_value

                    if current_feature_value < min_feature_value:
                        min_feature_value = current_feature_value
                    elif current_feature_value > max_feature_value:
                        max_feature_value = current_feature_value

                if max_feature_value <= min_feature_value + FEATURE_THRESHOLD:
                    features[f_j], features[n_total_constants] = features[n_total_constants], current.feature

                    n_found_constants += 1
                    n_total_constants += 1

                else:
                    f_i -= 1
                    features[f_i], features[f_j] = features[f_j], features[f_i]

                    # Draw a random threshold
                    current.threshold = rand_uniform(min_feature_value,
                                                     max_feature_value,
                                                     random_state)

                    if current.threshold == max_feature_value:
                        current.threshold = min_feature_value

                    # Partition
                    p, partition_end = start, end
                    while p < partition_end:
                        if Xf[p] <= current.threshold:
                            p += 1
                        else:
                            partition_end -= 1

                            Xf[p], Xf[partition_end] = Xf[partition_end], Xf[p]
                            samples[p], samples[partition_end] = samples[partition_end], samples[p]

                    current.pos = partition_end

                    # Reject if min_samples_leaf is not guaranteed
                    if (((current.pos - start) < min_samples_leaf) or
                            ((end - current.pos) < min_samples_leaf)):
                        continue

                    # Evaluate split
                    self.criterion.reset()
                    self.criterion.update(current.pos)

                    # Reject if min_weight_leaf is not satisfied
                    if ((self.criterion.weighted_n_left < min_weight_leaf) or
                            (self.criterion.weighted_n_right < min_weight_leaf)):
                        continue

                    current_proxy_improvement = self.criterion.proxy_impurity_improvement()

                    if current_proxy_improvement > best_proxy_improvement:
                        best_proxy_improvement = current_proxy_improvement
                        best = current.copy()  # copy

        # Reorganize into samples[start:best.pos] + samples[best.pos:end]
        if best.pos < end:
            if current.feature != best.feature:
                p, partition_end = start, end

                while p < partition_end:
                    if self.X[samples[p], best.feature] <= best.threshold:
                        p += 1
                    else:
                        partition_end -= 1

                        samples[p], samples[partition_end] = samples[partition_end], samples[p]

            self.criterion.reset()
            self.criterion.update(best.pos)
            best.improvement = self.criterion.impurity_improvement(impurity)
            (best.impurity_left, best.impurity_right) = self.criterion.children_impurity()

        # Respect invariant for constant features: the original order of
        # element in features[:n_known_constants] must be preserved for sibling
        # and child nodes
        for i in range(n_known_constants):
            features[i] = constant_features[i]

        # Copy newly found constant features
        for i in range(n_found_constants):
            constant_features[n_known_constants+i] + features[n_known_constants+i]

        # Return values
        split = best
        n_constant_features = n_total_constants
        return (0, split, n_constant_features)

注意此处保留了原始的排序算法,但实际上采用了python中自带的sort来完成排序。两者在对相等数值排序时先后顺序不一致,但至少在后面的测试中预测正确率是一样的。

 

Tree.py

import numpy as np
from .Utils import Node, Stack, PriorityHeap, PriorityHeapRecord, SIZE_MAX


# =============================================================================
# Types and constants
# =============================================================================

INFINITY = np.inf
EPSILON = np.finfo('double').eps

# Some handy constants (BestFirstTreeBuilder)
IS_FIRST = 1
IS_NOT_FIRST = 0
IS_LEFT = 1
IS_NOT_LEFT = 0

TREE_LEAF = -1
_TREE_LEAF = TREE_LEAF
_TREE_UNDEFINED = -2
INITIAL_STACK_SIZE = 10

# Depth first builder ---------------------------------------------------------

class DepthFirstTreeBuilder:
    """Build a decision tree in depth-first fashion."""
    def __init__(self, splitter, min_samples_split,
                  min_samples_leaf, min_weight_leaf,
                  max_depth, min_impurity_decrease,
                  min_impurity_split):
        self.splitter = splitter
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_leaf = min_weight_leaf
        self.max_depth = max_depth
        self.min_impurity_decrease = min_impurity_decrease
        self.min_impurity_split = min_impurity_split

    def build(self, tree, X, y,
                sample_weight=None,
                X_idx_sorted=None):
        """Build a decision tree from the training set (X, y)."""
        if tree.max_depth <= 10:
            init_capacity = (2 ** (tree.max_depth + 1)) - 1
        else:
            init_capacity = 2047

        tree._resize(init_capacity)

        # Parameters
        splitter = self.splitter
        max_depth = self.max_depth
        min_samples_leaf = self.min_samples_leaf
        min_weight_leaf = self.min_weight_leaf
        min_samples_split = self.min_samples_split
        min_impurity_decrease = self.min_impurity_decrease
        min_impurity_split = self.min_impurity_split

        # Recursive partition (without actual recursion)
        splitter.init(X, y, sample_weight, X_idx_sorted)

        n_node_samples = splitter.n_samples

        impurity = INFINITY
        first = 1
        max_depth_seen = -1
        rc = 0

        stack = Stack(INITIAL_STACK_SIZE)

        # push root node onto stack
        rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0)
        if rc == -1:
            # got return code -1 - out-of-memory
            raise MemoryError()

        while not stack.is_empty():
            (_, stack_record) = stack.pop()

            start = stack_record.start
            end = stack_record.end
            depth = stack_record.depth
            parent = stack_record.parent
            is_left = stack_record.is_left
            impurity = stack_record.impurity
            n_constant_features = stack_record.n_constant_features

            n_node_samples = end - start
            (_, weighted_n_node_samples) = splitter.node_reset(start, end)

            is_leaf = (depth >= max_depth or
                       n_node_samples < min_samples_split or
                       n_node_samples < 2 * min_samples_leaf or
                       weighted_n_node_samples < 2 * min_weight_leaf)

            if first:
                impurity = splitter.node_impurity()
                first = 0

            is_leaf = (is_leaf or (impurity <= min_impurity_split))

            if not is_leaf:
                (_, split, n_constant_features) = splitter.node_split(impurity, n_constant_features)
                # If EPSILON=0 in the below comparison, float precision
                # issues stop splitting, producing trees that are
                # dissimilar to v0.18
                is_leaf = (is_leaf or split.pos >= end or
                           (split.improvement + EPSILON <
                            min_impurity_decrease))

            node_id = tree._add_node(parent, is_left, is_leaf, split.feature,
                                     split.threshold, impurity, n_node_samples,
                                     weighted_n_node_samples)

            if node_id == SIZE_MAX:
                rc = -1
                break

            # Store value for all nodes, to facilitate tree/model
            # inspection and interpretation
            tree.value[node_id*tree.value_stride:(node_id+1)*tree.value_stride] = splitter.node_value()

            if not is_leaf:
                # Push right child on stack
                rc = stack.push(split.pos, end, depth + 1, node_id, 0,
                                split.impurity_right, n_constant_features)
                if rc == -1:
                    break

                # Push left child on stack
                rc = stack.push(start, split.pos, depth + 1, node_id, 1,
                                split.impurity_left, n_constant_features)
                if rc == -1:
                    break

            if depth > max_depth_seen:
                max_depth_seen = depth

        if rc >= 0:
            rc = tree._resize(tree.node_count)

        if rc >= 0:
            tree.max_depth = max_depth_seen
            
        if rc == -1:
            raise MemoryError()


# Best first builder ----------------------------------------------------------

def _add_to_frontier(rec, frontier):
    """Adds record ``rec`` to the priority queue ``frontier``
    Returns -1 in case of failure to allocate memory (and raise MemoryError)
    or 0 otherwise.
    """
    return frontier.push(rec.node_id, rec.start, rec.end, rec.pos, rec.depth,
                         rec.is_leaf, rec.improvement, rec.impurity,
                         rec.impurity_left, rec.impurity_right)


class BestFirstTreeBuilder:
    """Build a decision tree in best-first fashion.
    The best node to expand is given by the node at the frontier that has the
    highest impurity improvement.
    """
    def __init__(self, splitter, min_samples_split,
                  min_samples_leaf,  min_weight_leaf,
                  max_depth, max_leaf_nodes,
                  min_impurity_decrease, min_impurity_split):
        self.splitter = splitter
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_leaf = min_weight_leaf
        self.max_depth = max_depth
        self.max_leaf_nodes = max_leaf_nodes
        self.min_impurity_decrease = min_impurity_decrease
        self.min_impurity_split = min_impurity_split

    def build(self, tree, X, y,
                sample_weight=None,
                X_idx_sorted=None):
        """Build a decision tree from the training set (X, y)."""
        # Parameters
        splitter = self.splitter
        max_leaf_nodes = self.max_leaf_nodes

        # Recursive partition (without actual recursion)
        splitter.init(X, y, sample_weight, X_idx_sorted)

        frontier = PriorityHeap(INITIAL_STACK_SIZE)
        split_node_left = PriorityHeapRecord()
        split_node_right = PriorityHeapRecord()

        n_node_samples = splitter.n_samples
        max_split_nodes = max_leaf_nodes - 1
        max_depth_seen = -1
        rc = 0

        # Initial capacity
        init_capacity = max_split_nodes + max_leaf_nodes
        tree._resize(init_capacity)

        # add root to frontier
        rc = self._add_split_node(splitter, tree, 0, n_node_samples,
                                  INFINITY, IS_FIRST, IS_LEFT, None, 0,
                                  split_node_left)
        if rc >= 0:
            rc = _add_to_frontier(split_node_left, frontier)

        if rc == -1:
            raise MemoryError()

        while not frontier.is_empty():
            (_, record) = frontier.pop()

            node = tree.nodes[record.node_id]
            is_leaf = (record.is_leaf or max_split_nodes <= 0)

            if is_leaf:
                # Node is not expandable; set node as leaf
                node.left_child = _TREE_LEAF
                node.right_child = _TREE_LEAF
                node.feature = _TREE_UNDEFINED
                node.threshold = _TREE_UNDEFINED
            else:
                # Node is expandable

                # Decrement number of split nodes available
                max_split_nodes -= 1

                # Compute left split node
                rc = self._add_split_node(splitter, tree,
                                          record.start, record.pos,
                                          record.impurity_left,
                                          IS_NOT_FIRST, IS_LEFT, node,
                                          record.depth + 1,
                                          split_node_left)
                if rc == -1:
                    break

                # tree.nodes may have changed
                node = tree.nodes[record.node_id]

                # Compute right split node
                rc = self._add_split_node(splitter, tree, record.pos,
                                          record.end,
                                          record.impurity_right,
                                          IS_NOT_FIRST, IS_NOT_LEFT, node,
                                          record.depth + 1,
                                          split_node_right)
                if rc == -1:
                    break

                # Add nodes to queue
                rc = _add_to_frontier(split_node_left, frontier)
                if rc == -1:
                    break

                rc = _add_to_frontier(split_node_right, frontier)
                if rc == -1:
                    break

            if record.depth > max_depth_seen:
                max_depth_seen = record.depth

            if rc >= 0:
                rc = tree._resize(tree.node_count)

            if rc >= 0:
                tree.max_depth = max_depth_seen

        if rc == -1:
            raise MemoryError()

    def _add_split_node(self, splitter, tree,
                                    start, end, impurity,
                                    is_first, is_left, parent,
                                    depth,
                                    res):
        """Adds node w/ partition ``[start, end)`` to the frontier. """
        n_constant_features = 0
        min_impurity_decrease = self.min_impurity_decrease
        min_impurity_split = self.min_impurity_split
        weighted_n_node_samples = None

        splitter.node_reset(start, end, weighted_n_node_samples)

        if is_first:
            impurity = splitter.node_impurity()

        n_node_samples = end - start
        is_leaf = (depth >= self.max_depth or
                   n_node_samples < self.min_samples_split or
                   n_node_samples < 2 * self.min_samples_leaf or
                   weighted_n_node_samples < 2 * self.min_weight_leaf or
                   impurity <= min_impurity_split)

        if not is_leaf:
            (_, split, n_constant_features) = splitter.node_split(impurity, n_constant_features)
            # If EPSILON=0 in the below comparison, float precision issues stop
            # splitting early, producing trees that are dissimilar to v0.18
            is_leaf = (is_leaf or split.pos >= end or
                       split.improvement + EPSILON < min_impurity_decrease)

        # TODO: XXX
        node_id = tree._add_node(parent - tree.nodes
                                 if parent != None
                                 else _TREE_UNDEFINED,
                                 is_left, is_leaf,
                                 split.feature, split.threshold, impurity, n_node_samples,
                                 weighted_n_node_samples)
        if node_id == SIZE_MAX:
            return -1

        # compute values also for split nodes (might become leafs later).
        tree.value[node_id*tree.value_stride:(node_id+1)*tree.value_stride] = splitter.node_value()

        res.node_id = node_id
        res.start = start
        res.end = end
        res.depth = depth
        res.impurity = impurity

        if not is_leaf:
            # is split node
            res.pos = split.pos
            res.is_leaf = 0
            res.improvement = split.improvement
            res.impurity_left = split.impurity_left
            res.impurity_right = split.impurity_right
        else:
            # is leaf => 0 improvement
            res.pos = end
            res.is_leaf = 1
            res.improvement = 0.0
            res.impurity_left = impurity
            res.impurity_right = impurity

        return 0

# =============================================================================
# Tree
# =============================================================================

class Tree:
    """Array-based representation of a binary decision tree.
    The binary tree is represented as a number of parallel arrays. The i-th
    element of each array holds information about the node `i`. Node 0 is the
    tree's root. You can find a detailed description of all arrays in
    `_tree.pxd`. NOTE: Some of the arrays only apply to either leaves or split
    nodes, resp. In this case the values of nodes of the other type are
    arbitrary!
    Attributes
    ----------
    node_count : int
        The number of nodes (internal nodes + leaves) in the tree.
    capacity : int
        The current capacity (i.e., size) of the arrays, which is at least as
        great as `node_count`.
    max_depth : int
        The depth of the tree, i.e. the maximum depth of its leaves.
    children_left : array of int, shape [node_count]
        children_left[i] holds the node id of the left child of node i.
        For leaves, children_left[i] == TREE_LEAF. Otherwise,
        children_left[i] > i. This child handles the case where
        X[:, feature[i]] <= threshold[i].
    children_right : array of int, shape [node_count]
        children_right[i] holds the node id of the right child of node i.
        For leaves, children_right[i] == TREE_LEAF. Otherwise,
        children_right[i] > i. This child handles the case where
        X[:, feature[i]] > threshold[i].
    feature : array of int, shape [node_count]
        feature[i] holds the feature to split on, for the internal node i.
    threshold : array of double, shape [node_count]
        threshold[i] holds the threshold for the internal node i.
    value : array of double, shape [node_count, n_outputs, max_n_classes]
        Contains the constant prediction value of each node.
    impurity : array of double, shape [node_count]
        impurity[i] holds the impurity (i.e., the value of the splitting
        criterion) at node i.
    n_node_samples : array of int, shape [node_count]
        n_node_samples[i] holds the number of training samples reaching node i.
    weighted_n_node_samples : array of int, shape [node_count]
        weighted_n_node_samples[i] holds the weighted number of training samples
        reaching node i.
    """
    # Wrap for outside world.
    # WARNING: these reference the current `nodes` and `value` buffers, which
    # must not be freed by a subsequent memory allocation.
    # (i.e. through `_resize` or `__setstate__`)
    def __init__(self, n_features, n_classes,
                  n_outputs):
        """Constructor."""
        # Input/Output layout
        self.n_features = n_features
        self.n_outputs = n_outputs
        self.n_classes = []
        for idx in range(n_outputs):
            self.n_classes.append(n_classes[idx])

        self.max_n_classes = np.max(n_classes)
        self.value_stride = n_outputs * self.max_n_classes

        # Inner structures
        self.max_depth = 0
        self.node_count = 0
        self.capacity = 0
        self.value = []
        self.nodes = []

    def _resize(self, capacity=SIZE_MAX):
        """Guts of _resize
        Returns -1 in case of failure to allocate memory (and raise MemoryError)
        or 0 otherwise.
        """
        if capacity == self.capacity and self.nodes != None:
            return 0

        if capacity == SIZE_MAX:
            if self.capacity == 0:
                capacity = 3  # default initial value
            else:
                capacity = 2 * self.capacity

        for idx in range(capacity-self.capacity):
            self.nodes.append(Node())

        # value memory is initialised to 0 to enable classifier argmax
        if capacity > self.capacity:
            for idx in range(self.capacity*self.value_stride, capacity * self.value_stride):
                self.value.append(0)

        # if capacity smaller than node_count, adjust the counter
        if capacity < self.node_count:
            self.node_count = capacity

        self.capacity = capacity
        return 0

    def _add_node(self, parent, is_left, is_leaf,
                          feature, threshold, impurity,
                          n_node_samples,
                          weighted_n_node_samples):
        """Add a node to the tree.
        The new node registers itself as the child of its parent.
        Returns (size_t)(-1) on error.
        """
        node_id = self.node_count

        if node_id >= self.capacity:
            if self._resize() != 0:
                return SIZE_MAX

        node = self.nodes[node_id]
        node.impurity = impurity
        node.n_node_samples = n_node_samples
        node.weighted_n_node_samples = weighted_n_node_samples

        if parent != _TREE_UNDEFINED:
            if is_left:
                self.nodes[parent].left_child = node_id
            else:
                self.nodes[parent].right_child = node_id

        if is_leaf:
            node.left_child = _TREE_LEAF
            node.right_child = _TREE_LEAF
            node.feature = _TREE_UNDEFINED
            node.threshold = _TREE_UNDEFINED
        else:
            # left_child and right_child will be set later
            node.feature = feature
            node.threshold = threshold

        self.node_count += 1
        return node_id

    def predict(self, X):
        """Predict target for X."""
        out = self._get_value_ndarray().take(self.apply(X), axis=0,
                                             mode='clip')
        if self.n_outputs == 1:
            out = out.reshape(X.shape[0], self.max_n_classes)
        return out

    def apply(self, X):
        """Finds the terminal region (=leaf node) for each sample in X."""
        return self._apply_dense(X)

    def _apply_dense(self, X):
        """Finds the terminal region (=leaf node) for each sample in X."""
        # Check input
        if not isinstance(X, np.ndarray):
            raise ValueError("X should be in np.ndarray format, got %s"
                             % type(X))

        # Extract input
        X_ndarray = X
        n_samples = X.shape[0]

        # Initialize output
        out = np.zeros((n_samples,), dtype=np.intp)
        out_ptr = 0

        for i in range(n_samples):
            node = self.nodes[0]
            # While node not a leaf
            while node.left_child != _TREE_LEAF:
                # ... and node.right_child != _TREE_LEAF:
                if X_ndarray[i, node.feature] <= node.threshold:
                    out_ptr = node.left_child
                    node = self.nodes[node.left_child]
                else:
                    out_ptr = node.right_child
                    node = self.nodes[node.right_child]
                    
            out[i] = out_ptr  # node offset    
        return out

    def compute_feature_importances(self, normalize=True):
        """Computes the importance of each feature (aka variable)."""
        nodes = self.nodes
        node = 0
        end_node = len(self.nodes) - 1

        normalizer = 0.
        importances = np.zeros((self.n_features,))

        while node != end_node:
            if nodes[node].left_child != _TREE_LEAF:
                # ... and node.right_child != _TREE_LEAF:
                left = nodes[nodes[node].left_child]
                right = nodes[nodes[node].right_child]

                importances[nodes[node].feature] += (
                    nodes[node].weighted_n_node_samples * nodes[node].impurity -
                    left.weighted_n_node_samples * left.impurity -
                    right.weighted_n_node_samples * right.impurity)
            node += 1

        importances /= nodes[0].weighted_n_node_samples

        if normalize:
            normalizer = np.sum(importances)
            if normalizer > 0.0:
                # Avoid dividing by zero (e.g., when root is pure)
                importances /= normalizer

        return importances

    def _get_value_ndarray(self):
        """Wraps value as a 3-d NumPy array.
        The array keeps a reference to this Tree, which manages the underlying
        memory.
        """
        n_value = len(self.value)
        tmp = np.zeros((self.node_count, self.n_outputs, self.max_n_classes))
        if (n_value < self.node_count*self.n_outputs*self.max_n_classes):
            for idx in range(n_value):
                i = idx//(self.n_outputs*self.max_n_classes)
                j = (idx - i*(self.n_outputs*self.max_n_classes))//self.max_n_classes
                k = idx - i*(self.n_outputs*self.max_n_classes) - j*self.max_n_classes
                tmp[i][j][k] = self.value[idx]
        else:
            for i in range(self.node_count):
                for j in range(self.n_outputs):
                    for k in range(self.max_n_classes):
                        tmp[i][j][k] = self.value[(i*self.n_outputs+j)*self.max_n_classes+k]
        return tmp

    def _get_node_ndarray(self):
        """Wraps nodes as a NumPy struct array.
        The array keeps a reference to this Tree, which manages the underlying
        memory. Individual fields are publicly accessible as properties of the
        Tree.
        """
        return np.array(self.nodes)

    def compute_partial_dependence(self, X,
                                   target_features,
                                   out):
        """Partial dependence of the response on the ``target_feature`` set.
        For each sample in ``X`` a tree traversal is performed.
        Each traversal starts from the root with weight 1.0.
        At each non-leaf node that splits on a target feature, either
        the left child or the right child is visited based on the feature
        value of the current sample, and the weight is not modified.
        At each non-leaf node that splits on a complementary feature,
        both children are visited and the weight is multiplied by the fraction
        of training samples which went to each child.
        At each leaf, the value of the node is multiplied by the current
        weight (weights sum to 1 for all visited terminal nodes).
        Parameters
        ----------
        X : view on 2d ndarray, shape (n_samples, n_target_features)
            The grid points on which the partial dependence should be
            evaluated.
        target_features : view on 1d ndarray, shape (n_target_features)
            The set of target features for which the partial dependence
            should be evaluated.
        out : view on 1d ndarray, shape (n_samples)
            The value of the partial dependence function on each grid
            point.
        """
        weight_stack = np.zeros(self.node_count,
                                                dtype=np.float64)
        node_idx_stack = np.zeros(self.node_count,
                                                 dtype=np.intp)

        for sample_idx in range(X.shape[0]):
            # init stacks for current sample
            stack_size = 1
            node_idx_stack[0] = 0  # root node
            weight_stack[0] = 1  # all the samples are in the root node
            total_weight = 0

            while stack_size > 0:
                # pop the stack
                stack_size -= 1
                current_node_idx = node_idx_stack[stack_size]
                current_node = self.nodes[current_node_idx]

                if current_node.left_child == _TREE_LEAF:
                    # leaf node
                    out[sample_idx] += (weight_stack[stack_size] *
                                        self.value[current_node_idx])
                    total_weight += weight_stack[stack_size]
                else:  # non-leaf node
                    # determine if the split feature is a target feature
                    is_target_feature = False
                    for feature_idx in range(target_features.shape[0]):
                        if target_features[feature_idx] == current_node.feature:
                            is_target_feature = True
                            break

                    if is_target_feature:
                        # In this case, we push left or right child on stack
                        if X[sample_idx, feature_idx] <= current_node.threshold:
                            node_idx_stack[stack_size] = current_node.left_child
                        else:
                            node_idx_stack[stack_size] = current_node.right_child
                        stack_size += 1
                    else:
                        # In this case, we push both children onto the stack,
                        # and give a weight proportional to the number of
                        # samples going through each branch.

                        # push left child
                        node_idx_stack[stack_size] = current_node.left_child
                        left_sample_frac = (
                            self.nodes[current_node.left_child].weighted_n_node_samples /
                            current_node.weighted_n_node_samples)
                        current_weight = weight_stack[stack_size]
                        weight_stack[stack_size] = current_weight * left_sample_frac
                        stack_size += 1

                        # push right child
                        node_idx_stack[stack_size] = current_node.right_child
                        weight_stack[stack_size] = (
                            current_weight * (1 - left_sample_frac))
                        stack_size += 1

            # Sanity check. Should never happen.
            if not (0.999 < total_weight < 1.001):
                raise ValueError("Total weight should be 1.0 but was %.9f" %
                                 total_weight)


# =============================================================================
# Build Pruned Tree
# =============================================================================


class _CCPPruneController:
    """Base class used by build_pruned_tree_ccp and ccp_pruning_path
    to control pruning.
    """
    def stop_pruning(self, effective_alpha):
        """Return 1 to stop pruning and 0 to continue pruning"""
        return 0


class _AlphaPruner(_CCPPruneController):
    """Use alpha to control when to stop pruning."""
    def __init__(self, ccp_alpha):
        self.ccp_alpha = ccp_alpha
        self.capacity = 0

    def stop_pruning(self, effective_alpha):
        # The subtree on the previous iteration has the greatest ccp_alpha
        # less than or equal to self.ccp_alpha
        return self.ccp_alpha < effective_alpha

    def after_pruning(self, in_subtree):
        """Updates the number of leaves in subtree"""
        for i in range(in_subtree.shape[0]):
            if in_subtree[i]:
                self.capacity += 1


class _PathFinder(_CCPPruneController):
    """Record metrics used to return the cost complexity path."""
    def __init__(self,  node_count):
        self.ccp_alphas = np.zeros(shape=(node_count), dtype=np.float64)
        self.impurities = np.zeros(shape=(node_count), dtype=np.float64)
        self.count = 0

    def save_metrics(self, effective_alpha,
                           subtree_impurities):
        self.ccp_alphas[self.count] = effective_alpha
        self.impurities[self.count] = subtree_impurities
        self.count += 1


def _cost_complexity_prune(leaves_in_subtree, # OUT
                            orig_tree, controller):
    """Perform cost complexity pruning.
    This function takes an already grown tree, `orig_tree` and outputs a
    boolean mask `leaves_in_subtree` to are the leaves in the pruned tree. The
    controller signals when the pruning should stop and is passed the
    metrics of the subtrees during the pruning process.
    Parameters
    ----------
    leaves_in_subtree : unsigned char[:]
        Output for leaves of subtree
    orig_tree : Tree
        Original tree
    ccp_controller : _CCPPruneController
        Cost complexity controller
    """
    n_nodes = orig_tree.node_count
    # prior probability using weighted samples
    weighted_n_node_samples = orig_tree.weighted_n_node_samples
    total_sum_weights = weighted_n_node_samples[0]
    impurity = orig_tree.impurity
    # weighted impurity of each node
    r_node = np.empty(shape=n_nodes, dtype=np.float64)

    child_l = orig_tree.children_left
    child_r = orig_tree.children_right
    parent = np.zeros(shape=n_nodes, dtype=np.intp)

    # Only uses the start and parent variables
    stack = Stack(INITIAL_STACK_SIZE)
    stack_record = None
    rc = 0

    n_leaves = np.zeros(shape=n_nodes, dtype=np.intp)
    r_branch = np.zeros(shape=n_nodes, dtype=np.float64)

    # candidate nodes that can be pruned
    candidate_nodes = np.zeros(shape=n_nodes,
                                                    dtype=np.uint8)
    # nodes in subtree
    in_subtree = np.ones(shape=n_nodes, dtype=np.uint8)
    max_float64 = np.finfo(np.float64).max

    # find parent node ids and leaves
    for i in range(r_node.shape[0]):
        r_node[i] = (
            weighted_n_node_samples[i] * impurity[i] / total_sum_weights)

    # Push root node, using StackRecord.start as node id
    rc = stack.push(0, 0, 0, -1, 0, 0, 0)
    if rc == -1:
        raise MemoryError("pruning tree")

    while not stack.is_empty():
        stack.pop(stack_record)
        node_idx = stack_record.start
        parent[node_idx] = stack_record.parent
        if child_l[node_idx] == _TREE_LEAF:
            # ... and child_r[node_idx] == _TREE_LEAF:
            leaves_in_subtree[node_idx] = 1
        else:
            rc = stack.push(child_l[node_idx], 0, 0, node_idx, 0, 0, 0)
            if rc == -1:
                raise MemoryError("pruning tree")

            rc = stack.push(child_r[node_idx], 0, 0, node_idx, 0, 0, 0)
            if rc == -1:
                raise MemoryError("pruning tree")

    # computes number of leaves in all branches and the overall impurity of
    # the branch. The overall impurity is the sum of r_node in its leaves.
    for leaf_idx in range(leaves_in_subtree.shape[0]):
        if not leaves_in_subtree[leaf_idx]:
            continue
        r_branch[leaf_idx] = r_node[leaf_idx]

        # bubble up values to ancestor nodes
        current_r = r_node[leaf_idx]
        while leaf_idx != 0:
            parent_idx = parent[leaf_idx]
            r_branch[parent_idx] += current_r
            n_leaves[parent_idx] += 1
            leaf_idx = parent_idx

    for i in range(leaves_in_subtree.shape[0]):
        candidate_nodes[i] = not leaves_in_subtree[i]

    # save metrics before pruning
    controller.save_metrics(0.0, r_branch[0])

    # while root node is not a leaf
    while candidate_nodes[0]:
        # computes ccp_alpha for subtrees and finds the minimal alpha
        effective_alpha = max_float64
        for i in range(n_nodes):
            if not candidate_nodes[i]:
                continue
            subtree_alpha = (r_node[i] - r_branch[i]) / (n_leaves[i] - 1)
            if subtree_alpha < effective_alpha:
                effective_alpha = subtree_alpha
                pruned_branch_node_idx = i

        if controller.stop_pruning(effective_alpha):
            break

        # stack uses only the start variable
        rc = stack.push(pruned_branch_node_idx, 0, 0, 0, 0, 0, 0)
        if rc == -1:
            raise MemoryError("pruning tree")

        # descendants of branch are not in subtree
        while not stack.is_empty():
            stack.pop(stack_record)
            node_idx = stack_record.start

            if not in_subtree[node_idx]:
                continue # branch has already been marked for pruning
            candidate_nodes[node_idx] = 0
            leaves_in_subtree[node_idx] = 0
            in_subtree[node_idx] = 0

            if child_l[node_idx] != _TREE_LEAF:
                # ... and child_r[node_idx] != _TREE_LEAF:
                rc = stack.push(child_l[node_idx], 0, 0, 0, 0, 0, 0)
                if rc == -1:
                    raise MemoryError("pruning tree")
                rc = stack.push(child_r[node_idx], 0, 0, 0, 0, 0, 0)
                if rc == -1:
                    raise MemoryError("pruning tree")
        leaves_in_subtree[pruned_branch_node_idx] = 1
        in_subtree[pruned_branch_node_idx] = 1

        # updates number of leaves
        n_pruned_leaves = n_leaves[pruned_branch_node_idx] - 1
        n_leaves[pruned_branch_node_idx] = 0

        # computes the increase in r_branch to bubble up
        r_diff = r_node[pruned_branch_node_idx] - r_branch[pruned_branch_node_idx]
        r_branch[pruned_branch_node_idx] = r_node[pruned_branch_node_idx]

        # bubble up values to ancestors
        node_idx = parent[pruned_branch_node_idx]
        while node_idx != -1:
            n_leaves[node_idx] -= n_pruned_leaves
            r_branch[node_idx] += r_diff
            node_idx = parent[node_idx]

        controller.save_metrics(effective_alpha, r_branch[0])

    controller.after_pruning(in_subtree)


def _build_pruned_tree_ccp(
    tree, # OUT
    orig_tree,
    ccp_alpha):
    """Build a pruned tree from the original tree using cost complexity
    pruning.
    The values and nodes from the original tree are copied into the pruned
    tree.
    Parameters
    ----------
    tree : Tree
        Location to place the pruned tree
    orig_tree : Tree
        Original tree
    ccp_alpha : positive double
        Complexity parameter. The subtree with the largest cost complexity
        that is smaller than ``ccp_alpha`` will be chosen. By default,
        no pruning is performed.
    """

    n_nodes = orig_tree.node_count
    leaves_in_subtree = np.zeros(
            shape=n_nodes, dtype=np.uint8)

    pruning_controller = _AlphaPruner(ccp_alpha=ccp_alpha)

    _cost_complexity_prune(leaves_in_subtree, orig_tree, pruning_controller)

    _build_pruned_tree(tree, orig_tree, leaves_in_subtree,
                       pruning_controller.capacity)


def ccp_pruning_path(orig_tree):
    """Computes the cost complexity pruning path.
    Parameters
    ----------
    tree : Tree
        Original tree.
    Returns
    -------
    path_info : dict
        Information about pruning path with attributes:
        ccp_alphas : ndarray
            Effective alphas of subtree during pruning.
        impurities : ndarray
            Sum of the impurities of the subtree leaves for the
            corresponding alpha value in ``ccp_alphas``.
    """
    leaves_in_subtree = np.zeros(
            shape=orig_tree.node_count, dtype=np.uint8)

    path_finder = _PathFinder(orig_tree.node_count)

    _cost_complexity_prune(leaves_in_subtree, orig_tree, path_finder)

    total_items = path_finder.count
    ccp_alphas = np.empty(shape=total_items,
                                         dtype=np.float64)
    impurities = np.empty(shape=total_items,
                                         dtype=np.float64)
    count = 0

    while count < total_items:
        ccp_alphas[count] = path_finder.ccp_alphas[count]
        impurities[count] = path_finder.impurities[count]
        count += 1

    return {'ccp_alphas': ccp_alphas, 'impurities': impurities}


def _build_pruned_tree(
    tree, # OUT
    orig_tree,
    leaves_in_subtree,
    capacity):
    """Build a pruned tree.
    Build a pruned tree from the original tree by transforming the nodes in
    ``leaves_in_subtree`` into leaves.
    Parameters
    ----------
    tree : Tree
        Location to place the pruned tree
    orig_tree : Tree
        Original tree
    leaves_in_subtree : unsigned char memoryview, shape=(node_count, )
        Boolean mask for leaves to include in subtree
    capacity : SIZE_t
        Number of nodes to initially allocate in pruned tree
    """
    tree._resize(capacity)

    # value_stride for original tree and new tree are the same
    value_stride = orig_tree.value_stride
    max_depth_seen = -1
    rc = 0

    # Only uses the start, depth, parent, and is_left variables
    stack = Stack(INITIAL_STACK_SIZE)
    stack_record = None

    # push root node onto stack
    rc = stack.push(0, 0, 0, _TREE_UNDEFINED, 0, 0.0, 0)
    if rc == -1:
        raise MemoryError("pruning tree")

    while not stack.is_empty():
        stack.pop(stack_record)

        orig_node_id = stack_record.start
        depth = stack_record.depth
        parent = stack_record.parent
        is_left = stack_record.is_left

        is_leaf = leaves_in_subtree[orig_node_id]
        node = orig_tree.nodes[orig_node_id]

        new_node_id = tree._add_node(
            parent, is_left, is_leaf, node.feature, node.threshold,
            node.impurity, node.n_node_samples,
            node.weighted_n_node_samples)

        if new_node_id == SIZE_MAX:
            rc = -1
            break

        # copy value from original tree to new tree
        for idx in range(value_stride):
            tree.value[value_stride*new_node_id+idx] = orig_tree.value[value_stride*orig_node_id+idx]

        if not is_leaf:
            # Push right child on stack
            rc = stack.push(
                node.right_child, 0, depth + 1, new_node_id, 0, 0.0, 0)
            if rc == -1:
                break

            # push left child on stack
            rc = stack.push(
                node.left_child, 0, depth + 1, new_node_id, 1, 0.0, 0)
            if rc == -1:
                break

        if depth > max_depth_seen:
            max_depth_seen = depth

    if rc >= 0:
        tree.max_depth = max_depth_seen
    if rc == -1:
        raise MemoryError("pruning tree")

 

DecisionTree.py

import numbers
import warnings
from math import ceil

import copy
import numpy as np

from .Criterion import Criterion, Gini, Entropy
from .Splitter import Splitter, BestSplitter, RandomSplitter
from .Tree import DepthFirstTreeBuilder
from .Tree import BestFirstTreeBuilder
from .Tree import Tree
from .Tree import _build_pruned_tree_ccp
from .Tree import ccp_pruning_path


__all__ = ["DecisionTreeClassifier",
           "ExtraTreeClassifier"]


def clone(estimator, *, safe=True):
    """Constructs a new estimator with the same parameters.

    Clone does a deep copy of the model in an estimator
    without actually copying attached data. It yields a new estimator
    with the same parameters that has not been fit on any data.

    Parameters
    ----------
    estimator : {list, tuple, set} of estimator objects or estimator object
        The estimator or group of estimators to be cloned.

    safe : bool, default=True
        If safe is false, clone will fall back to a deep copy on objects
        that are not estimators.

    """
    estimator_type = type(estimator)
    # XXX: not handling dictionaries
    if estimator_type in (list, tuple, set, frozenset):
        return estimator_type([clone(e, safe=safe) for e in estimator])
    elif not hasattr(estimator, 'get_params') or isinstance(estimator, type):
        if not safe:
            return copy.deepcopy(estimator)
        else:
            if isinstance(estimator, type):
                raise TypeError("Cannot clone object. " +
                                "You should provide an instance of " +
                                "scikit-learn estimator instead of a class.")
            else:
                raise TypeError("Cannot clone object '%s' (type %s): "
                                "it does not seem to be a scikit-learn "
                                "estimator as it does not implement a "
                                "'get_params' method."
                                % (repr(estimator), type(estimator)))

    klass = estimator.__class__
    new_object_params = estimator.get_params(deep=False)
    for name, param in new_object_params.items():
        new_object_params[name] = clone(param, safe=False)
    new_object = klass(**new_object_params)
    params_set = new_object.get_params(deep=False)

    # quick sanity check of the parameters of the clone
    for name in new_object_params:
        param1 = new_object_params[name]
        param2 = params_set[name]
        if param1 is not param2:
            raise RuntimeError('Cannot clone object %s, as the constructor '
                               'either does not set or modifies parameter %s' %
                               (estimator, name))
    return new_object


def check_random_state(seed):
    """Turn seed into a np.random.RandomState instance

    Parameters
    ----------
    seed : None | int | instance of RandomState
        If seed is None, return the RandomState singleton used by np.random.
        If seed is an int, return a new RandomState instance seeded with seed.
        If seed is already a RandomState instance, return it.
        Otherwise raise ValueError.
    """
    if seed is None or seed is np.random:
        return np.random.mtrand._rand
    if isinstance(seed, numbers.Integral):
        return np.random.RandomState(seed)
    if isinstance(seed, np.random.RandomState):
        return seed
    raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
                     ' instance' % seed)


def _check_sample_weight(sample_weight, X, dtype=None):
    """Validate sample weights.

    Note that passing sample_weight=None will output an array of ones.
    Therefore, in some cases, you may want to protect the call with:
    if sample_weight is not None:
        sample_weight = _check_sample_weight(...)

    Parameters
    ----------
    sample_weight : {ndarray, Number or None}, shape (n_samples,)
       Input sample weights.

    X : nd-array, list or sparse matrix
        Input data.

    dtype: dtype
       dtype of the validated `sample_weight`.
       If None, and the input `sample_weight` is an array, the dtype of the
       input is preserved; otherwise an array with the default numpy dtype
       is be allocated.  If `dtype` is not one of `float32`, `float64`,
       `None`, the output will be of dtype `float64`.

    Returns
    -------
    sample_weight : ndarray, shape (n_samples,)
       Validated sample weight. It is guaranteed to be "C" contiguous.
    """
    n_samples = len(X)

    if dtype is not None and dtype not in [np.float32, np.float64]:
        dtype = np.float64

    if sample_weight is None:
        sample_weight = np.ones(n_samples, dtype=dtype)
    elif isinstance(sample_weight, numbers.Number):
        sample_weight = np.full(n_samples, sample_weight, dtype=dtype)
    else:
        if dtype is None:
            dtype = [np.float64, np.float32]
        if sample_weight.ndim != 1:
            raise ValueError("Sample weights must be 1D array or scalar")

        if sample_weight.shape != (n_samples,):
            raise ValueError("sample_weight.shape == {}, expected {}!"
                             .format(sample_weight.shape, (n_samples,)))
    return sample_weight


def compute_class_weight(class_weight, *, classes, y):
    """Estimate class weights for unbalanced datasets.

    Parameters
    ----------
    class_weight : dict, 'balanced' or None
        If 'balanced', class weights will be given by
        ``n_samples / (n_classes * np.bincount(y))``.
        If a dictionary is given, keys are classes and values
        are corresponding class weights.
        If None is given, the class weights will be uniform.

    classes : ndarray
        Array of the classes occurring in the data, as given by
        ``np.unique(y_org)`` with ``y_org`` the original class labels.

    y : array-like, shape (n_samples,)
        Array of original class labels per sample;

    Returns
    -------
    class_weight_vect : ndarray, shape (n_classes,)
        Array with class_weight_vect[i] the weight for i-th class

    References
    ----------
    The "balanced" heuristic is inspired by
    Logistic Regression in Rare Events Data, King, Zen, 2001.
    """
    # Import error caused by circular imports.
    from ..preprocessing import LabelEncoder

    if set(y) - set(classes):
        raise ValueError("classes should include all valid labels that can "
                         "be in y")
    if class_weight is None or len(class_weight) == 0:
        # uniform class weights
        weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
    elif class_weight == 'balanced':
        # Find the weight of each class as present in y.
        le = LabelEncoder()
        y_ind = le.fit_transform(y)
        if not all(np.in1d(classes, le.classes_)):
            raise ValueError("classes should have valid labels that are in y")

        recip_freq = len(y) / (len(le.classes_) *
                               np.bincount(y_ind).astype(np.float64))
        weight = recip_freq[le.transform(classes)]
    else:
        # user-defined dictionary
        weight = np.ones(classes.shape[0], dtype=np.float64, order='C')
        if not isinstance(class_weight, dict):
            raise ValueError("class_weight must be dict, 'balanced', or None,"
                             " got: %r" % class_weight)
        for c in class_weight:
            i = np.searchsorted(classes, c)
            if i >= len(classes) or classes[i] != c:
                raise ValueError("Class label {} not present.".format(c))
            else:
                weight[i] = class_weight[c]

    return weight

def compute_sample_weight(class_weight, y, *, indices=None):
    """Estimate sample weights by class for unbalanced datasets.

    Parameters
    ----------
    class_weight : dict, list of dicts, "balanced", or None, optional
        Weights associated with classes in the form ``{class_label: weight}``.
        If not given, all classes are supposed to have weight one. For
        multi-output problems, a list of dicts can be provided in the same
        order as the columns of y.

        Note that for multioutput (including multilabel) weights should be
        defined for each class of every column in its own dict. For example,
        for four-class multilabel classification weights should be
        [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of
        [{1:1}, {2:5}, {3:1}, {4:1}].

        The "balanced" mode uses the values of y to automatically adjust
        weights inversely proportional to class frequencies in the input data:
        ``n_samples / (n_classes * np.bincount(y))``.

        For multi-output, the weights of each column of y will be multiplied.

    y : array-like of shape (n_samples,) or (n_samples, n_outputs)
        Array of original class labels per sample.

    indices : array-like, shape (n_subsample,), or None
        Array of indices to be used in a subsample. Can be of length less than
        n_samples in the case of a subsample, or equal to n_samples in the
        case of a bootstrap subsample with repeated indices. If None, the
        sample weight will be calculated over the full sample. Only "balanced"
        is supported for class_weight if this is provided.

    Returns
    -------
    sample_weight_vect : ndarray, shape (n_samples,)
        Array with sample weights as applied to the original y
    """
    y = np.atleast_1d(y)
    if y.ndim == 1:
        y = np.reshape(y, (-1, 1))
    n_outputs = y.shape[1]

    if isinstance(class_weight, str):
        if class_weight not in ['balanced']:
            raise ValueError('The only valid preset for class_weight is '
                             '"balanced". Given "%s".' % class_weight)
    elif (indices is not None and
          not isinstance(class_weight, str)):
        raise ValueError('The only valid class_weight for subsampling is '
                         '"balanced". Given "%s".' % class_weight)
    elif n_outputs > 1:
        if (not hasattr(class_weight, "__iter__") or
                isinstance(class_weight, dict)):
            raise ValueError("For multi-output, class_weight should be a "
                             "list of dicts, or a valid string.")
        if len(class_weight) != n_outputs:
            raise ValueError("For multi-output, number of elements in "
                             "class_weight should match number of outputs.")

    expanded_class_weight = []
    for k in range(n_outputs):

        y_full = y[:, k]
        classes_full = np.unique(y_full)
        classes_missing = None

        if class_weight == 'balanced' or n_outputs == 1:
            class_weight_k = class_weight
        else:
            class_weight_k = class_weight[k]

        if indices is not None:
            # Get class weights for the subsample, covering all classes in
            # case some labels that were present in the original data are
            # missing from the sample.
            y_subsample = y[indices, k]
            classes_subsample = np.unique(y_subsample)

            weight_k = np.take(compute_class_weight(class_weight_k,
                                                    classes=classes_subsample,
                                                    y=y_subsample),
                               np.searchsorted(classes_subsample,
                                               classes_full),
                               mode='clip')

            classes_missing = set(classes_full) - set(classes_subsample)
        else:
            weight_k = compute_class_weight(class_weight_k,
                                            classes=classes_full,
                                            y=y_full)

        weight_k = weight_k[np.searchsorted(classes_full, y_full)]

        if classes_missing:
            # Make missing classes' weight zero
            weight_k[np.in1d(y_full, list(classes_missing))] = 0.

        expanded_class_weight.append(weight_k)

    expanded_class_weight = np.prod(expanded_class_weight,
                                    axis=0,
                                    dtype=np.float64)

    return expanded_class_weight


class Bunch(dict):
    """Container object exposing keys as attributes

    Bunch objects are sometimes used as an output for functions and methods.
    They extend dictionaries by enabling values to be accessed by key,
    `bunch["value_key"]`, or by an attribute, `bunch.value_key`.

    Examples
    --------
    >>> b = Bunch(a=1, b=2)
    >>> b['b']
    2
    >>> b.b
    2
    >>> b.a = 3
    >>> b['a']
    3
    >>> b.c = 6
    >>> b['c']
    6
    """
    def __init__(self, **kwargs):
        super().__init__(kwargs)

    def __setattr__(self, key, value):
        self[key] = value

    def __dir__(self):
        return self.keys()

    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(key)

    def __setstate__(self, state):
        # Bunch pickles generated with scikit-learn 0.16.* have an non
        # empty __dict__. This causes a surprising behaviour when
        # loading these pickles scikit-learn 0.17: reading bunch.key
        # uses __dict__ but assigning to bunch.key use __setattr__ and
        # only changes bunch['key']. More details can be found at:
        # https://github.com/scikit-learn/scikit-learn/issues/6196.
        # Overriding __setstate__ to be a noop has the effect of
        # ignoring the pickled __dict__
        pass



# =============================================================================
# Types and constants
# =============================================================================

#DTYPE = Tree.DTYPE
#DOUBLE = Tree.DOUBLE

from numpy import float64 as DOUBLE

CRITERIA_CLF = {"gini": Gini, "entropy": Entropy}
DENSE_SPLITTERS = {"best": BestSplitter,
                   "random": RandomSplitter}

# =============================================================================
# Base decision tree
# =============================================================================


class BaseDecisionTree:
    """Base class for decision trees.

    Warning: This class should not be used directly.
    Use derived classes instead.
    """
    def __init__(self, *,
                 criterion,
                 splitter,
                 max_depth,
                 min_samples_split,
                 min_samples_leaf,
                 min_weight_fraction_leaf,
                 max_features,
                 max_leaf_nodes,
                 random_state,
                 min_impurity_decrease,
                 min_impurity_split,
                 class_weight=None,
                 presort='deprecated',
                 ccp_alpha=0.0):
        self.criterion = criterion
        self.splitter = splitter
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.min_weight_fraction_leaf = min_weight_fraction_leaf
        self.max_features = max_features
        self.max_leaf_nodes = max_leaf_nodes
        self.random_state = random_state
        self.min_impurity_decrease = min_impurity_decrease
        self.min_impurity_split = min_impurity_split
        self.class_weight = class_weight
        self.presort = presort
        self.ccp_alpha = ccp_alpha

    def get_depth(self):
        """Return the depth of the decision tree.

        The depth of a tree is the maximum distance between the root
        and any leaf.

        Returns
        -------
        self.tree_.max_depth : int
            The maximum depth of the tree.
        """
        return self.tree_.max_depth

    def get_n_leaves(self):
        """Return the number of leaves of the decision tree.

        Returns
        -------
        self.tree_.n_leaves : int
            Number of leaves.
        """
        return self.tree_.n_leaves

    def fit(self, X, y, sample_weight=None, check_input=True,
            X_idx_sorted=None):

        random_state = check_random_state(self.random_state)

        if self.ccp_alpha < 0.0:
            raise ValueError("ccp_alpha must be greater than or equal to 0")

        # Determine output settings
        n_samples, self.n_features_ = X.shape
        is_classification = True

        y = np.atleast_1d(y)
        expanded_class_weight = None

        if y.ndim == 1:
            # reshape is necessary to preserve the data contiguity against vs
            # [:, np.newaxis] that does not.
            y = np.reshape(y, (-1, 1))

        self.n_outputs_ = y.shape[1]

        if is_classification:
            y = np.copy(y)

            self.classes_ = []
            self.n_classes_ = []

            if self.class_weight is not None:
                y_original = np.copy(y)

            y_encoded = np.zeros(y.shape, dtype=np.int)
            for k in range(self.n_outputs_):
                classes_k, y_encoded[:, k] = np.unique(y[:, k],
                                                       return_inverse=True)
                self.classes_.append(classes_k)
                self.n_classes_.append(classes_k.shape[0])
            y = y_encoded

            if self.class_weight is not None:
                expanded_class_weight = compute_sample_weight(
                    self.class_weight, y_original)

            self.n_classes_ = np.array(self.n_classes_, dtype=np.intp)

        if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
            y = np.ascontiguousarray(y, dtype=DOUBLE)

        # Check parameters
        max_depth = (np.iinfo(np.int32).max if self.max_depth is None
                     else self.max_depth)
        max_leaf_nodes = (-1 if self.max_leaf_nodes is None
                          else self.max_leaf_nodes)

        if isinstance(self.min_samples_leaf, numbers.Integral):
            if not 1 <= self.min_samples_leaf:
                raise ValueError("min_samples_leaf must be at least 1 "
                                 "or in (0, 0.5], got %s"
                                 % self.min_samples_leaf)
            min_samples_leaf = self.min_samples_leaf
        else:  # float
            if not 0. < self.min_samples_leaf <= 0.5:
                raise ValueError("min_samples_leaf must be at least 1 "
                                 "or in (0, 0.5], got %s"
                                 % self.min_samples_leaf)
            min_samples_leaf = int(ceil(self.min_samples_leaf * n_samples))

        if isinstance(self.min_samples_split, numbers.Integral):
            if not 2 <= self.min_samples_split:
                raise ValueError("min_samples_split must be an integer "
                                 "greater than 1 or a float in (0.0, 1.0]; "
                                 "got the integer %s"
                                 % self.min_samples_split)
            min_samples_split = self.min_samples_split
        else:  # float
            if not 0. < self.min_samples_split <= 1.:
                raise ValueError("min_samples_split must be an integer "
                                 "greater than 1 or a float in (0.0, 1.0]; "
                                 "got the float %s"
                                 % self.min_samples_split)
            min_samples_split = int(ceil(self.min_samples_split * n_samples))
            min_samples_split = max(2, min_samples_split)

        min_samples_split = max(min_samples_split, 2 * min_samples_leaf)

        if isinstance(self.max_features, str):
            if self.max_features == "auto":
                if is_classification:
                    max_features = max(1, int(np.sqrt(self.n_features_)))
                else:
                    max_features = self.n_features_
            elif self.max_features == "sqrt":
                max_features = max(1, int(np.sqrt(self.n_features_)))
            elif self.max_features == "log2":
                max_features = max(1, int(np.log2(self.n_features_)))
            else:
                raise ValueError("Invalid value for max_features. "
                                 "Allowed string values are 'auto', "
                                 "'sqrt' or 'log2'.")
        elif self.max_features is None:
            max_features = self.n_features_
        elif isinstance(self.max_features, numbers.Integral):
            max_features = self.max_features
        else:  # float
            if self.max_features > 0.0:
                max_features = max(1,
                                   int(self.max_features * self.n_features_))
            else:
                max_features = 0

        self.max_features_ = max_features

        if len(y) != n_samples:
            raise ValueError("Number of labels=%d does not match "
                             "number of samples=%d" % (len(y), n_samples))
        if not 0 <= self.min_weight_fraction_leaf <= 0.5:
            raise ValueError("min_weight_fraction_leaf must in [0, 0.5]")
        if max_depth <= 0:
            raise ValueError("max_depth must be greater than zero. ")
        if not (0 < max_features <= self.n_features_):
            raise ValueError("max_features must be in (0, n_features]")
        if not isinstance(max_leaf_nodes, numbers.Integral):
            raise ValueError("max_leaf_nodes must be integral number but was "
                             "%r" % max_leaf_nodes)
        if -1 < max_leaf_nodes < 2:
            raise ValueError(("max_leaf_nodes {0} must be either None "
                              "or larger than 1").format(max_leaf_nodes))

        if sample_weight is not None:
            sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)

        if expanded_class_weight is not None:
            if sample_weight is not None:
                sample_weight = sample_weight * expanded_class_weight
            else:
                sample_weight = expanded_class_weight

        # Set min_weight_leaf from min_weight_fraction_leaf
        if sample_weight is None:
            min_weight_leaf = (self.min_weight_fraction_leaf *
                               n_samples)
        else:
            min_weight_leaf = (self.min_weight_fraction_leaf *
                               np.sum(sample_weight))

        min_impurity_split = self.min_impurity_split
        if min_impurity_split is not None:
            warnings.warn("The min_impurity_split parameter is deprecated. "
                          "Its default value has changed from 1e-7 to 0 in "
                          "version 0.23, and it will be removed in 0.25. "
                          "Use the min_impurity_decrease parameter instead.",
                          FutureWarning)

            if min_impurity_split < 0.:
                raise ValueError("min_impurity_split must be greater than "
                                 "or equal to 0")
        else:
            min_impurity_split = 0

        if self.min_impurity_decrease < 0.:
            raise ValueError("min_impurity_decrease must be greater than "
                             "or equal to 0")

        if self.presort != 'deprecated':
            warnings.warn("The parameter 'presort' is deprecated and has no "
                          "effect. It will be removed in v0.24. You can "
                          "suppress this warning by not passing any value "
                          "to the 'presort' parameter.",
                          FutureWarning)

        # Build tree
        criterion = self.criterion
        if not isinstance(criterion, Criterion):
            if is_classification:
                criterion = CRITERIA_CLF[self.criterion](self.n_outputs_,
                                                         self.n_classes_)
        SPLITTERS = DENSE_SPLITTERS

        splitter = self.splitter
        if not isinstance(self.splitter, Splitter):
            splitter = SPLITTERS[self.splitter](criterion,
                                                self.max_features_,
                                                min_samples_leaf,
                                                min_weight_leaf,
                                                random_state)

        self.tree_ = Tree(self.n_features_,
                              self.n_classes_, self.n_outputs_)

        # Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
        if max_leaf_nodes < 0:
            builder = DepthFirstTreeBuilder(splitter, min_samples_split,
                                            min_samples_leaf,
                                            min_weight_leaf,
                                            max_depth,
                                            self.min_impurity_decrease,
                                            min_impurity_split)
        else:
            builder = BestFirstTreeBuilder(splitter, min_samples_split,
                                           min_samples_leaf,
                                           min_weight_leaf,
                                           max_depth,
                                           max_leaf_nodes,
                                           self.min_impurity_decrease,
                                           min_impurity_split)

        builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)

        if self.n_outputs_ == 1:
            self.n_classes_ = self.n_classes_[0]
            self.classes_ = self.classes_[0]

        self._prune_tree()
        return self

    def _validate_X_predict(self, X, check_input):
        """Validate X whenever one tries to predict, apply, predict_proba"""
        n_features = X.shape[1]
        if self.n_features_ != n_features:
            raise ValueError("Number of features of the model must "
                             "match the input. Model n_features is %s and "
                             "input n_features is %s "
                             % (self.n_features_, n_features))

        return X

    def predict(self, X, check_input=True):
        """Predict class or regression value for X.

        For a classification model, the predicted class for each sample in X is
        returned. For a regression model, the predicted value based on X is
        returned.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        check_input : bool, default=True
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        Returns
        -------
        y : array-like of shape (n_samples,) or (n_samples, n_outputs)
            The predicted classes, or the predict values.
        """
        X = self._validate_X_predict(X, check_input)
        proba = self.tree_.predict(X)
        n_samples = X.shape[0]

        # Classification
        if self.n_outputs_ == 1:
            return self.classes_.take(np.argmax(proba, axis=1), axis=0)
        else:
            class_type = self.classes_[0].dtype
            predictions = np.zeros((n_samples, self.n_outputs_),
                                   dtype=class_type)
            for k in range(self.n_outputs_):
                predictions[:, k] = self.classes_[k].take(
                    np.argmax(proba[:, k], axis=1),
                    axis=0)
            
            return predictions


    def apply(self, X, check_input=True):
        """Return the index of the leaf that each sample is predicted as.

        .. versionadded:: 0.17

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        check_input : bool, default=True
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        Returns
        -------
        X_leaves : array-like of shape (n_samples,)
            For each datapoint x in X, return the index of the leaf x
            ends up in. Leaves are numbered within
            ``[0; self.tree_.node_count)``, possibly with gaps in the
            numbering.
        """
        X = self._validate_X_predict(X, check_input)
        return self.tree_.apply(X)

    def decision_path(self, X, check_input=True):
        """Return the decision path in the tree.

        .. versionadded:: 0.18

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        check_input : bool, default=True
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        Returns
        -------
        indicator : sparse matrix of shape (n_samples, n_nodes)
            Return a node indicator CSR matrix where non zero elements
            indicates that the samples goes through the nodes.
        """
        X = self._validate_X_predict(X, check_input)
        return self.tree_.decision_path(X)

    def _prune_tree(self):
        """Prune tree using Minimal Cost-Complexity Pruning."""
        if self.ccp_alpha < 0.0:
            raise ValueError("ccp_alpha must be greater than or equal to 0")

        if self.ccp_alpha == 0.0:
            return

        # build pruned tree
        n_classes = np.atleast_1d(self.n_classes_)
        pruned_tree = Tree(self.n_features_, n_classes, self.n_outputs_)
        _build_pruned_tree_ccp(pruned_tree, self.tree_, self.ccp_alpha)

        self.tree_ = pruned_tree

    def cost_complexity_pruning_path(self, X, y, sample_weight=None):
        """Compute the pruning path during Minimal Cost-Complexity Pruning.

        See :ref:`minimal_cost_complexity_pruning` for details on the pruning
        process.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The training input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csc_matrix``.

        y : array-like of shape (n_samples,) or (n_samples, n_outputs)
            The target values (class labels) as integers or strings.

        sample_weight : array-like of shape (n_samples,), default=None
            Sample weights. If None, then samples are equally weighted. Splits
            that would create child nodes with net zero or negative weight are
            ignored while searching for a split in each node. Splits are also
            ignored if they would result in any single class carrying a
            negative weight in either child node.

        Returns
        -------
        ccp_path : :class:`~sklearn.utils.Bunch`
            Dictionary-like object, with the following attributes.

            ccp_alphas : ndarray
                Effective alphas of subtree during pruning.

            impurities : ndarray
                Sum of the impurities of the subtree leaves for the
                corresponding alpha value in ``ccp_alphas``.
        """
        est = clone(self).set_params(ccp_alpha=0.0)
        est.fit(X, y, sample_weight=sample_weight)
        return Bunch(**ccp_pruning_path(est.tree_))

    @property
    def feature_importances_(self):
        """Return the feature importances.

        The importance of a feature is computed as the (normalized) total
        reduction of the criterion brought by that feature.
        It is also known as the Gini importance.

        Warning: impurity-based feature importances can be misleading for
        high cardinality features (many unique values). See
        :func:`sklearn.inspection.permutation_importance` as an alternative.

        Returns
        -------
        feature_importances_ : ndarray of shape (n_features,)
            Normalized total reduction of criteria by feature
            (Gini importance).
        """
        return self.tree_.compute_feature_importances()


# =============================================================================
# Public estimators
# =============================================================================

class DecisionTreeClassifier(BaseDecisionTree):
    """A decision tree classifier.

    Read more in the :ref:`User Guide <tree>`.

    Parameters
    ----------
    criterion : {"gini", "entropy"}, default="gini"
        The function to measure the quality of a split. Supported criteria are
        "gini" for the Gini impurity and "entropy" for the information gain.

    splitter : {"best", "random"}, default="best"
        The strategy used to choose the split at each node. Supported
        strategies are "best" to choose the best split and "random" to choose
        the best random split.

    max_depth : int, default=None
        The maximum depth of the tree. If None, then nodes are expanded until
        all leaves are pure or until all leaves contain less than
        min_samples_split samples.

    min_samples_split : int or float, default=2
        The minimum number of samples required to split an internal node:

        - If int, then consider `min_samples_split` as the minimum number.
        - If float, then `min_samples_split` is a fraction and
          `ceil(min_samples_split * n_samples)` are the minimum
          number of samples for each split.

        .. versionchanged:: 0.18
           Added float values for fractions.

    min_samples_leaf : int or float, default=1
        The minimum number of samples required to be at a leaf node.
        A split point at any depth will only be considered if it leaves at
        least ``min_samples_leaf`` training samples in each of the left and
        right branches.  This may have the effect of smoothing the model,
        especially in regression.

        - If int, then consider `min_samples_leaf` as the minimum number.
        - If float, then `min_samples_leaf` is a fraction and
          `ceil(min_samples_leaf * n_samples)` are the minimum
          number of samples for each node.

        .. versionchanged:: 0.18
           Added float values for fractions.

    min_weight_fraction_leaf : float, default=0.0
        The minimum weighted fraction of the sum total of weights (of all
        the input samples) required to be at a leaf node. Samples have
        equal weight when sample_weight is not provided.

    max_features : int, float or {"auto", "sqrt", "log2"}, default=None
        The number of features to consider when looking for the best split:

            - If int, then consider `max_features` features at each split.
            - If float, then `max_features` is a fraction and
              `int(max_features * n_features)` features are considered at each
              split.
            - If "auto", then `max_features=sqrt(n_features)`.
            - If "sqrt", then `max_features=sqrt(n_features)`.
            - If "log2", then `max_features=log2(n_features)`.
            - If None, then `max_features=n_features`.

        Note: the search for a split does not stop until at least one
        valid partition of the node samples is found, even if it requires to
        effectively inspect more than ``max_features`` features.

    random_state : int, RandomState instance, default=None
        Controls the randomness of the estimator. The features are always
        randomly permuted at each split, even if ``splitter`` is set to
        ``"best"``. When ``max_features < n_features``, the algorithm will
        select ``max_features`` at random at each split before finding the best
        split among them. But the best found split may vary across different
        runs, even if ``max_features=n_features``. That is the case, if the
        improvement of the criterion is identical for several splits and one
        split has to be selected at random. To obtain a deterministic behaviour
        during fitting, ``random_state`` has to be fixed to an integer.
        See :term:`Glossary <random_state>` for details.

    max_leaf_nodes : int, default=None
        Grow a tree with ``max_leaf_nodes`` in best-first fashion.
        Best nodes are defined as relative reduction in impurity.
        If None then unlimited number of leaf nodes.

    min_impurity_decrease : float, default=0.0
        A node will be split if this split induces a decrease of the impurity
        greater than or equal to this value.

        The weighted impurity decrease equation is the following::

            N_t / N * (impurity - N_t_R / N_t * right_impurity
                                - N_t_L / N_t * left_impurity)

        where ``N`` is the total number of samples, ``N_t`` is the number of
        samples at the current node, ``N_t_L`` is the number of samples in the
        left child, and ``N_t_R`` is the number of samples in the right child.

        ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
        if ``sample_weight`` is passed.

        .. versionadded:: 0.19

    min_impurity_split : float, default=0
        Threshold for early stopping in tree growth. A node will split
        if its impurity is above the threshold, otherwise it is a leaf.

        .. deprecated:: 0.19
           ``min_impurity_split`` has been deprecated in favor of
           ``min_impurity_decrease`` in 0.19. The default value of
           ``min_impurity_split`` has changed from 1e-7 to 0 in 0.23 and it
           will be removed in 0.25. Use ``min_impurity_decrease`` instead.

    class_weight : dict, list of dict or "balanced", default=None
        Weights associated with classes in the form ``{class_label: weight}``.
        If None, all classes are supposed to have weight one. For
        multi-output problems, a list of dicts can be provided in the same
        order as the columns of y.

        Note that for multioutput (including multilabel) weights should be
        defined for each class of every column in its own dict. For example,
        for four-class multilabel classification weights should be
        [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of
        [{1:1}, {2:5}, {3:1}, {4:1}].

        The "balanced" mode uses the values of y to automatically adjust
        weights inversely proportional to class frequencies in the input data
        as ``n_samples / (n_classes * np.bincount(y))``

        For multi-output, the weights of each column of y will be multiplied.

        Note that these weights will be multiplied with sample_weight (passed
        through the fit method) if sample_weight is specified.

    presort : deprecated, default='deprecated'
        This parameter is deprecated and will be removed in v0.24.

        .. deprecated:: 0.22

    ccp_alpha : non-negative float, default=0.0
        Complexity parameter used for Minimal Cost-Complexity Pruning. The
        subtree with the largest cost complexity that is smaller than
        ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
        :ref:`minimal_cost_complexity_pruning` for details.

        .. versionadded:: 0.22

    Attributes
    ----------
    classes_ : ndarray of shape (n_classes,) or list of ndarray
        The classes labels (single output problem),
        or a list of arrays of class labels (multi-output problem).

    feature_importances_ : ndarray of shape (n_features,)
        The impurity-based feature importances.
        The higher, the more important the feature.
        The importance of a feature is computed as the (normalized)
        total reduction of the criterion brought by that feature.  It is also
        known as the Gini importance [4]_.

        Warning: impurity-based feature importances can be misleading for
        high cardinality features (many unique values). See
        :func:`sklearn.inspection.permutation_importance` as an alternative.

    max_features_ : int
        The inferred value of max_features.

    n_classes_ : int or list of int
        The number of classes (for single output problems),
        or a list containing the number of classes for each
        output (for multi-output problems).

    n_features_ : int
        The number of features when ``fit`` is performed.

    n_outputs_ : int
        The number of outputs when ``fit`` is performed.

    tree_ : Tree
        The underlying Tree object. Please refer to
        ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and
        :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
        for basic usage of these attributes.

    See Also
    --------
    DecisionTreeRegressor : A decision tree regressor.

    Notes
    -----
    The default values for the parameters controlling the size of the trees
    (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
    unpruned trees which can potentially be very large on some data sets. To
    reduce memory consumption, the complexity and size of the trees should be
    controlled by setting those parameter values.

    References
    ----------

    .. [1] https://en.wikipedia.org/wiki/Decision_tree_learning

    .. [2] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification
           and Regression Trees", Wadsworth, Belmont, CA, 1984.

    .. [3] T. Hastie, R. Tibshirani and J. Friedman. "Elements of Statistical
           Learning", Springer, 2009.

    .. [4] L. Breiman, and A. Cutler, "Random Forests",
           https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm

    Examples
    --------
    >>> from sklearn.datasets import load_iris
    >>> from sklearn.model_selection import cross_val_score
    >>> from sklearn.tree import DecisionTreeClassifier
    >>> clf = DecisionTreeClassifier(random_state=0)
    >>> iris = load_iris()
    >>> cross_val_score(clf, iris.data, iris.target, cv=10)
    ...                             # doctest: +SKIP
    ...
    array([ 1.     ,  0.93...,  0.86...,  0.93...,  0.93...,
            0.93...,  0.93...,  1.     ,  0.93...,  1.      ])
    """
    def __init__(self, *,
                 criterion="gini",
                 splitter="best",
                 max_depth=None,
                 min_samples_split=2,
                 min_samples_leaf=1,
                 min_weight_fraction_leaf=0.,
                 max_features=None,
                 random_state=None,
                 max_leaf_nodes=None,
                 min_impurity_decrease=0.,
                 min_impurity_split=None,
                 class_weight=None,
                 presort='deprecated',
                 ccp_alpha=0.0):
        super().__init__(
            criterion=criterion,
            splitter=splitter,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            min_weight_fraction_leaf=min_weight_fraction_leaf,
            max_features=max_features,
            max_leaf_nodes=max_leaf_nodes,
            class_weight=class_weight,
            random_state=random_state,
            min_impurity_decrease=min_impurity_decrease,
            min_impurity_split=min_impurity_split,
            presort=presort,
            ccp_alpha=ccp_alpha)

    def fit(self, X, y, sample_weight=None, check_input=True,
            X_idx_sorted=None):
        """Build a decision tree classifier from the training set (X, y).

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The training input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csc_matrix``.

        y : array-like of shape (n_samples,) or (n_samples, n_outputs)
            The target values (class labels) as integers or strings.

        sample_weight : array-like of shape (n_samples,), default=None
            Sample weights. If None, then samples are equally weighted. Splits
            that would create child nodes with net zero or negative weight are
            ignored while searching for a split in each node. Splits are also
            ignored if they would result in any single class carrying a
            negative weight in either child node.

        check_input : bool, default=True
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        X_idx_sorted : array-like of shape (n_samples, n_features), \
                default=None
            The indexes of the sorted training input samples. If many tree
            are grown on the same dataset, this allows the ordering to be
            cached between trees. If None, the data will be sorted here.
            Don't use this parameter unless you know what to do.

        Returns
        -------
        self : DecisionTreeClassifier
            Fitted estimator.
        """
        super().fit(
            X, y,
            sample_weight=sample_weight,
            check_input=check_input,
            X_idx_sorted=X_idx_sorted)
        return self

    def predict_proba(self, X, check_input=True):
        """Predict class probabilities of the input samples X.

        The predicted class probability is the fraction of samples of the same
        class in a leaf.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        check_input : bool, default=True
            Allow to bypass several input checking.
            Don't use this parameter unless you know what you do.

        Returns
        -------
        proba : ndarray of shape (n_samples, n_classes) or list of n_outputs \
            such arrays if n_outputs > 1
            The class probabilities of the input samples. The order of the
            classes corresponds to that in the attribute :term:`classes_`.
        """
        X = self._validate_X_predict(X, check_input)
        proba = self.tree_.predict(X)

        if self.n_outputs_ == 1:
            proba = proba[:, :self.n_classes_]
            normalizer = proba.sum(axis=1)[:, np.newaxis]
            normalizer[normalizer == 0.0] = 1.0
            proba /= normalizer

            return proba

        else:
            all_proba = []

            for k in range(self.n_outputs_):
                proba_k = proba[:, k, :self.n_classes_[k]]
                normalizer = proba_k.sum(axis=1)[:, np.newaxis]
                normalizer[normalizer == 0.0] = 1.0
                proba_k /= normalizer
                all_proba.append(proba_k)

            return all_proba

    def predict_log_proba(self, X):
        """Predict class log-probabilities of the input samples X.

        Parameters
        ----------
        X : {array-like, sparse matrix} of shape (n_samples, n_features)
            The input samples. Internally, it will be converted to
            ``dtype=np.float32`` and if a sparse matrix is provided
            to a sparse ``csr_matrix``.

        Returns
        -------
        proba : ndarray of shape (n_samples, n_classes) or list of n_outputs \
            such arrays if n_outputs > 1
            The class log-probabilities of the input samples. The order of the
            classes corresponds to that in the attribute :term:`classes_`.
        """
        proba = self.predict_proba(X)

        if self.n_outputs_ == 1:
            return np.log(proba)

        else:
            for k in range(self.n_outputs_):
                proba[k] = np.log(proba[k])

            return proba


class ExtraTreeClassifier(DecisionTreeClassifier):
    """An extremely randomized tree classifier.

    Extra-trees differ from classic decision trees in the way they are built.
    When looking for the best split to separate the samples of a node into two
    groups, random splits are drawn for each of the `max_features` randomly
    selected features and the best split among those is chosen. When
    `max_features` is set 1, this amounts to building a totally random
    decision tree.

    Warning: Extra-trees should only be used within ensemble methods.

    Read more in the :ref:`User Guide <tree>`.

    Parameters
    ----------
    criterion : {"gini", "entropy"}, default="gini"
        The function to measure the quality of a split. Supported criteria are
        "gini" for the Gini impurity and "entropy" for the information gain.

    splitter : {"random", "best"}, default="random"
        The strategy used to choose the split at each node. Supported
        strategies are "best" to choose the best split and "random" to choose
        the best random split.

    max_depth : int, default=None
        The maximum depth of the tree. If None, then nodes are expanded until
        all leaves are pure or until all leaves contain less than
        min_samples_split samples.

    min_samples_split : int or float, default=2
        The minimum number of samples required to split an internal node:

        - If int, then consider `min_samples_split` as the minimum number.
        - If float, then `min_samples_split` is a fraction and
          `ceil(min_samples_split * n_samples)` are the minimum
          number of samples for each split.

        .. versionchanged:: 0.18
           Added float values for fractions.

    min_samples_leaf : int or float, default=1
        The minimum number of samples required to be at a leaf node.
        A split point at any depth will only be considered if it leaves at
        least ``min_samples_leaf`` training samples in each of the left and
        right branches.  This may have the effect of smoothing the model,
        especially in regression.

        - If int, then consider `min_samples_leaf` as the minimum number.
        - If float, then `min_samples_leaf` is a fraction and
          `ceil(min_samples_leaf * n_samples)` are the minimum
          number of samples for each node.

        .. versionchanged:: 0.18
           Added float values for fractions.

    min_weight_fraction_leaf : float, default=0.0
        The minimum weighted fraction of the sum total of weights (of all
        the input samples) required to be at a leaf node. Samples have
        equal weight when sample_weight is not provided.

    max_features : int, float, {"auto", "sqrt", "log2"} or None, default="auto"
        The number of features to consider when looking for the best split:

            - If int, then consider `max_features` features at each split.
            - If float, then `max_features` is a fraction and
              `int(max_features * n_features)` features are considered at each
              split.
            - If "auto", then `max_features=sqrt(n_features)`.
            - If "sqrt", then `max_features=sqrt(n_features)`.
            - If "log2", then `max_features=log2(n_features)`.
            - If None, then `max_features=n_features`.

        Note: the search for a split does not stop until at least one
        valid partition of the node samples is found, even if it requires to
        effectively inspect more than ``max_features`` features.

    random_state : int, RandomState instance, default=None
        Used to pick randomly the `max_features` used at each split.
        See :term:`Glossary <random_state>` for details.

    max_leaf_nodes : int, default=None
        Grow a tree with ``max_leaf_nodes`` in best-first fashion.
        Best nodes are defined as relative reduction in impurity.
        If None then unlimited number of leaf nodes.

    min_impurity_decrease : float, default=0.0
        A node will be split if this split induces a decrease of the impurity
        greater than or equal to this value.

        The weighted impurity decrease equation is the following::

            N_t / N * (impurity - N_t_R / N_t * right_impurity
                                - N_t_L / N_t * left_impurity)

        where ``N`` is the total number of samples, ``N_t`` is the number of
        samples at the current node, ``N_t_L`` is the number of samples in the
        left child, and ``N_t_R`` is the number of samples in the right child.

        ``N``, ``N_t``, ``N_t_R`` and ``N_t_L`` all refer to the weighted sum,
        if ``sample_weight`` is passed.

        .. versionadded:: 0.19

    min_impurity_split : float, (default=0)
        Threshold for early stopping in tree growth. A node will split
        if its impurity is above the threshold, otherwise it is a leaf.

        .. deprecated:: 0.19
           ``min_impurity_split`` has been deprecated in favor of
           ``min_impurity_decrease`` in 0.19. The default value of
           ``min_impurity_split`` has changed from 1e-7 to 0 in 0.23 and it
           will be removed in 0.25. Use ``min_impurity_decrease`` instead.

    class_weight : dict, list of dict or "balanced", default=None
        Weights associated with classes in the form ``{class_label: weight}``.
        If None, all classes are supposed to have weight one. For
        multi-output problems, a list of dicts can be provided in the same
        order as the columns of y.

        Note that for multioutput (including multilabel) weights should be
        defined for each class of every column in its own dict. For example,
        for four-class multilabel classification weights should be
        [{0: 1, 1: 1}, {0: 1, 1: 5}, {0: 1, 1: 1}, {0: 1, 1: 1}] instead of
        [{1:1}, {2:5}, {3:1}, {4:1}].

        The "balanced" mode uses the values of y to automatically adjust
        weights inversely proportional to class frequencies in the input data
        as ``n_samples / (n_classes * np.bincount(y))``

        For multi-output, the weights of each column of y will be multiplied.

        Note that these weights will be multiplied with sample_weight (passed
        through the fit method) if sample_weight is specified.

    ccp_alpha : non-negative float, default=0.0
        Complexity parameter used for Minimal Cost-Complexity Pruning. The
        subtree with the largest cost complexity that is smaller than
        ``ccp_alpha`` will be chosen. By default, no pruning is performed. See
        :ref:`minimal_cost_complexity_pruning` for details.

        .. versionadded:: 0.22

    Attributes
    ----------
    classes_ : ndarray of shape (n_classes,) or list of ndarray
        The classes labels (single output problem),
        or a list of arrays of class labels (multi-output problem).

    max_features_ : int
        The inferred value of max_features.

    n_classes_ : int or list of int
        The number of classes (for single output problems),
        or a list containing the number of classes for each
        output (for multi-output problems).

    feature_importances_ : ndarray of shape (n_features,)
        The impurity-based feature importances.
        The higher, the more important the feature.
        The importance of a feature is computed as the (normalized)
        total reduction of the criterion brought by that feature.  It is also
        known as the Gini importance.

        Warning: impurity-based feature importances can be misleading for
        high cardinality features (many unique values). See
        :func:`sklearn.inspection.permutation_importance` as an alternative.

    n_features_ : int
        The number of features when ``fit`` is performed.

    n_outputs_ : int
        The number of outputs when ``fit`` is performed.

    tree_ : Tree
        The underlying Tree object. Please refer to
        ``help(sklearn.tree._tree.Tree)`` for attributes of Tree object and
        :ref:`sphx_glr_auto_examples_tree_plot_unveil_tree_structure.py`
        for basic usage of these attributes.

    See Also
    --------
    ExtraTreeRegressor : An extremely randomized tree regressor.
    sklearn.ensemble.ExtraTreesClassifier : An extra-trees classifier.
    sklearn.ensemble.ExtraTreesRegressor : An extra-trees regressor.

    Notes
    -----
    The default values for the parameters controlling the size of the trees
    (e.g. ``max_depth``, ``min_samples_leaf``, etc.) lead to fully grown and
    unpruned trees which can potentially be very large on some data sets. To
    reduce memory consumption, the complexity and size of the trees should be
    controlled by setting those parameter values.

    References
    ----------

    .. [1] P. Geurts, D. Ernst., and L. Wehenkel, "Extremely randomized trees",
           Machine Learning, 63(1), 3-42, 2006.

    Examples
    --------
    >>> from sklearn.datasets import load_iris
    >>> from sklearn.model_selection import train_test_split
    >>> from sklearn.ensemble import BaggingClassifier
    >>> from sklearn.tree import ExtraTreeClassifier
    >>> X, y = load_iris(return_X_y=True)
    >>> X_train, X_test, y_train, y_test = train_test_split(
    ...    X, y, random_state=0)
    >>> extra_tree = ExtraTreeClassifier(random_state=0)
    >>> cls = BaggingClassifier(extra_tree, random_state=0).fit(
    ...    X_train, y_train)
    >>> cls.score(X_test, y_test)
    0.8947...
    """
    def __init__(self, *,
                 criterion="gini",
                 splitter="random",
                 max_depth=None,
                 min_samples_split=2,
                 min_samples_leaf=1,
                 min_weight_fraction_leaf=0.,
                 max_features="auto",
                 random_state=None,
                 max_leaf_nodes=None,
                 min_impurity_decrease=0.,
                 min_impurity_split=None,
                 class_weight=None,
                 ccp_alpha=0.0):
        super().__init__(
            criterion=criterion,
            splitter=splitter,
            max_depth=max_depth,
            min_samples_split=min_samples_split,
            min_samples_leaf=min_samples_leaf,
            min_weight_fraction_leaf=min_weight_fraction_leaf,
            max_features=max_features,
            max_leaf_nodes=max_leaf_nodes,
            class_weight=class_weight,
            min_impurity_decrease=min_impurity_decrease,
            min_impurity_split=min_impurity_split,
            random_state=random_state,
            ccp_alpha=ccp_alpha)

 

将以上代码文件放在同一个文件夹DT中,然后将下面的测试文件和文件夹DT放在一起,即可运行。

dtree_test.py

from DT.DecisionTree import DecisionTreeClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)

clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)

print("Number of mislabeled points out of a total %d points : %d" \
       % (X_test.shape[0], (y_test != y_pred).sum()))

 

运行结果如下 :

 

需要注意的是,由于时间有限,并未对所有代码进行检查,仅就最简单情况下的运行过程进行了修改,确保运行成功且结果和scikit-learn中原始算法一致。酌情保留了分类相关的代码,但删掉了所有Regression相关的代码。也许在采用其他参数训练时会报错或者给出错误的结果,使用时需要注意检查并修改。

展开阅读全文
加载中
点击引领话题📣 发布并加入讨论🔥
0 评论
0 收藏
0
分享
返回顶部
顶部