Coverage for mlprodict/onnxrt/ops_cpu/op_pad.py: 88%

42 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10 

11 

12def _pad_impl(data, raw_pads, mode, constant_values=0.0, axes=None): 

13 if raw_pads is not None: 

14 old_raw_pads = raw_pads 

15 raw_pads = [] 

16 pos = 0 

17 for i in range(len(data.shape)): 

18 if axes is None or i in axes: 

19 raw_pads.extend(old_raw_pads[pos: pos + 2]) 

20 pos += 2 

21 else: 

22 raw_pads.extend([0, 0]) 

23 raw_pads = numpy.array(raw_pads) 

24 

25 input_rank = data.ndim 

26 if input_rank * 2 != raw_pads.size: 

27 raise RuntimeError( 

28 "The number of elements in raw_pads should be 2 * data_rank") 

29 

30 half = raw_pads.shape[0] // 2 

31 pad_width = tuple((raw_pads[i], raw_pads[i + half]) 

32 for i in range(0, half)) 

33 

34 if mode == "constant": 

35 return numpy.pad( 

36 data, pad_width=pad_width, mode=mode, 

37 constant_values=constant_values) 

38 return numpy.pad(data, pad_width=pad_width, mode=mode) 

39 

40 

41def onnx_pad(data, pads, constant_value=None, mode='constant'): 

42 """ 

43 Implements :epkg:`numpy:pad` based on ONNX signature. 

44 

45 :param data: data to pad 

46 :param pads: tensor of integers indicating the number of 

47 padding elements to add or remove (if negative) at the 

48 beginning and end of each axis. For 2D input tensor, it 

49 is the number of pixels. `pads` should be a 1D tensor of 

50 shape `[2 * input_rank]`. `pads` format should be: 

51 `[x1_begin, x2_begin,...,x1_end, x2_end,...]`, where `xi_begin` is 

52 the number of pad values added at the beginning of axis `i` 

53 and xi_end, the number of pad values added at the end of axis `i`. 

54 :param constant_value: A scalar value to be used if the mode chosen is 

55 `constant` (by default it is 0, empty string or False). 

56 :param mode: Supported modes: `constant`(default), `reflect`, `edge` 

57 :return: tensor after padding 

58 """ 

59 return _pad_impl( 

60 data, pads, mode=mode, 

61 constant_values=constant_value or numpy.array( 

62 [0], dtype=data.dtype.type)) 

63 

64 

65class Pad_1(OpRun): 

66 

67 atts = {'mode': b'constant'} 

68 

69 def __init__(self, onnx_node, desc=None, **options): 

70 OpRun.__init__(self, onnx_node, desc=desc, 

71 expected_attributes=Pad.atts, 

72 **options) 

73 self.mode_ = self.mode.decode('ascii') 

74 

75 def _run(self, data, pads, constant_value=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

76 if constant_value is None: 

77 constant_value = 0 

78 return (_pad_impl(data, pads, mode=self.mode_, 

79 constant_values=constant_value), ) 

80 

81 

82class Pad_18(Pad_1): 

83 

84 def _run(self, data, pads, constant_value=None, axes=None, # pylint: disable=W0237 

85 attributes=None, verbose=0, fLOG=None): 

86 if constant_value is None: 

87 constant_value = 0 

88 return (_pad_impl( 

89 data, pads, mode=self.mode_, 

90 constant_values=constant_value, axes=axes), ) 

91 

92 

93if onnx_opset_version() >= 18: 

94 Pad = Pad_18 

95else: 

96 Pad = Pad_1 # type: ignore