Coverage for mlprodict/onnxrt/ops_cpu/op_one_hot_encoder.py: 72%

36 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class OneHotEncoder(OpRun): 

12 """ 

13 :epkg:`ONNX` specifications does not mention 

14 the possibility to change the output type, 

15 sparse, dense, float, double. 

16 """ 

17 

18 atts = {'cats_int64s': numpy.empty(0, dtype=numpy.int64), 

19 'cats_strings': numpy.empty(0, dtype=numpy.str_), 

20 'zeros': 1, 

21 } 

22 

23 def __init__(self, onnx_node, desc=None, **options): 

24 OpRun.__init__(self, onnx_node, desc=desc, 

25 expected_attributes=OneHotEncoder.atts, 

26 **options) 

27 if len(self.cats_int64s) > 0: 

28 self.classes_ = {v: i for i, v in enumerate(self.cats_int64s)} 

29 elif len(self.cats_strings) > 0: 

30 self.classes_ = {v.decode('utf-8'): i for i, 

31 v in enumerate(self.cats_strings)} 

32 else: 

33 raise RuntimeError("No encoding was defined.") # pragma: no cover 

34 

35 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

36 shape = x.shape 

37 new_shape = shape + (len(self.classes_), ) 

38 res = numpy.zeros(new_shape, dtype=numpy.float32) 

39 if len(x.shape) == 1: 

40 for i, v in enumerate(x): 

41 j = self.classes_.get(v, -1) 

42 if j >= 0: 

43 res[i, j] = 1. 

44 elif len(x.shape) == 2: 

45 for a, row in enumerate(x): 

46 for i, v in enumerate(row): 

47 j = self.classes_.get(v, -1) 

48 if j >= 0: 

49 res[a, i, j] = 1. 

50 else: 

51 raise RuntimeError( # pragma: no cover 

52 f"This operator is not implemented for shape {x.shape}.") 

53 

54 if not self.zeros: 

55 red = res.sum(axis=len(res.shape) - 1) 

56 if numpy.min(red) == 0: 

57 rows = [] 

58 for i, val in enumerate(red): 

59 if val == 0: 

60 rows.append(dict(row=i, value=x[i])) 

61 if len(rows) > 5: 

62 break 

63 raise RuntimeError( # pragma no cover 

64 "One observation did not have any defined category.\n" 

65 "classes: {}\nfirst rows:\n{}\nres:\n{}\nx:\n{}".format( 

66 self.classes_, "\n".join(str(_) for _ in rows), 

67 res[:5], x[:5])) 

68 

69 return (res, )