Coverage for mlprodict/onnxrt/ops_cpu/op_label_encoder.py: 94%

53 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class LabelEncoder(OpRun): 

12 

13 atts = {'default_float': 0., 'default_int64': -1, 

14 'default_string': b'', 

15 'keys_floats': numpy.empty(0, dtype=numpy.float32), 

16 'keys_int64s': numpy.empty(0, dtype=numpy.int64), 

17 'keys_strings': numpy.empty(0, dtype=numpy.str_), 

18 'values_floats': numpy.empty(0, dtype=numpy.float32), 

19 'values_int64s': numpy.empty(0, dtype=numpy.int64), 

20 'values_strings': numpy.empty(0, dtype=numpy.str_), 

21 } 

22 

23 def __init__(self, onnx_node, desc=None, **options): 

24 OpRun.__init__(self, onnx_node, desc=desc, 

25 expected_attributes=LabelEncoder.atts, 

26 **options) 

27 if len(self.keys_floats) > 0 and len(self.values_floats) > 0: 

28 self.classes_ = {k: v for k, v in zip( 

29 self.keys_floats, self.values_floats)} 

30 self.default_ = self.default_float 

31 self.dtype_ = numpy.float32 

32 elif len(self.keys_floats) > 0 and len(self.values_int64s) > 0: 

33 self.classes_ = {k: v for k, v in zip( 

34 self.keys_floats, self.values_int64s)} 

35 self.default_ = self.default_int64 

36 self.dtype_ = numpy.int64 

37 elif len(self.keys_int64s) > 0 and len(self.values_int64s) > 0: 

38 self.classes_ = {k: v for k, v in zip( 

39 self.keys_int64s, self.values_int64s)} 

40 self.default_ = self.default_int64 

41 self.dtype_ = numpy.int64 

42 elif len(self.keys_int64s) > 0 and len(self.values_floats) > 0: 

43 self.classes_ = {k: v for k, v in zip( 

44 self.keys_int64s, self.values_floats)} 

45 self.default_ = self.default_int64 

46 self.dtype_ = numpy.float32 

47 elif len(self.keys_strings) > 0 and len(self.values_floats) > 0: 

48 self.classes_ = {k.decode('utf-8'): v for k, v in zip( 

49 self.keys_strings, self.values_floats)} 

50 self.default_ = self.default_float 

51 self.dtype_ = numpy.float32 

52 elif len(self.keys_strings) > 0 and len(self.values_int64s) > 0: 

53 self.classes_ = {k.decode('utf-8'): v for k, v in zip( 

54 self.keys_strings, self.values_int64s)} 

55 self.default_ = self.default_int64 

56 self.dtype_ = numpy.int64 

57 elif len(self.keys_strings) > 0 and len(self.values_strings) > 0: 

58 self.classes_ = { 

59 k.decode('utf-8'): v.decode('utf-8') for k, v in zip( 

60 self.keys_strings, self.values_strings)} 

61 self.default_ = self.default_string 

62 self.dtype_ = numpy.array(self.classes_.values).dtype 

63 elif len(self.keys_floats) > 0 and len(self.values_strings) > 0: 

64 self.classes_ = {k: v.decode('utf-8') for k, v in zip( 

65 self.keys_floats, self.values_strings)} 

66 self.default_ = self.default_string 

67 self.dtype_ = numpy.array(self.classes_.values).dtype 

68 elif len(self.keys_int64s) > 0 and len(self.values_strings) > 0: 

69 self.classes_ = {k: v.decode('utf-8') for k, v in zip( 

70 self.keys_int64s, self.values_strings)} 

71 self.default_ = self.default_string 

72 self.dtype_ = numpy.array(self.classes_.values).dtype 

73 elif hasattr(self, 'classes_strings'): 

74 raise RuntimeError( # pragma: no cover 

75 "This runtime does not implement version 1 of " 

76 "operator LabelEncoder.") 

77 else: 

78 raise RuntimeError( 

79 f"No encoding was defined in {onnx_node}.") 

80 if len(self.classes_) == 0: 

81 raise RuntimeError( # pragma: no cover 

82 "Empty classes for LabelEncoder, (onnx_node='{}')\n{}.".format( 

83 self.onnx_node.name, onnx_node)) 

84 

85 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

86 if len(x.shape) > 1: 

87 x = numpy.squeeze(x) 

88 res = numpy.empty((x.shape[0], ), dtype=self.dtype_) 

89 for i in range(0, res.shape[0]): 

90 res[i] = self.classes_.get(x[i], self.default_) 

91 return (res, )