Coverage for mlprodict/onnxrt/ops_cpu/op_softmax.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRunUnaryNum, OpRunBinaryNum 

10from ._new_ops import OperatorSchema 

11 

12 

13class _Softmax(OpRunUnaryNum): 

14 

15 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

16 **options): 

17 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=expected_attributes, 

19 **options) 

20 

21 def _run(self, X, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

22 if self.inplaces.get(0, False) and X.flags['WRITEABLE']: 

23 return self._run_inplace(X) 

24 tmp = X - X.max(axis=self.axis, keepdims=1) 

25 Y = numpy.exp(tmp) 

26 Y /= Y.sum(axis=self.axis, keepdims=1) 

27 return (Y, ) 

28 

29 def _run_inplace(self, X): 

30 X -= X.max(axis=self.axis, keepdims=1) 

31 numpy.exp(X, out=X) 

32 X /= X.sum(axis=self.axis, keepdims=1) 

33 return (X, ) 

34 

35 def to_python(self, inputs): 

36 lines = ["tmp = {0} - {0}.max(axis=axis)[:, numpy.newaxis]".format( 

37 inputs[0]), 

38 "Y = numpy.exp(tmp)", 

39 "Y /= Y.sum(axis=axis)[:, numpy.newaxis]", 

40 "return Y"] 

41 return ("import numpy", "\n".join(lines)) 

42 

43 

44class Softmax_1(_Softmax): 

45 

46 atts = {'axis': 1} 

47 

48 def __init__(self, onnx_node, desc=None, **options): 

49 _Softmax.__init__(self, onnx_node, desc=desc, 

50 expected_attributes=Softmax_1.atts, 

51 **options) 

52 

53 

54class Softmax_13(_Softmax): 

55 

56 atts = {'axis': -1} 

57 

58 def __init__(self, onnx_node, desc=None, **options): 

59 _Softmax.__init__(self, onnx_node, desc=desc, 

60 expected_attributes=Softmax_13.atts, 

61 **options) 

62 

63 

64class SoftmaxGrad_13(OpRunBinaryNum): 

65 """ 

66 SoftmaxGrad computes :math:`dX = Y * ( dY - ReduceSum(Y * dY))`. 

67 ONNX does not have a dot product, 

68 which can be simulated as a pointwise-multiplication ("Mul"), 

69 followed by a "ReduceSum". Unfortunately, the treatment of "axis" 

70 is different in "SoftmaxGrad" and "ReduceSum". 

71 If axis=k for SoftmaxGrad, we need to specify [k, ..., n-1] as the axes of 

72 reduction for "ReduceSum", after accounting for negative-axis specification. 

73 An alternative solution would be to Flatten inputs to 2D and then reshape 

74 output back to original shape. Hopefully, many of these ops can be optimized 

75 away in the common-case of statically-known shapes. 

76 """ 

77 

78 atts = {'axis': 1} 

79 

80 def __init__(self, onnx_node, desc=None, **options): 

81 OpRunBinaryNum.__init__(self, onnx_node, desc=desc, 

82 expected_attributes=SoftmaxGrad_13.atts, 

83 **options) 

84 

85 def _find_custom_operator_schema(self, op_name): 

86 if op_name in ("SoftmaxGrad_13", "SoftmaxGrad"): 

87 return SoftmaxGradSchema() 

88 raise RuntimeError( # pragma: no cover 

89 f"Unable to find a schema for operator '{op_name}'.") 

90 

91 def _run(self, grad, prob, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

92 # softmax 

93 # tmp = X - X.max(axis=self.axis)[:, numpy.newaxis] 

94 # Y = numpy.exp(tmp) 

95 # Y /= Y.sum(axis=self.axis)[:, numpy.newaxis] 

96 # derivative 

97 pg = prob * grad 

98 if self.axis < 0: 

99 axis = len(pg.shape) + self.axis 

100 else: 

101 axis = self.axis 

102 axis = tuple(range(axis, len(pg.shape))) 

103 dg = grad - pg.sum(axis=axis, keepdims=1) 

104 return (prob * dg, ) 

105 

106 

107class SoftmaxGradSchema(OperatorSchema): 

108 """ 

109 Defines a schema for operators added in this package 

110 such as @see cl SoftmaxGrad_13. 

111 """ 

112 

113 def __init__(self): 

114 OperatorSchema.__init__(self, 'SoftmaxGrad') 

115 self.attributes = SoftmaxGrad_13.atts 

116 

117 

118if onnx_opset_version() >= 13: 

119 Softmax = Softmax_13 

120else: # pragma: no cover 

121 Softmax = Softmax_1 

122 

123SoftmaxGrad = SoftmaxGrad_13