Coverage for mlprodict/onnxrt/ops_cpu/op_dequantize_linear.py: 82%

28 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class DequantizeLinear(OpRun): 

12 

13 atts = {'axis': 1} 

14 python_inputs = ['*inputs'] 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRun.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=DequantizeLinear.atts, 

19 **options) 

20 

21 def _run(self, *args, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

22 if len(args[1].shape) > 1: 

23 raise RuntimeError( # pragma: no cover 

24 "Input 2 must be a vector or a number.") 

25 

26 x_scale = args[2] 

27 if len(x_scale.shape) > 0 and x_scale.size == 1: 

28 x_scale = x_scale[0] 

29 if len(args) > 2: 

30 if x_scale.dtype != args[0].dtype: 

31 raise RuntimeError( # pragma no cover 

32 "Type mismatch {} != {} in DequantizeLinear.".format( 

33 args[0].dtype, x_scale.dtype)) 

34 

35 if len(x_scale.shape) > 0: 

36 new_shape = [1 for s in args[0].shape] 

37 new_shape[self.axis] = len(x_scale) 

38 x = args[0].astype(numpy.float32) - x_scale.reshape(new_shape) 

39 y = x * args[1].reshape(new_shape) 

40 else: 

41 x = args[0].astype(numpy.float32) - x_scale 

42 y = x * args[1] 

43 elif len(args[1].shape) > 0: 

44 new_shape = [1 for s in args[0].shape] 

45 new_shape[self.axis] = len(x_scale) 

46 y = args[0].astype(numpy.float32) * x_scale.reshape(new_shape) 

47 else: 

48 y = args[0].astype(numpy.float32) * x_scale 

49 return (y.astype(numpy.float32), )