Coverage for onnxcustom/utils/onnx_helper.py: 100%

75 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1# pylint: disable=C0415,E0611,E1101 

2""" 

3@file 

4@brief Onnx implementation of common functions used to train a model. 

5""" 

6import math 

7import numpy 

8from onnx import TensorProto, numpy_helper, helper 

9from onnxruntime import OrtValue 

10from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue 

11 

12 

13def onnx_rename_weights(onx): 

14 """ 

15 Renames ONNX initializers to make sure their name 

16 follows the alphabetical order. The model is 

17 modified inplace. This function calls 

18 :func:`onnx_rename_names 

19 <mlprodict.onnx_tools.onnx_manipulations.onnx_rename_names>`. 

20 

21 :param onx: ONNX model 

22 :return: same model 

23 

24 .. note:: 

25 The function does not go into subgraphs. 

26 """ 

27 from mlprodict.onnx_tools.onnx_manipulations import ( # pylint: disable=C0415 

28 onnx_rename_names) 

29 

30 init = [init.name for init in onx.graph.initializer] 

31 ninit = max(1, int(math.log(len(init)) / math.log(10) + 1)) 

32 fmt = f"I%0{ninit}d_%s" 

33 new_names = [fmt % (i, name) for i, name in enumerate(init)] 

34 repl = dict(zip(init, new_names)) 

35 return onnx_rename_names(onx, recursive=False, replace=repl) 

36 

37 

38def get_onnx_opset(onx, domain=''): 

39 """ 

40 Returns the opset associated to an opset. 

41 

42 :param onx: onx graph 

43 :param domain: domain 

44 :return: value 

45 """ 

46 for opset in onx.opset_import: 

47 if opset.domain == domain: 

48 return opset.version 

49 raise ValueError( 

50 f"Unable to find opset for domain={domain!r}.") 

51 

52 

53def proto_type_to_dtype(proto_type): 

54 """ 

55 Converts a ONNX TensorProto type into numpy type. 

56 

57 :param proto_type: integer 

58 :return: proto type 

59 """ 

60 if proto_type == TensorProto.FLOAT: 

61 return numpy.float32 

62 if proto_type == TensorProto.DOUBLE: 

63 return numpy.float64 

64 # Not efficient. 

65 if proto_type == 'tensor(float)': 

66 return numpy.float32 

67 if proto_type == 'tensor(double)': 

68 return numpy.float64 

69 raise ValueError( 

70 f"Unexpected value proto_type={proto_type!r} (type={type(proto_type)!r}).") 

71 

72 

73def dtype_to_var_type(dtype): 

74 """ 

75 Converts a numpy dtype into a var type. 

76 """ 

77 from skl2onnx.common.data_types import ( 

78 FloatTensorType, DoubleTensorType, 

79 Int32TensorType, Int64TensorType) 

80 if dtype == numpy.float32: 

81 return FloatTensorType 

82 if dtype == numpy.float64: 

83 return DoubleTensorType 

84 if dtype == numpy.int64: 

85 return Int64TensorType 

86 if dtype == numpy.int32: 

87 return Int32TensorType 

88 raise ValueError( 

89 f"Unexpected value dtype={dtype!r}.") 

90 

91 

92def _finalize_new_onnx(graph, onx): 

93 onnx_model = helper.make_model(graph) 

94 onnx_model.ir_version = onx.ir_version 

95 onnx_model.producer_name = onx.producer_name 

96 onnx_model.producer_version = onx.producer_version 

97 onnx_model.domain = onx.domain 

98 onnx_model.model_version = onx.model_version 

99 onnx_model.doc_string = onx.doc_string 

100 if len(onx.metadata_props) > 0: # pragma: no cover 

101 values = {p.key: p.value for p in onx.metadata_props} 

102 helper.set_model_props(onnx_model, values) 

103 

104 del onnx_model.opset_import[:] # pylint: disable=E1101 

105 for oimp in onx.opset_import: 

106 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

107 op_set.domain = oimp.domain 

108 op_set.version = oimp.version 

109 return onnx_model 

110 

111 

112def add_initializer(model, name, value): 

113 """ 

114 Adds an initializer to graph. 

115 

116 :param model: onnx model 

117 :param name: initializer name 

118 :param value: value 

119 :return: new ONNX graph 

120 """ 

121 inits = set(i.name for i in model.graph.initializer) 

122 if name in inits: 

123 raise ValueError( # pragma: no cover 

124 f"Name {name!r} is already taken among {inits!r}.") 

125 list_inits = list(model.graph.initializer) 

126 list_inits.append( 

127 numpy_helper.from_array(value, name=name)) 

128 graph_def = helper.make_graph( 

129 model.graph.node, model.graph.name, 

130 model.graph.input, model.graph.output, 

131 list_inits) 

132 return _finalize_new_onnx(graph_def, model) 

133 

134 

135def replace_initializers_into_onnx(model, results): 

136 """ 

137 Replaces initializers by other initializers, 

138 usually trained ones. 

139 

140 :param model: onnx graph 

141 :param results: results to be added in a dictionary 

142 :return: new onnx graph 

143 """ 

144 inputs = list(model.graph.input) 

145 outputs = list(model.graph.output) 

146 inits = list(model.graph.initializer) 

147 

148 inits_dict = {init.name: i for i, init in enumerate(inits)} 

149 for k, v in results.items(): 

150 if k in inits_dict: 

151 if isinstance(v, numpy.ndarray): 

152 v = numpy_helper.from_array(v, k) 

153 elif isinstance(v, (C_OrtValue, OrtValue)): 

154 v = numpy_helper.from_array(v.numpy(), k) 

155 inits[inits_dict[k]] = v 

156 else: 

157 raise RuntimeError( # pragma: no cover 

158 f"Unable to find initializer {k!r} in {inits_dict!r}.") 

159 

160 graph = helper.make_graph( 

161 list(model.graph.node), model.graph.name, inputs, 

162 outputs, inits) 

163 return _finalize_new_onnx(graph, model)