Coverage for mlprodict/onnxrt/ops_cpu/op_linear_classifier.py: 94%

34 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from scipy.special import expit # pylint: disable=E0611 

9from ._op import OpRunClassifierProb 

10from ._op_classifier_string import _ClassifierCommon 

11from ._op_numpy_helper import numpy_dot_inplace 

12 

13 

14class LinearClassifier(OpRunClassifierProb, _ClassifierCommon): 

15 

16 atts = {'classlabels_ints': [], 'classlabels_strings': [], 

17 'coefficients': None, 'intercepts': None, 

18 'multi_class': 0, 'post_transform': b'NONE'} 

19 

20 def __init__(self, onnx_node, desc=None, **options): 

21 OpRunClassifierProb.__init__(self, onnx_node, desc=desc, 

22 expected_attributes=LinearClassifier.atts, 

23 **options) 

24 self._post_process_label_attributes() 

25 if not isinstance(self.coefficients, numpy.ndarray): 

26 raise TypeError( # pragma: no cover 

27 f"coefficient must be an array not {type(self.coefficients)}.") 

28 if len(getattr(self, "classlabels_ints", [])) == 0 and \ 

29 len(getattr(self, 'classlabels_strings', [])) == 0: 

30 raise ValueError( # pragma: no cover 

31 "Fields classlabels_ints or classlabels_strings must be specified.") 

32 self.nb_class = max(len(getattr(self, 'classlabels_ints', [])), 

33 len(getattr(self, 'classlabels_strings', []))) 

34 if len(self.coefficients.shape) != 1: 

35 raise ValueError( # pragma: no cover 

36 "coefficient must be an array but has shape {}\n{}.".format( 

37 self.coefficients.shape, desc)) 

38 n = self.coefficients.shape[0] // self.nb_class 

39 self.coefficients = self.coefficients.reshape(self.nb_class, n).T 

40 

41 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

42 scores = numpy_dot_inplace(self.inplaces, x, self.coefficients) 

43 if self.intercepts is not None: 

44 scores += self.intercepts 

45 

46 if self.post_transform == b'NONE': 

47 pass 

48 elif self.post_transform == b'LOGISTIC': 

49 expit(scores, out=scores) 

50 elif self.post_transform == b'SOFTMAX': 

51 numpy.subtract(scores, scores.max(axis=1)[ 

52 :, numpy.newaxis], out=scores) 

53 numpy.exp(scores, out=scores) 

54 numpy.divide(scores, scores.sum(axis=1)[ 

55 :, numpy.newaxis], out=scores) 

56 else: 

57 raise NotImplementedError( # pragma: no cover 

58 f"Unknown post_transform: '{self.post_transform}'.") 

59 

60 if self.nb_class == 1: 

61 label = numpy.zeros((scores.shape[0],), dtype=x.dtype) 

62 label[scores > 0] = 1 

63 else: 

64 label = numpy.argmax(scores, axis=1) 

65 return self._post_process_predicted_label(label, scores)