Coverage for mlprodict/onnx_tools/exports/skl2onnx_helper.py: 98%

63 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Helpers to run examples created with :epkg:`sklearn-onnx`. 

4""" 

5from onnx import helper, TensorProto, ValueInfoProto, TypeProto 

6 

7 

8def get_tensor_shape(obj): 

9 """ 

10 Returns the shape if that makes sense for this object. 

11 """ 

12 if isinstance(obj, ValueInfoProto): 

13 return get_tensor_shape(obj.type) 

14 elif not isinstance(obj, TypeProto): 

15 raise TypeError( # pragma: no cover 

16 f"Unexpected type {type(obj)!r}.") 

17 shape = [] 

18 for d in obj.tensor_type.shape.dim: 

19 v = d.dim_value if d.dim_value > 0 else d.dim_param 

20 shape.append(v) 

21 if len(shape) == 0: 

22 shape = None 

23 else: 

24 shape = list(None if s == 0 else s for s in shape) 

25 return shape 

26 

27 

28def get_tensor_elem_type(obj): 

29 """ 

30 Returns the element type if that makes sense for this object. 

31 """ 

32 if isinstance(obj, ValueInfoProto): 

33 return get_tensor_elem_type(obj.type) 

34 elif not isinstance(obj, TypeProto): 

35 raise TypeError( # pragma: no cover 

36 f"Unexpected type {type(obj)!r}.") 

37 return obj.tensor_type.elem_type 

38 

39 

40def _copy_inout(inout, scope, new_name): 

41 shape = get_tensor_shape(inout) 

42 elem_type = get_tensor_elem_type(inout) 

43 value_info = helper.make_tensor_value_info( 

44 new_name, elem_type, shape) 

45 return value_info 

46 

47 

48def _clean_variable_name(name, scope): 

49 return scope.get_unique_variable_name(name) 

50 

51 

52def _clean_operator_name(name, scope): 

53 return scope.get_unique_operator_name(name) 

54 

55 

56def _clean_initializer_name(name, scope): 

57 return scope.get_unique_variable_name(name) 

58 

59 

60def add_onnx_graph(scope, operator, container, onx): 

61 """ 

62 Adds a whole ONNX graph to an existing one following 

63 :epkg:`skl2onnx` API assuming this ONNX graph implements 

64 an `operator <http://onnx.ai/sklearn-onnx/api_summary.html? 

65 highlight=operator#skl2onnx.common._topology.Operator>`_. 

66 

67 :param scope: scope (to get unique names) 

68 :param operator: operator 

69 :param container: container 

70 :param onx: ONNX graph 

71 """ 

72 graph = onx.graph 

73 name_mapping = {} 

74 node_mapping = {} 

75 for node in graph.node: 

76 name = node.name 

77 if name is not None: 

78 node_mapping[node.name] = _clean_initializer_name( 

79 node.name, scope) 

80 for o in node.input: 

81 name_mapping[o] = _clean_variable_name(o, scope) 

82 for o in node.output: 

83 name_mapping[o] = _clean_variable_name(o, scope) 

84 for o in graph.initializer: 

85 name_mapping[o.name] = _clean_operator_name(o.name, scope) 

86 

87 inputs = [_copy_inout(o, scope, name_mapping[o.name]) 

88 for o in graph.input] 

89 outputs = [_copy_inout(o, scope, name_mapping[o.name]) 

90 for o in graph.output] 

91 

92 for inp, to in zip(operator.inputs, inputs): 

93 n = helper.make_node('Identity', [inp.onnx_name], [to.name], 

94 name=_clean_operator_name('Identity', scope)) 

95 container.nodes.append(n) 

96 

97 for inp, to in zip(outputs, operator.outputs): 

98 n = helper.make_node('Identity', [inp.name], [to.onnx_name], 

99 name=_clean_operator_name('Identity', scope)) 

100 container.nodes.append(n) 

101 

102 for node in graph.node: 

103 n = helper.make_node( 

104 node.op_type, 

105 [name_mapping[o] for o in node.input], 

106 [name_mapping[o] for o in node.output], 

107 name=node_mapping[node.name] if node.name else None, 

108 domain=node.domain if node.domain else None) 

109 n.attribute.extend(node.attribute) # pylint: disable=E1101 

110 container.nodes.append(n) 

111 

112 for o in graph.initializer: 

113 as_str = o.SerializeToString() 

114 tensor = TensorProto() 

115 tensor.ParseFromString(as_str) 

116 tensor.name = name_mapping[o.name] 

117 container.initializers.append(tensor) 

118 

119 # opset 

120 for oimp in onx.opset_import: 

121 container.node_domain_version_pair_sets.add( 

122 (oimp.domain, oimp.version))