Coverage for mlprodict/onnxrt/ops_cpu/op_unique.py: 89%

28 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11def _specify_int64(indices, inverse_indices, counts): 

12 return (numpy.array(indices, dtype=numpy.int64), 

13 numpy.array(inverse_indices, dtype=numpy.int64), 

14 numpy.array(counts, dtype=numpy.int64)) 

15 

16 

17class Unique(OpRun): 

18 

19 atts = {'axis': numpy.nan, 'sorted': 1} 

20 

21 def __init__(self, onnx_node, desc=None, **options): 

22 OpRun.__init__(self, onnx_node, desc=desc, 

23 expected_attributes=Unique.atts, 

24 **options) 

25 

26 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

27 if numpy.isnan(self.axis): 

28 y, indices, inverse_indices, counts = numpy.unique( 

29 x, True, True, True) 

30 else: 

31 y, indices, inverse_indices, counts = numpy.unique( 

32 x, True, True, True, axis=self.axis) 

33 if len(self.onnx_node.output) == 1: 

34 return (y, ) 

35 

36 if not self.sorted: 

37 argsorted_indices = numpy.argsort(indices) 

38 inverse_indices_map = { 

39 i: si 

40 for i, si in zip( 

41 argsorted_indices, numpy.arange(len(argsorted_indices)))} 

42 indices = indices[argsorted_indices] 

43 y = numpy.take(x, indices, axis=0) 

44 inverse_indices = numpy.asarray( 

45 [inverse_indices_map[i] for i in inverse_indices], 

46 dtype=numpy.int64) 

47 counts = counts[argsorted_indices] 

48 

49 indices, inverse_indices, counts = _specify_int64( 

50 indices, inverse_indices, counts) 

51 if len(self.onnx_node.output) == 2: 

52 return (y, indices) 

53 if len(self.onnx_node.output) == 3: 

54 return (y, indices, inverse_indices) 

55 return (y, indices, inverse_indices, counts)