Coverage for mlprodict/onnxrt/ops_cpu/op_linear_classifier.py: 94%
34 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from scipy.special import expit # pylint: disable=E0611
9from ._op import OpRunClassifierProb
10from ._op_classifier_string import _ClassifierCommon
11from ._op_numpy_helper import numpy_dot_inplace
14class LinearClassifier(OpRunClassifierProb, _ClassifierCommon):
16 atts = {'classlabels_ints': [], 'classlabels_strings': [],
17 'coefficients': None, 'intercepts': None,
18 'multi_class': 0, 'post_transform': b'NONE'}
20 def __init__(self, onnx_node, desc=None, **options):
21 OpRunClassifierProb.__init__(self, onnx_node, desc=desc,
22 expected_attributes=LinearClassifier.atts,
23 **options)
24 self._post_process_label_attributes()
25 if not isinstance(self.coefficients, numpy.ndarray):
26 raise TypeError( # pragma: no cover
27 f"coefficient must be an array not {type(self.coefficients)}.")
28 if len(getattr(self, "classlabels_ints", [])) == 0 and \
29 len(getattr(self, 'classlabels_strings', [])) == 0:
30 raise ValueError( # pragma: no cover
31 "Fields classlabels_ints or classlabels_strings must be specified.")
32 self.nb_class = max(len(getattr(self, 'classlabels_ints', [])),
33 len(getattr(self, 'classlabels_strings', [])))
34 if len(self.coefficients.shape) != 1:
35 raise ValueError( # pragma: no cover
36 "coefficient must be an array but has shape {}\n{}.".format(
37 self.coefficients.shape, desc))
38 n = self.coefficients.shape[0] // self.nb_class
39 self.coefficients = self.coefficients.reshape(self.nb_class, n).T
41 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
42 scores = numpy_dot_inplace(self.inplaces, x, self.coefficients)
43 if self.intercepts is not None:
44 scores += self.intercepts
46 if self.post_transform == b'NONE':
47 pass
48 elif self.post_transform == b'LOGISTIC':
49 expit(scores, out=scores)
50 elif self.post_transform == b'SOFTMAX':
51 numpy.subtract(scores, scores.max(axis=1)[
52 :, numpy.newaxis], out=scores)
53 numpy.exp(scores, out=scores)
54 numpy.divide(scores, scores.sum(axis=1)[
55 :, numpy.newaxis], out=scores)
56 else:
57 raise NotImplementedError( # pragma: no cover
58 f"Unknown post_transform: '{self.post_transform}'.")
60 if self.nb_class == 1:
61 label = numpy.zeros((scores.shape[0],), dtype=x.dtype)
62 label[scores > 0] = 1
63 else:
64 label = numpy.argmax(scores, axis=1)
65 return self._post_process_predicted_label(label, scores)