Coverage for src/mlstatpy/ml/_neural_tree_api.py: 97%
31 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-27 05:59 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-27 05:59 +0100
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Conversion from tree to neural network.
5"""
6import numpy
7from ..optim import SGDOptimizer
10class _TrainingAPI:
11 """
12 Declaration of function needed to train a model.
13 """
15 @property
16 def training_weights(self):
17 "Returns the weights."
18 raise NotImplementedError( # pragma: no cover
19 "This should be overwritten.")
21 def update_training_weights(self, grad, add=True):
22 """
23 Updates weights.
25 :param grad: vector to add to the weights such as gradient
26 :param add: addition or replace
27 """
28 raise NotImplementedError( # pragma: no cover
29 "This should be overwritten.")
31 def fill_cache(self, X):
32 """
33 Creates a cache with intermediate results.
34 """
35 return None # pragma: no cover
37 def loss(self, X, y, cache=None):
38 """
39 Computes the loss. Returns a float.
40 """
41 raise NotImplementedError( # pragma: no cover
42 "This should be overwritten.")
44 def dlossds(self, X, y, cache=None):
45 """
46 Computes the loss derivative due to prediction error.
47 """
48 raise NotImplementedError( # pragma: no cover
49 "This should be overwritten.")
51 def gradient_backward(self, graddx, X, inputs=False, cache=None):
52 """
53 Computes the gradient in X.
55 :param graddx: existing gradient against the outputs
56 :param X: computes the gradient in X
57 :param inputs: if False, derivative against the coefficients,
58 otherwise against the inputs.
59 :param cache: cache intermediate results to avoid more computation
60 :return: gradient
61 """
62 raise NotImplementedError( # pragma: no cover
63 "This should be overwritten.")
65 def gradient(self, X, y, inputs=False):
66 """
67 Computes the gradient in *X* knowing the expected value *y*.
69 :param X: computes the gradient in X
70 :param y: expected values
71 :param inputs: if False, derivative against the coefficients,
72 otherwise against the inputs.
73 :return: gradient
74 """
75 if len(X.shape) != 1:
76 raise ValueError( # pragma: no cover
77 f"X must a vector of one dimension but has shape {X.shape}.")
78 cache = self.fill_cache(X) # pylint: disable=E1128
79 dlossds = self.dlossds(X, y, cache=cache)
80 return self.gradient_backward(dlossds, X, inputs=inputs, cache=cache)
82 def fit(self, X, y, optimizer=None, max_iter=100, early_th=None, verbose=False,
83 lr=None, lr_schedule=None, l1=0., l2=0., momentum=0.9):
84 """
85 Fits a neuron.
87 :param X: training set
88 :param y: training labels
89 :param optimizer: optimizer, by default, it is
90 :class:`SGDOptimizer <mlstatpy.optim.sgd.SGDOptimizer>`.
91 :param max_iter: number maximum of iterations
92 :param early_th: early stopping threshold
93 :param verbose: more verbose
94 :param lr: to overwrite *learning_rate_init* if
95 *optimizer* is None (unused otherwise)
96 :param lr_schedule: to overwrite *lr_schedule* if
97 *optimizer* is None (unused otherwise)
98 :param l1: L1 regularization if *optimizer* is None
99 (unused otherwise)
100 :param l2: L2 regularization if *optimizer* is None
101 (unused otherwise)
102 :param momentum: used if *optimizer* is None
103 :return: self
104 """
105 if optimizer is None:
106 optimizer = SGDOptimizer(
107 self.training_weights, learning_rate_init=lr or 0.002,
108 lr_schedule=lr_schedule or 'invscaling',
109 l1=l1, l2=l2, momentum=momentum)
111 def fct_loss(coef, lx, ly, neuron=self):
112 neuron.update_training_weights(coef, False)
113 loss = neuron.loss(lx, ly)
114 if loss.shape[0] > 1:
115 return numpy.sum(loss)
116 return loss
118 def fct_grad(coef, lx, ly, i, neuron=self):
119 neuron.update_training_weights(coef, False)
120 return neuron.gradient(lx, ly).ravel()
122 optimizer.train(
123 X, y, fct_loss, fct_grad, max_iter=max_iter,
124 early_th=early_th, verbose=verbose)
126 self.update_training_weights(optimizer.coef, False)
127 return self