Coverage for mlprodict/onnxrt/ops_shape/__init__.py: 100%
39 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Shortcut to *ops_shape*.
4"""
5import textwrap
6from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E0401,E0611
7from ...onnx_tools.onnx2py_helper import get_onnx_schema
8from .shape_excs import ShapeInferenceMissing
9from ._element_unary import (
10 shape_abs, shape_acos, shape_acosh,
11 shape_asin, shape_asinh, shape_atan, shape_atanh,
12 shape_castlike, shape_ceil, shape_celu,
13 shape_clip, shape_cos, shape_cosh,
14 shape_elu, shape_erf, shape_exp, shape_floor,
15 shape_hardmax, shape_hardsigmoid,
16 shape_identity, shape_isinf, shape_isnan,
17 shape_leakyrelu, shape_log, shape_logsoftmax,
18 shape_neg, shape_not, shape_reciprocal, shape_relu, shape_round,
19 shape_selu, shape_shrink,
20 shape_sigmoid, shape_sign, shape_sin, shape_sinh, shape_softmax,
21 shape_softplus, shape_softsign, shape_sqrt,
22 shape_tan, shape_tanh, shape_thresholdedrelu, shape_trilu)
23from ._element_wise import (
24 shape_add, shape_and,
25 shape_div,
26 shape_equal,
27 shape_greater, shape_greaterorequal,
28 shape_less, shape_lessorequal,
29 shape_max, shape_min, shape_mod, shape_mul,
30 shape_or,
31 shape_pow,
32 shape_sub,
33 shape_xor)
34from ._op_shape_op import shape_det
37_shape_functions = {
38 k: v for k, v in globals().items() if k.startswith("shape_")
39}
42count = [0]
45def shape_dispatch(cache, known_shape, node, rt_class=None):
46 """
47 Calls the corresponding fucntion for every node.
49 :param cache: cache used function
50 :param known_shape: known_shape for all results
51 :param node: onnx node
52 :param rt_class: a node may be a predefined function in onnx,
53 if no specific function is available, the predefined
54 onnx definition is used and run through this runtime
55 :return: was *known_shape* updated or not...
56 """
57 key = node.domain, node.op_type
58 fct_shape = None
59 if key in cache:
60 fct_shape = cache[key]
61 else:
62 op_type = "shape_" + node.op_type.lower()
63 if op_type in _shape_functions:
64 fct_shape = _shape_functions[op_type]
65 cache[key] = fct_shape
67 if fct_shape is None and rt_class is not None:
68 # check this operator is a predefined function in ONNX.
69 try:
70 onnx_schema = get_onnx_schema(node.op_type, node.domain)
71 except SchemaError:
72 onnx_schema = None
73 if onnx_schema is not None and onnx_schema.has_function:
74 sess = rt_class(onnx_schema.function_body)
75 if len(node.input) != len(sess.input_names):
76 raise RuntimeError( # pragma: no cover
77 "node and function must have the same number of inputs, "
78 "len(%r) != len(%r)." % (
79 node.input, sess.input_names))
80 if len(node.output) != len(sess.output_names):
81 raise RuntimeError( # pragma: no cover
82 "node and function must have the same number of outputs, "
83 "len(%r) != len(%r)." % (
84 node.output, sess.output_names))
86 def _shape_function(known_shape, node):
87 inputs = {iname: known_shape[name] for name, iname in
88 zip(node.input, sess.input_names)}
89 outputs = sess.run(inputs)
90 res = False
91 for name, oname in zip(node.output, sess.output_names):
92 r = known_shape.update(name, outputs[oname])
93 res = res or r
94 return res
96 fct_shape = _shape_function
97 cache[key] = fct_shape
99 if fct_shape is not None:
100 return fct_shape(known_shape, node)
102 raise ShapeInferenceMissing( # pragma: no cover
103 "Unable to find a corresponding function for operator type %r "
104 "domain=%r, looking for %r among\n%s" % (
105 node.op_type, node.domain, "shape_" + node.op_type.lower(),
106 "\n".join(textwrap.wrap(
107 " ".join(_ for _ in sorted(_shape_functions))))))