module mltree.tree_digitize#

Short summary#

module mlinsights.mltree.tree_digitize

Helpers to investigate a tree structure.

Functions#

function

truncated documentation

digitize2tree

Builds a decision tree which returns the same result as lambda x: numpy.digitize(x, bins, right=right) (see numpy.digitize). …

Documentation#

Helpers to investigate a tree structure.

New in version 0.4.

source on GitHub

mlinsights.mltree.tree_digitize.digitize2tree(bins, right=False)#

Builds a decision tree which returns the same result as lambda x: numpy.digitize(x, bins, right=right) (see numpy.digitize).

Parameters
  • bins – array of bins. It has to be 1-dimensional and monotonic.

  • right – Indicating whether the intervals include the right or the left bin edge. Default behavior is (right==False) indicating that the interval does not include the right edge. The left bin end is open in this case, i.e., bins[i-1] <= x < bins[i] is the default behavior for monotonically increasing bins.

Returns

decision tree

Note

The implementation of decision trees in scikit-learn only allows one type of decision (<=). That’s why the function throws an exception when right=False. However, this could be overcome by using :epkg:`ONNX` where all kind of decision rules are implemented. Default value for right is still False to follow numpy API even though this value raises an exception in digitize2tree.

The following example shows what the tree looks like.

<<<

import numpy
from sklearn.tree import export_text
from mlinsights.mltree import digitize2tree

x = numpy.array([0.2, 6.4, 3.0, 1.6])
bins = numpy.array([0.0, 1.0, 2.5, 4.0, 7.0])
expected = numpy.digitize(x, bins, right=True)
tree = digitize2tree(bins, right=True)
pred = tree.predict(x.reshape((-1, 1)))
print("Comparison with numpy:")
print(expected, pred)
print("Tree:")
print(export_text(tree, feature_names=['x']))

>>>

    Comparison with numpy:
    [1 4 3 2] [1. 4. 3. 2.]
    Tree:
    |--- x <= 2.50
    |   |--- x <= 1.00
    |   |   |--- x <= 0.00
    |   |   |   |--- value: [0.00]
    |   |   |--- x >  0.00
    |   |   |   |--- value: [1.00]
    |   |--- x >  1.00
    |   |   |--- value: [2.00]
    |--- x >  2.50
    |   |--- x <= 4.00
    |   |   |--- x <= 2.50
    |   |   |   |--- value: [2.00]
    |   |   |--- x >  2.50
    |   |   |   |--- value: [3.00]
    |   |--- x >  4.00
    |   |   |--- x <= 7.00
    |   |   |   |--- x <= 4.00
    |   |   |   |   |--- value: [3.00]
    |   |   |   |--- x >  4.00
    |   |   |   |   |--- value: [4.00]
    |   |   |--- x >  7.00
    |   |   |   |--- value: [5.00]

See also example numpy.digitize as a tree.

New in version 0.4.

source on GitHub