Coverage for mlprodict/onnxrt/ops_shape/__init__.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Shortcut to *ops_shape*. 

4""" 

5import textwrap 

6from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E0401,E0611 

7from ...onnx_tools.onnx2py_helper import get_onnx_schema 

8from .shape_excs import ShapeInferenceMissing 

9from ._element_unary import ( 

10 shape_abs, shape_acos, shape_acosh, 

11 shape_asin, shape_asinh, shape_atan, shape_atanh, 

12 shape_castlike, shape_ceil, shape_celu, 

13 shape_clip, shape_cos, shape_cosh, 

14 shape_elu, shape_erf, shape_exp, shape_floor, 

15 shape_hardmax, shape_hardsigmoid, 

16 shape_identity, shape_isinf, shape_isnan, 

17 shape_leakyrelu, shape_log, shape_logsoftmax, 

18 shape_neg, shape_not, shape_reciprocal, shape_relu, shape_round, 

19 shape_selu, shape_shrink, 

20 shape_sigmoid, shape_sign, shape_sin, shape_sinh, shape_softmax, 

21 shape_softplus, shape_softsign, shape_sqrt, 

22 shape_tan, shape_tanh, shape_thresholdedrelu, shape_trilu) 

23from ._element_wise import ( 

24 shape_add, shape_and, 

25 shape_div, 

26 shape_equal, 

27 shape_greater, shape_greaterorequal, 

28 shape_less, shape_lessorequal, 

29 shape_max, shape_min, shape_mod, shape_mul, 

30 shape_or, 

31 shape_pow, 

32 shape_sub, 

33 shape_xor) 

34from ._op_shape_op import shape_det 

35 

36 

37_shape_functions = { 

38 k: v for k, v in globals().items() if k.startswith("shape_") 

39} 

40 

41 

42count = [0] 

43 

44 

45def shape_dispatch(cache, known_shape, node, rt_class=None): 

46 """ 

47 Calls the corresponding fucntion for every node. 

48 

49 :param cache: cache used function 

50 :param known_shape: known_shape for all results 

51 :param node: onnx node 

52 :param rt_class: a node may be a predefined function in onnx, 

53 if no specific function is available, the predefined 

54 onnx definition is used and run through this runtime 

55 :return: was *known_shape* updated or not... 

56 """ 

57 key = node.domain, node.op_type 

58 fct_shape = None 

59 if key in cache: 

60 fct_shape = cache[key] 

61 else: 

62 op_type = "shape_" + node.op_type.lower() 

63 if op_type in _shape_functions: 

64 fct_shape = _shape_functions[op_type] 

65 cache[key] = fct_shape 

66 

67 if fct_shape is None and rt_class is not None: 

68 # check this operator is a predefined function in ONNX. 

69 try: 

70 onnx_schema = get_onnx_schema(node.op_type, node.domain) 

71 except SchemaError: 

72 onnx_schema = None 

73 if onnx_schema is not None and onnx_schema.has_function: 

74 sess = rt_class(onnx_schema.function_body) 

75 if len(node.input) != len(sess.input_names): 

76 raise RuntimeError( # pragma: no cover 

77 "node and function must have the same number of inputs, " 

78 "len(%r) != len(%r)." % ( 

79 node.input, sess.input_names)) 

80 if len(node.output) != len(sess.output_names): 

81 raise RuntimeError( # pragma: no cover 

82 "node and function must have the same number of outputs, " 

83 "len(%r) != len(%r)." % ( 

84 node.output, sess.output_names)) 

85 

86 def _shape_function(known_shape, node): 

87 inputs = {iname: known_shape[name] for name, iname in 

88 zip(node.input, sess.input_names)} 

89 outputs = sess.run(inputs) 

90 res = False 

91 for name, oname in zip(node.output, sess.output_names): 

92 r = known_shape.update(name, outputs[oname]) 

93 res = res or r 

94 return res 

95 

96 fct_shape = _shape_function 

97 cache[key] = fct_shape 

98 

99 if fct_shape is not None: 

100 return fct_shape(known_shape, node) 

101 

102 raise ShapeInferenceMissing( # pragma: no cover 

103 "Unable to find a corresponding function for operator type %r " 

104 "domain=%r, looking for %r among\n%s" % ( 

105 node.op_type, node.domain, "shape_" + node.op_type.lower(), 

106 "\n".join(textwrap.wrap( 

107 " ".join(_ for _ in sorted(_shape_functions))))))