Coverage for mlprodict/onnxrt/ops_cpu/op_category_mapper.py: 100%

24 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class CategoryMapper(OpRun): 

12 

13 atts = {'cats_int64s': numpy.empty(0, dtype=numpy.int64), 

14 'cats_strings': numpy.empty(0, dtype=numpy.str_), 

15 'default_int64': -1, 

16 'default_string': b'', 

17 } 

18 

19 def __init__(self, onnx_node, desc=None, **options): 

20 OpRun.__init__(self, onnx_node, desc=desc, 

21 expected_attributes=CategoryMapper.atts, 

22 **options) 

23 if len(self.cats_int64s) != len(self.cats_strings): 

24 raise RuntimeError( # pragma: no cover 

25 "Lengths mismatch between cats_int64s (%d) and " 

26 "cats_strings (%d)." % ( 

27 len(self.cats_int64s), len(self.cats_strings))) 

28 self.int2str_ = {} 

29 self.str2int_ = {} 

30 for a, b in zip(self.cats_int64s, self.cats_strings): 

31 be = b.decode('utf-8') 

32 self.int2str_[a] = be 

33 self.str2int_[be] = a 

34 

35 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

36 if x.dtype == numpy.int64: 

37 xf = x.ravel() 

38 res = [self.int2str_.get(xf[i], self.default_string) 

39 for i in range(0, xf.shape[0])] 

40 return (numpy.array(res).reshape(x.shape), ) 

41 

42 xf = x.ravel() 

43 res = numpy.empty((xf.shape[0], ), dtype=numpy.int64) 

44 for i in range(0, res.shape[0]): 

45 res[i] = self.str2int_.get(xf[i], self.default_int64) 

46 return (res.reshape(x.shape), )