module mltree.tree_structure#

Short summary#

module mlinsights.mltree.tree_structure

Helpers to investigate a tree structure.

source on GitHub

Functions#

function

truncated documentation

_get_tree

Returns the tree object.

predict_leaves

Returns the leave every observations of X falls into.

tree_find_common_node

Finds the common node to nodes i and j.

tree_find_path_to_root

Lists nodes involved into the path to find node i.

tree_leave_index

Returns the indices of every leave in a tree.

tree_leave_neighbors

The function determines which leaves are neighbors. The method uses some memory as it creates creates a grid of …

tree_node_parents

Returns a dictionary {node_id: parent_id}.

tree_node_range

Determines the ranges for a node all dimensions. nan means infinity.

Documentation#

Helpers to investigate a tree structure.

source on GitHub

mlinsights.mltree.tree_structure._get_tree(obj)#

Returns the tree object.

source on GitHub

mlinsights.mltree.tree_structure.predict_leaves(model, X)#

Returns the leave every observations of X falls into.

Parameters:
  • model – a decision tree

  • X – observations

Returns:

array of leaves

source on GitHub

mlinsights.mltree.tree_structure.tree_find_common_node(tree, i, j, parents=None)#

Finds the common node to nodes i and j.

Parameters:
  • tree – tree

  • i – node index (tree.nodes[i])

  • j – node index (tree.nodes[j])

  • parents – precomputed parents (None -> calls tree_node_range)

Returns:

common root, remaining path to i, remaining path to j

source on GitHub

mlinsights.mltree.tree_structure.tree_find_path_to_root(tree, i, parents=None)#

Lists nodes involved into the path to find node i.

Parameters:
  • tree – tree

  • i – node index (tree.nodes[i])

  • parents – precomputed parents (None -> calls tree_node_range)

Returns:

one array of size (D, 2) where D is the number of dimensions

source on GitHub

mlinsights.mltree.tree_structure.tree_leave_index(model)#

Returns the indices of every leave in a tree.

Parameters:

model – something which has a member tree_

Returns:

leave indices

source on GitHub

mlinsights.mltree.tree_structure.tree_leave_neighbors(model)#

The function determines which leaves are neighbors. The method uses some memory as it creates creates a grid of the feature spaces, each split multiplies the number of cells by two.

Parameters:

model – a sklearn.tree.DecisionTreeRegressor, a sklearn.tree.DecisionTreeClassifier, a model which has a member tree_

Returns:

a dictionary {(i, j): (dimension, x1, x2)}, i, j are node indices, if X_d * sign < th  * sign, the observations goes to node i, j otherwise, i < j. The border is somewhere in the segment [x1, x2].

The following example shows what the function returns in case of simple grid in two dimensions.

<<<

import numpy
from sklearn.tree import DecisionTreeClassifier
from mlinsights.mltree import tree_leave_neighbors

X = numpy.array([[0, 0], [0, 1], [0, 2],
                 [1, 0], [1, 1], [1, 2],
                 [2, 0], [2, 1], [2, 2]])
y = list(range(X.shape[0]))
clr = DecisionTreeClassifier(max_depth=4)
clr.fit(X, y)

nei = tree_leave_neighbors(clr)

import pprint
pprint.pprint(nei)

>>>

    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    somewhere/workspace/mlinsights/mlinsights_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/utils/deprecation.py:103: FutureWarning: The attribute `n_features_` is deprecated in 1.0 and will be removed in 1.2. Use `n_features_in_` instead.
      warnings.warn(msg, category=FutureWarning)
    {(2, 4): [(0, (0.0, 0.0), (1.0, 0.0))],
     (2, 8): [(1, (0.0, 0.0), (0.0, 1.0))],
     (4, 5): [(0, (1.0, 0.0), (2.0, 0.0))],
     (4, 12): [(1, (1.0, 0.0), (1.0, 1.0))],
     (5, 13): [(1, (2.0, 0.0), (2.0, 1.0))],
     (8, 9): [(1, (0.0, 1.0), (0.0, 2.0))],
     (8, 12): [(0, (0.0, 1.0), (1.0, 1.0))],
     (9, 15): [(0, (0.0, 2.0), (1.0, 2.0))],
     (12, 13): [(0, (1.0, 1.0), (2.0, 1.0))],
     (12, 15): [(1, (1.0, 1.0), (1.0, 2.0))],
     (13, 16): [(1, (2.0, 1.0), (2.0, 2.0))],
     (15, 16): [(0, (1.0, 2.0), (2.0, 2.0))]}

source on GitHub

mlinsights.mltree.tree_structure.tree_node_parents(tree)#

Returns a dictionary {node_id: parent_id}.

Parameters:

tree – tree

Returns:

parents

source on GitHub

mlinsights.mltree.tree_structure.tree_node_range(tree, i, parents=None)#

Determines the ranges for a node all dimensions. nan means infinity.

Parameters:
  • tree – tree

  • i – node index (tree.nodes[i])

  • parents – precomputed parents (None -> calls tree_node_range)

Returns:

one array of size (D, 2) where D is the number of dimensions

The following example shows what the function returns in case of simple grid in two dimensions.

<<<

import numpy
from sklearn.tree import DecisionTreeClassifier
from mlinsights.mltree import tree_leave_index, tree_node_range

X = numpy.array([[0, 0], [0, 1], [0, 2],
                 [1, 0], [1, 1], [1, 2],
                 [2, 0], [2, 1], [2, 2]])
y = list(range(X.shape[0]))
clr = DecisionTreeClassifier(max_depth=4)
clr.fit(X, y)

leaves = tree_leave_index(clr)
ra = tree_node_range(clr, leaves[0])

print(ra)

>>>

    [[nan 0.5]
     [nan 0.5]]

source on GitHub