Coverage for mlprodict/onnx_tools/_onnx_check_model.py: 92%
231 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# pylint: disable=W0511,E1101,W1309,E0611,C0302,R0912,C0200,R1725,R0205,E0401,E1136,E1111
2"""
3@file
4@brief Python implementation of `onnx.checker.check_model`.
5"""
6import os
7import warnings
8import numpy
9from onnx import ( # pylint: disable=W0611
10 TensorProto, TypeProto, ModelProto, AttributeProto, SequenceProto,
11 OptionalProto)
12from onnx.defs import onnx_opset_version, get_schema, OpSchema
13from onnx.onnx_cpp2py_export.defs import SchemaError
14from .. import get_ir_version
17IR_VERSION = get_ir_version(onnx_opset_version())
18ONNX_DOMAIN = ''
19AI_ONNX_ML_DOMAIN = 'ai.onnx.ml'
20AI_ONNX_TRAINING_DOMAIN = 'ai.onnx.ml.training'
23class OnnxCheckError(RuntimeError):
24 """
25 Raised when a model fails check.
27 :param msg: message
28 :param proto: proto
29 """
31 def __init__(self, msg, proto):
32 RuntimeError.__init__(self, msg)
33 self.proto = proto
36class UndefinedSchema:
37 """
38 Undefined schema.
39 """
41 def __init__(self, name, version, domain):
42 self.name = name
43 self.version = version
44 self.domain = domain
46 @property
47 def deprecated_(self):
48 "Returns False."
49 return False
51 def verify(self, node):
52 "Verifies a, undefined node is consistent with ONNX language."
53 if self.deprecated_:
54 raise OnnxCheckError( # pragma: no cover
55 f"Operator '{self.name_}' has been deprecated since "
56 f"version {self.since_version_}.",
57 node)
60class Schema(object):
61 """
62 Wrapper around a schema.
63 """
65 def __init__(self, schema):
66 self.schema = schema
68 def __getattr__(self, attr):
69 if attr.endswith('_') and hasattr(self.schema, attr[:-1]):
70 return getattr(self.schema, attr[:-1])
71 return super(Schema, self).__getattribute__(attr)
73 def num_inputs_allowed(self, n):
74 "Not implemented yet."
75 # return allowed_input_nums.count(n);
76 return True
78 def num_outputs_allowed(self, n):
79 "Not implemented yet."
80 # return allowed_input_nums.count(n);
81 return True
83 def verify(self, node):
84 "Verifies a node is consistent with ONNX language."
85 if self.deprecated_:
86 raise OnnxCheckError( # pragma: no cover
87 f"Operator '{self.name_}' has been deprecated since "
88 f"version {self.since_version_}.",
89 node)
91 # Check the number of inputs.
92 if (len(node.input) < self.min_input_ or
93 len(node.input) > self.max_input_):
94 raise OnnxCheckError( # pragma: no cover
95 f"Node '{node.name}' has input size {len(node.input)} "
96 f"not in range [min={self.min_input_}, "
97 f"max={self.max_input_}].",
98 node)
100 if not self.num_inputs_allowed(len(node.input)):
101 raise OnnxCheckError( # pragma: no cover
102 f"Node '{node.name}' has input size {len(node.input)} "
103 f"not in allowed input sizes.",
104 node)
106 # Check the number of outputs.
107 if (len(node.output) < self.min_output_ or
108 len(node.output) > self.max_output_):
109 raise OnnxCheckError( # pragma: no cover
110 f"Node '{node.name}' has output size {len(node.output)} "
111 f"not in range [min={self.min_output_}, "
112 f"max={self.max_output_}].",
113 node)
115 if not self.num_outputs_allowed(len(node.output)):
116 raise OnnxCheckError( # pragma: no cover
117 f"Node '{node.name}' has output size {len(node.output)} "
118 f"not in allowed output sizes.",
119 node)
121 # Check the values of inputs / outputs
122 for in_idx in range(len(node.input)):
123 if in_idx >= len(self.inputs_):
124 if (not self.inputs_ and
125 OpSchema.FormalParameterOption.Variadic ==
126 self.inputs_.back().GetOption()):
127 # The last input formal parameter should be variadic.
128 break
129 else:
130 raise OnnxCheckError( # pragma: no cover
131 f"Node '{node.name}' has more inputs ("
132 f"{len(node.input)} than declared {len(self.inputs_)}. "
133 f"in op definition.",
134 node)
136 if (not node.input[in_idx] and
137 OpSchema.FormalParameterOption.Single ==
138 self.inputs_[in_idx].GetOption()):
139 raise OnnxCheckError( # pragma: no cover
140 f"Node '{node.name}' input[{in_idx}] is marked single but "
141 f"has an empty string in the graph.",
142 node)
144 for out_idx in range(len(node.output)):
145 if out_idx >= len(self.outputs_):
146 if (not self.outputs_ and
147 OpSchema.FormalParameterOption.Variadic ==
148 self.outputs_.back().GetOption()):
149 # The last output formal parameter should be variadic.
150 break
151 else:
152 raise OnnxCheckError( # pragma: no cover
153 f"Node '{node.name}' has more outputs ("
154 f"{len(node.output)} than declared {len(self.outputs_)}. "
155 f"in op definition.",
156 node)
158 if (not node.output[out_idx] and
159 OpSchema.FormalParameterOption.Single ==
160 self.outputs_[out_idx].GetOption()):
161 raise OnnxCheckError( # pragma: no cover
162 f"Node '{node.name}' output[{out_idx}] is marked single but "
163 f"has an empty string in the graph.",
164 node)
166 # An internal symbol is defined as starting with two underscores. Attributes
167 # with names meeting this condition are considered implementation details
168 # and should be ignored for the purpose of schema checking.
169 def isInternalSymbol(sym): # pragma: no cover
170 return len(sym) >= 2 and sym[0] == '_' and sym[1] == '_'
172 # Check attributes
173 seen_attr_names = set()
174 for attr_proto in node.attribute: # pragma: no cover
175 name = attr_proto.name
177 if name in seen_attr_names:
178 raise OnnxCheckError( # pragma: no cover
179 f"Attribute '{name}' appeared multiple times.",
180 node)
181 seen_attr_names.add(name)
183 if name in self.attributes_:
184 search = self.attributes_.index(name)
185 else:
186 search = -1
187 expected_type = None
188 if search != -1:
189 expected_type = self.attributes_[search]
190 elif self.allows_unchecked_attributes_ or isInternalSymbol(name):
191 continue
192 else:
193 raise OnnxCheckError( # pragma: no cover
194 f"Unrecognized attribute '{name}' for operator "
195 f"'{node.op_type}'.", node)
197 # Type would be UNDEFINED if not set
198 if attr_proto.type != expected_type:
199 raise OnnxCheckError( # pragma: no cover
200 f"Mismatched attribute type in '{node.name}' and "
201 f"attribute '{name}'.", node)
203 # ref_attr_name is only valid when non-empty
204 # we simply read default value if not present
205 if not attr_proto.ref_attr_name:
206 continue
208 # if attr_proto.type != UNDEFINED
209 # we consider primitive types to be set even
210 # if proto3 did not output default values into the stream
211 # in which case we will read the default
212 if expected_type in (AttributeProto.FLOAT,
213 AttributeProto.INT,
214 AttributeProto.STRING):
215 pass
216 elif expected_type == AttributeProto.TENSOR:
217 if attr_proto.t.ByteSize == 0:
218 raise OnnxCheckError( # pragma: no cover
219 f"Attribute '{name}' is expected to have field "
220 f"'t'.", node)
221 elif expected_type == AttributeProto.SPARSE_TENSOR:
222 if attr_proto.sparse_tensor.ByteSize == 0:
223 raise OnnxCheckError( # pragma: no cover
224 f"Attribute '{name}' is expected to have field "
225 f"'sparse_tensor'.", node)
226 elif expected_type == AttributeProto.GRAPH:
227 if attr_proto.g.ByteSize == 0:
228 raise OnnxCheckError( # pragma: no cover
229 f"Attribute '{name}' is expected to have field "
230 f"'g'.", node)
231 if node.op_type == 'If' and len(attr_proto.g.input) > 0:
232 raise OnnxCheckError( # pragma: no cover
233 f"Attribute '{attr_proto.name}' of "
234 f"operator If with name '{node.name}' must not have "
235 f"inputs.", node)
236 elif expected_type == AttributeProto.TYPE_PROTO:
237 if attr_proto.tp.ByteSize == 0:
238 raise OnnxCheckError( # pragma: no cover
239 f"Attribute '{name}' is expected to have field "
240 f"'tp'.", node)
241 elif expected_type == AttributeProto.FLOATS:
242 if attr_proto.floats.ByteSize == 0:
243 raise OnnxCheckError( # pragma: no cover
244 f"Attribute '{name}' is expected to have field "
245 f"'floats'.", node)
246 elif expected_type == AttributeProto.INTS:
247 if attr_proto.ints.ByteSize == 0:
248 raise OnnxCheckError( # pragma: no cover
249 f"Attribute '{name}' is expected to have field "
250 f"'ints'.", node)
251 elif expected_type == AttributeProto.STRINGS:
252 if attr_proto.strings.ByteSize == 0:
253 raise OnnxCheckError( # pragma: no cover
254 f"Attribute '{name}' is expected to have field "
255 f"'strings'.", node)
256 elif expected_type == AttributeProto.TENSORS:
257 if attr_proto.tensors.ByteSize == 0:
258 raise OnnxCheckError( # pragma: no cover
259 f"Attribute '{name}' is expected to have field "
260 f"'tensors'.", node)
261 elif expected_type == AttributeProto.SPARSE_TENSORS:
262 # Not adding check ... we should likely delete the check in all other
263 # cases, which will not allow us to have an empty list as a valid value
264 # for an attribute and this seems undesirable.
265 pass
266 elif expected_type == AttributeProto.GRAPHS:
267 if attr_proto.graphs.ByteSize == 0:
268 raise OnnxCheckError( # pragma: no cover
269 f"Attribute '{name}' is expected to have field "
270 f"'graphs'.", node)
271 elif expected_type == AttributeProto.TYPE_PROTOS:
272 if attr_proto.type_protos.ByteSize == 0:
273 raise OnnxCheckError( # pragma: no cover
274 f"Attribute '{name}' is expected to have field "
275 f"'type_protos'.", node)
276 else:
277 raise OnnxCheckError( # pragma: no cover
278 f"Attribute '{name}' has unknown expected type.",
279 node)
281 for attr in self.attributes_:
282 if not attr.required:
283 continue
284 if attr.name not in seen_attr_names:
285 raise OnnxCheckError( # pragma: no cover
286 f"Required attribute '{attr.name}' is missing.",
287 node)
290class CheckerContextDefaultRegistry:
291 """
292 Registry.
293 """
295 def get_schema(self, op_type, version, domain):
296 "Accessor."
297 try:
298 return Schema(get_schema(op_type, version, domain))
299 except SchemaError:
300 return UndefinedSchema(op_type, version, domain)
302 def GetSchema(self, op_type, version, domain):
303 "Accessor."
304 return self.get_schema(op_type, version, domain)
307class CheckerContext:
308 """
309 Class hosting information about a graph.
310 """
312 def __init__(self, ctx=None):
313 if ctx is None:
314 self.ir_version_ = -1
315 self.opset_imports_ = {}
316 self.schema_registry_ = CheckerContextDefaultRegistry()
317 self.model_dir_ = None
318 self.is_main_graph_ = True
319 else:
320 self.ir_version_ = ctx.ir_version_
321 self.opset_imports_ = ctx.opset_imports_.copy()
322 self.schema_registry_ = ctx.schema_registry_
323 self.model_dir_ = ctx.model_dir_
324 self.is_main_graph_ = ctx.is_main_graph_
326 def get_ir_version(self):
327 "Accessor."
328 return self.ir_version_
330 def set_ir_version(self, v):
331 "Accessor."
332 self.ir_version_ = v
334 def get_opset_imports(self):
335 "Accessor."
336 return self.opset_imports_
338 def set_opset_imports(self, imps):
339 "Accessor."
340 self.opset_imports_ = imps
342 def is_main_graph(self):
343 "Accessor."
344 return self.is_main_graph_
346 def set_is_main_graph(self, is_main_graph):
347 "Accessor."
348 self.is_main_graph_ = is_main_graph # pragma: no cover
350 def set_schema_registry(self, schema_registry):
351 "Accessor."
352 self.schema_registry_ = schema_registry # pragma: no cover
354 def get_schema_registry(self):
355 "Accessor."
356 return self.schema_registry_
358 def set_model_dir(self, model_dir):
359 "Accessor."
360 self.model_dir_ = model_dir # pragma: no cover
362 def get_model_dir(self):
363 "Accessor."
364 return self.model_dir_ # pragma: no cover
367class LexicalScopeContext:
368 """
369 Construct an instance with the lexical scope from the parent graph to allow
370 lookup of names from that scope via this_or_ancestor_graph_has.
371 The caller must ensure parent_context remains valid for the entire lifetime
372 of the new instance. Alternatively, if that cannot be guaranteed, create an
373 instance with the default constructor and populate output_names with the
374 values from the parent scope so the values are copied instead.
375 """
377 def __init__(self, parent_context=None):
378 if parent_context is None:
379 self.parent_context_ = None
380 else:
381 self.parent_context_ = parent_context.copy()
382 self.output_names = set()
384 def add(self, name):
385 "Adds a name to the context."
386 self.output_names.add(name)
388 def this_graph_has(self, name):
389 "Checks the context includes a specific name."
390 return name in self.output_names
392 def this_or_ancestor_graph_has(self, name):
393 "Checks the context and its ancestor includes a specific name."
394 return self.this_graph_has(name) or (
395 self.parent_context_ and
396 self.parent_context_.this_or_ancestor_graph_has(name))
398 def copy(self):
399 "Copies the instance."
400 ctx = LexicalScopeContext(self.parent_context_)
401 ctx.output_names = set(self.output_names)
402 return ctx
405def _enforce_has_field(proto, field):
406 if not hasattr(proto, field):
407 raise OnnxCheckError( # pragma: no cover
408 f"Field '{field}' of '{proto}' is required but missing.", proto)
411def _enforce_has_repeated_field(proto, field):
412 if not getattr(proto, field + '_size')(): # pragma: no cover
413 raise OnnxCheckError( # pragma: no cover
414 f"Repeated field '{field}' of '{proto}' is required but missing.", proto)
417def _enforce_non_empty_field(proto, field):
418 if not getattr(proto, field):
419 raise OnnxCheckError( # pragma: no cover
420 f"Field '{field}' of '{proto}' is required to be non-empty.", proto)
423def _check_value_info(value_info, ctx):
424 _enforce_non_empty_field(value_info, "name")
425 # Relax constraint for subgraph input/output.
426 if not ctx.is_main_graph():
427 return # pragma: no cover
428 _enforce_has_field(value_info, "type")
429 value_case = None
430 for n in dir(value_info.type):
431 if n.endswith('_type'):
432 tt = getattr(value_info.type, n)
433 if tt.ByteSize() > 0:
434 if value_case is not None:
435 raise OnnxCheckError( # pragma: no cover
436 f"Value_info {value_info} has multiple types.",
437 value_info)
438 value_case = n
440 if value_case == "tensor_type":
441 _enforce_has_field(tt, "elem_type")
442 _enforce_has_field(tt, "shape")
443 elif value_case == "optional_type": # pragma: no cover
444 tt = value_info.type.optional_type
445 _enforce_has_field(tt, "elem_type")
446 elif value_case == "sequence_type": # pragma: no cover
447 tt = value_info.type.sequence_type
448 _enforce_has_field(tt, "elem_type")
449 elif value_case == "map_type": # pragma: no cover
450 tt = value_info.type.map_type
451 _enforce_has_field(tt, "key_type")
452 _enforce_has_field(tt, "value_type")
453 elif value_case == "opaque_type": # pragma: no cover
454 pass
455 elif value_case == "sparse_tensor_type": # pragma: no cover
456 tt = value_info.type.sparse_tensor_type
457 _enforce_has_field(tt, "elem_type")
458 _enforce_has_field(tt, "shape")
459 else:
460 raise OnnxCheckError( # pragma: no cover
461 f"Unrecognized type value case (value_info name '{value_info.name}' "
462 f"value_case={value_case!r}.", value_info)
465def _check_data_field(tensor, field, num_value_fields):
466 at = getattr(tensor, field)
467 has = len(at)
468 if has:
469 num_value_fields[0] += 1 # pylint: disable=E1137
470 value_field = getattr(tensor, field)
471 return value_field
472 return None
475def _check_field(tensor, field, value_field, nelem):
476 if nelem != 0 and len(getattr(tensor, field)): # pragma: no cover
477 raise OnnxCheckError( # pragma: no cover
478 f"values of data_type '{tensor.data_type} "
479 f"should be stored in field '{field}' "
480 f"instead of '{value_field}'.",
481 tensor)
484def _check_tensor(tensor, ctx):
486 _enforce_has_field(tensor, "data_type")
487 if tensor.data_type == TensorProto.UNDEFINED:
488 raise OnnxCheckError( # pragma: no cover
489 f"Setting data_type field (tensor name '{tensor.name}' "
490 f"to UNDEFINED is not allowed.", tensor)
492 num_value_fields = [0]
494 value_field = (
495 _check_data_field(tensor, "float_data", num_value_fields) or
496 _check_data_field(tensor, "int32_data", num_value_fields) or
497 _check_data_field(tensor, "string_data", num_value_fields) or
498 _check_data_field(tensor, "int64_data", num_value_fields) or
499 _check_data_field(tensor, "raw_data", num_value_fields) or
500 _check_data_field(tensor, "double_data", num_value_fields) or
501 _check_data_field(tensor, "uint64_data", num_value_fields))
503 num_value_fields = num_value_fields[0]
505 stored_externally = (
506 hasattr(tensor, 'data_location') and
507 tensor.data_location == TensorProto.EXTERNAL)
508 if stored_externally:
509 if num_value_fields != 0: # pragma: no cover
510 raise OnnxCheckError( # pragma: no cover
511 f"Data of TensorProto ( tensor name: f{tensor.name}) "
512 f"is stored externally and should not have data field: "
513 f"{value_field}.", tensor)
515 has_location = False
516 for entry in tensor.external_data(): # pragma: no cover
517 # if entry.has_key() and entry.has_value() and entry.key() == "location":
518 if entry.has_value() and entry.key() == "location":
519 has_location = True
520 data_path = os.path.join(ctx.get_model_dir(), entry.value())
521 # use stat to check whether the file exists
522 if os.stat(data_path).st_size != 0:
523 raise OnnxCheckError( # pragma: no cover
524 f"Data of TensorProto ( tensor name: {tensor.name} "
525 f"should be stored in {data_path}, but it doesn't "
526 "exist or is not accessible.", tensor)
527 if not has_location:
528 raise OnnxCheckError( # pragma: no cover
529 f"TensorProto tensor name {tensor.name} is stored externally "
530 f"but doesn't have a location.",
531 tensor)
532 return
534 nelem = 1
535 for x in tensor.dims:
536 nelem *= x
538 if nelem == 0 and num_value_fields != 0:
539 raise OnnxCheckError( # pragma: no cover
540 f"TensorProto (tensor name f{tensor.name} "
541 f"is 0-element but contains data!",
542 tensor)
543 if nelem != 0 and num_value_fields != 1:
544 raise OnnxCheckError( # pragma: no cover
545 f"TensorProto (tensor name: {tensor.name} "
546 f"should contain one and only one value field.",
547 tensor)
548 if hasattr(tensor, 'raw_data') and len(tensor.raw_data) > 0:
549 if tensor.data_type == TensorProto.STRING:
550 raise OnnxCheckError( # pragma: no cover
551 f"STRING data (tensor name: f{tensor.name} "
552 f"should not be stored in raw_data field",
553 tensor)
554 else: # pragma: no cover
555 if tensor.data_type in (TensorProto.FLOAT,
556 TensorProto.COMPLEX64):
557 _check_field(tensor, "float_data", value_field, nelem)
558 elif tensor.data_type in (TensorProto.DOUBLE,
559 TensorProto.COMPLEX128):
560 _check_field(tensor, "double_data", value_field, nelem)
561 elif tensor.data_type in (TensorProto.INT32,
562 TensorProto.UINT8,
563 TensorProto.INT8,
564 TensorProto.UINT16,
565 TensorProto.INT16,
566 TensorProto.BOOL,
567 TensorProto.FLOAT16,
568 TensorProto.BFLOAT16):
569 _check_field(tensor, "int32_data", value_field, nelem)
570 elif tensor.data_type == TensorProto.INT64:
571 _check_field(tensor, "int64_data", value_field, nelem)
572 elif tensor.data_type == TensorProto.INT64:
573 _check_field(tensor, "int64_data", value_field, nelem)
574 elif tensor.data_type in (TensorProto.UINT32,
575 TensorProto.UINT64):
576 _check_field(tensor, "uint64_data", value_field, nelem)
577 elif tensor.data_type == TensorProto.STRING:
578 _check_field(tensor, "string_data", value_field, nelem)
579 else:
580 raise OnnxCheckError( # pragma: no cover
581 f"Unrecognized data_type (tensor name: {tensor.name} "
582 f"): {tensor.data_type}.",
583 tensor)
586def _check_sequence(sequence, ctx): # pragma: no cover
587 _enforce_has_field(sequence, "elem_type")
588 if sequence.elem_type == SequenceProto.TENSOR:
589 for tensor in sequence.tensor_values():
590 _check_tensor(tensor, ctx)
591 elif sequence.elem_type == SequenceProto.SPARSE_TENSOR:
592 for sparse_tensor in sequence.sparse_tensor_values():
593 _check_sparse_tensor(sparse_tensor, ctx)
594 elif sequence.elem_type == SequenceProto.SEQUENCE:
595 for seq in sequence.sequence_values():
596 _check_sequence(seq, ctx)
597 elif sequence.elem_type == SequenceProto.MAP:
598 for map in sequence.map_values():
599 _check_map(map, ctx)
600 else:
601 raise OnnxCheckError( # pragma: no cover
602 f"Sequence ( Structure name: {sequence.name}, "
603 f"elem_type: {sequence.elem_type}) is not have "
604 f"a valid element type.",
605 sequence)
608def _check_optional(optional, ctx): # pragma: no cover
609 _enforce_has_field(optional, "elem_type")
610 if optional.elem_type == OptionalProto.UNDEFINED:
611 return
612 elif optional.elem_type == OptionalProto.TENSOR:
613 if optional.has_tensor_value():
614 _check_tensor(optional.tensor_value(), ctx)
615 elif optional.elem_type == OptionalProto.SPARSE_TENSOR:
616 if optional.has_sparse_tensor_value():
617 _check_sparse_tensor(optional.sparse_tensor_value(), ctx)
618 elif optional.elem_type == OptionalProto.SEQUENCE:
619 if optional.has_sequence_value():
620 _check_sequence(optional.sequence_value(), ctx)
621 elif optional.elem_type == OptionalProto.MAP:
622 if (optional.has_map_value()):
623 _check_map(optional.map_value(), ctx)
624 else:
625 raise OnnxCheckError( # pragma: no cover
626 f"Optional ( Structure name: {optional.name}, "
627 f"elem_type: {optional.elem_type}) is not "
628 f"have a valid element type.",
629 optional)
632def _check_map(map, ctx): # pragma: no cover
633 _enforce_has_field(map, 'key_type')
634 if map.key_type() == TensorProto.UNDEFINED:
635 raise OnnxCheckError( # pragma: no cover
636 f"Setting key_type field (map name: '{map.name}') "
637 f"to UNDEFINED is not allowed.",
638 map)
639 # Check if key is a valid type, specifically INT8, INT16, INT32, INT64,
640 # UINT8, UINT16, UINT32, UINT64, or STRING.
641 if map.key_type() in (TensorProto.FLOAT, TensorProto.BOOL,
642 TensorProto.FLOAT16, TensorProto.COMPLEX64,
643 TensorProto.COMPLEX128):
644 raise OnnxCheckError( # pragma: no cover
645 f"Setting key_type field (map name: {map.name}) "
646 f" to invalid TensorProto key_type {map.key_type()} "
647 f"is not allowed",
648 map)
649 # MapProto will use either keys or string_keys, so only one should be > 0.
650 if map.keys_size() > 0 and map.string_keys_size() > 0:
651 raise OnnxCheckError( # pragma: no cover
652 f"Map (name: '{map.name}') should not "
653 f"contain more than one keys field.",
654 map)
656 num_keys = map.keys_size() + map.string_keys_size()
657 num_values = 0
659 _enforce_has_field(map, 'values')
660 _check_sequence(map.values(), ctx)
662 if map.values().elem_type == SequenceProto.TENSOR:
663 num_values = map.values().tensor_values_size()
664 elif map.values().elem_type == SequenceProto.SPARSE_TENSOR:
665 num_values = map.values().sparse_tensor_values_size()
666 elif map.values().elem_type == SequenceProto.SEQUENCE:
667 num_values = map.values().sequence_values_size()
668 elif map.values().elem_type == SequenceProto.MAP:
669 num_values = map.values().map_values_size()
671 if num_keys != num_values:
672 raise OnnxCheckError( # pragma: no cover
673 f"Length of map keys and map values are not the same "
674 f"(map name: '{map.name}').",
675 map)
678def _parse_data(dtype, indices):
679 if dtype != indices.dtype:
680 raise OnnxCheckError( # pragma: no cover
681 f"Wrong element type {indices.dtype}, expected is {dtype}.",
682 None)
685def _check_sparse_tensor_indices_1( # pragma: no cover
686 indices, sparse_tensor_proto, nnz): # pragma: no cover
687 """
688 Check that the index data stored in a SparseTensorProto is valid.
689 indices: a 1-dimensional tensor; indices[i] represents the
690 linearized index value for the i-th nonzero value.
691 """
692 dense_rank = sparse_tensor_proto.dims_size()
693 dense_size = 1
694 for i in range(dense_rank):
695 dense_size *= sparse_tensor_proto.dims(i)
696 if indices.dims(0) != nnz:
697 raise OnnxCheckError( # pragma: no cover
698 f"Sparse tensor indices '{indices.name}' has "
699 f"{indices.dims(0)} values, but NNZ is {nnz}.",
700 sparse_tensor_proto)
702 # Check if indices appear in ascending order, and if they have valid
703 # values. The i-th value in index_data is the linear index of the i-th
704 # non-zero value.
705 index_data = _parse_data(numpy.int64, indices)
707 prev_index = -1
708 for i in range(nnz):
709 curr_index = index_data[i] # linearized index of i-th value
710 if curr_index < 0 or curr_index >= dense_size:
711 raise OnnxCheckError( # pragma: no cover
712 f"Sparse tensor '{indices.name}' index value at "
713 f"position [{i}] out of range [0, {dense_size - 1}].",
714 sparse_tensor_proto)
715 if curr_index <= prev_index:
716 raise OnnxCheckError( # pragma: no cover
717 f"Sparse tensor '{indices.name}' index value at "
718 f"position [{i}] not in sorted order.",
719 sparse_tensor_proto)
720 prev_index = curr_index
723def _check_sparse_tensor_indices_2( # pragma: no cover
724 indices, sparse_tensor_proto, nnz): # pragma: no cover
725 """
726 Check that the index data stored in a SparseTensorProto is valid.
727 indices: a 2-dimensional tensor; indices[i,j] represents the j-th
728 index value for the i-th nonzero value.
729 """
730 dense_rank = sparse_tensor_proto.dims_size()
731 if indices.dims(0) != nnz:
732 raise OnnxCheckError( # pragma: no cover
733 f"Sparse tensor indices '{indices.name}' "
734 f"first dimension size does not equal NNZ={nnz}.",
735 sparse_tensor_proto)
737 if indices.dims(1) != dense_rank:
738 raise OnnxCheckError( # pragma: no cover
739 f"Sparse tensor indices '{indices.name}' "
740 f"second dimension size does not equal "
741 f"dense_rank={dense_rank}.",
742 sparse_tensor_proto)
744 # Check if indices appear in ascending order, and if they have valid
745 # values.
746 index_data = _parse_data(numpy.int64, indices)
747 prev_index = -1
748 for i in range(nnz):
749 curr_index = 0 # linearized index of i-th value
750 for j in range(dense_rank):
751 index_ij = index_data[i * dense_rank + j]
752 if index_ij < 0 or index_ij >= sparse_tensor_proto.dims(j):
753 raise OnnxCheckError( # pragma: no cover
754 f"Sparse tensor '{indices.name}' index value "
755 f"at position [{i}, {j}] out of range.",
756 sparse_tensor_proto)
757 curr_index = curr_index * sparse_tensor_proto.dims(j) + index_ij
758 if curr_index <= prev_index:
759 raise OnnxCheckError( # pragma: no cover
760 f"Sparse tensor '{indices.name}' index value "
761 f"at position [{i}] not in lexicographic sorted "
762 "order.", sparse_tensor_proto)
763 prev_index = curr_index
766def _check_sparse_tensor(sparse_tensor_proto, ctx): # pragma: no cover
767 _enforce_has_field(sparse_tensor_proto, "values")
769 values = sparse_tensor_proto.values()
770 _check_tensor(values, ctx)
772 # values must be a tensor of shape [NNZ]
773 # Currently we restrict the value associated with a particular index-tuple
774 # to be a single value. In the future, if there is a requirement,
775 # we may extend this to permit the value to be a "sub-tensor", in which
776 # case values will have dimension > 1.
777 if values.dims_size() != 1:
778 raise OnnxCheckError( # pragma: no cover
779 f"Sparse tensor values '{values.name}' must have rank 1.",
780 sparse_tensor_proto)
782 nnz = values.dims(0)
783 dense_rank = sparse_tensor_proto.dims_size()
784 if dense_rank == 0:
785 raise OnnxCheckError( # pragma: no cover
786 f"Sparse tensor '{values.name}' must have a "
787 f"dense-rank > 0.", sparse_tensor_proto)
789 for i in range(dense_rank):
790 if sparse_tensor_proto.dims(i) <= 0:
791 raise OnnxCheckError( # pragma: no cover
792 f"Sparse tensor '{values.name} dimensions "
793 f"are not positive.", sparse_tensor_proto)
795 if sparse_tensor_proto.has_indices():
796 indices = sparse_tensor_proto.indices()
797 _check_tensor(indices, ctx)
798 if indices.data_type != TensorProto.INT64:
799 raise OnnxCheckError( # pragma: no cover
800 f"Sparse tensor indices '{indices.name}' must have INT64 type.",
801 sparse_tensor_proto)
803 if indices.dims().size() == 1:
804 # Indices in linearized format
805 _check_sparse_tensor_indices_1(indices, sparse_tensor_proto, nnz)
806 return
807 if indices.dims().size() == 2:
808 # Check COO-style index. E.g., an index for a 3D tensor is a 3-tuple.
809 _check_sparse_tensor_indices_2(indices, sparse_tensor_proto, nnz)
810 return
811 raise OnnxCheckError( # pragma: no cover
812 f"Sparse tensor indices '{indices.name}' must have rank 1 or 2.",
813 sparse_tensor_proto)
814 elif nnz != 0:
815 raise OnnxCheckError( # pragma: no cover
816 f"Sparse tensor '{values.name}' has no index values.",
817 sparse_tensor_proto)
820def check_attribute(attr, ctx, lex_ctx): # pragma: no cover
821 """
822 NB: This is a generic "attribute well-formedness" check, it doesn't
823 actually test if an attribute is valid per a schema.
824 """
825 _enforce_non_empty_field(attr, "name")
827 if ctx.get_ir_version() >= 0x00000002:
828 _enforce_has_field(attr, "type")
830 used_fields = 0
832 def check_type(expected_type):
833 if hasattr(attr, 'type') and attr.type != expected_type:
834 raise OnnxCheckError( # pragma: no cover
835 f"Type field and data field mismatch in attribute '{attr.name}'.",
836 attr)
838 def check_singular_field(field, itype):
839 if hasattr(attr, field):
840 check_type(itype)
841 return 1
842 return 0
844 def check_repeated_field(field, type):
845 if getattr(attr, field + '_size')() > 0:
846 check_type(type)
847 return 1
848 return 0
850 used_fields += check_singular_field("f", AttributeProto.FLOAT)
851 used_fields += check_singular_field("i", AttributeProto.INT)
852 used_fields += check_singular_field("s", AttributeProto.STRING)
853 used_fields += check_singular_field("t", AttributeProto.TENSOR)
854 used_fields += check_singular_field("g", AttributeProto.GRAPH)
855 used_fields += check_singular_field("tp", AttributeProto.TYPE_PROTO)
856 used_fields += check_singular_field("sparse_tensor",
857 AttributeProto.SPARSE_TENSOR)
858 used_fields += check_repeated_field("floats", AttributeProto.FLOATS)
859 used_fields += check_repeated_field("ints", AttributeProto.INTS)
860 used_fields += check_repeated_field("strings", AttributeProto.STRINGS)
861 used_fields += check_repeated_field("tensors", AttributeProto.TENSORS)
862 used_fields += check_repeated_field("graphs", AttributeProto.GRAPHS)
863 used_fields += check_repeated_field("sparse_tensors",
864 AttributeProto.SPARSE_TENSORS)
865 used_fields += check_repeated_field("type_protos",
866 AttributeProto.TYPE_PROTOS)
868 # Normally, used_fields is expected to be 1.
869 # In proto3, when the value to be set is type default value
870 # (say 0 for int), used_fields may be 0.
871 if used_fields > 1:
872 raise OnnxCheckError( # pragma: no cover
873 f"Attribute (name: '{attr.name}') should not "
874 f"contain more than one value field.",
875 attr)
877 if not ctx.is_main_graph():
878 # It's an attribute of a node in function body.
879 if attr.has_ref_attr_name() and used_fields != 0:
880 # The attribute proto is supposed to refer to data outside and does not
881 # have its own value field set.
882 raise OnnxCheckError( # pragma: no cover
883 f"Attribute (name: '{attr.name}') should refer "
884 f"to attribute in parent node.",
885 attr)
887 if attr.has_t():
888 _check_tensor(attr.t(), ctx)
890 if attr.has_sparse_tensor():
891 _check_sparse_tensor(attr.sparse_tensor(), ctx)
893 if attr.has_g():
894 subgraph_ctx = CheckerContext(ctx)
895 subgraph_ctx.set_is_main_graph(False)
896 _check_graph(attr.g(), subgraph_ctx, lex_ctx)
898 for tensor in attr.tensors():
899 _check_tensor(tensor, ctx)
901 for sparse_tensor in attr.sparse_tensors():
902 _check_sparse_tensor(sparse_tensor, ctx)
904 if attr.graphs().size() > 0:
905 subgraph_ctx = CheckerContext(ctx)
906 subgraph_ctx.set_is_main_graph(False)
907 for graph in attr.graphs():
908 _check_graph(graph, subgraph_ctx, lex_ctx)
911def _check_node(node, ctx, lex_ctx):
912 _enforce_non_empty_field(node, "op_type")
914 if not node.input and not node.output:
915 raise OnnxCheckError( # pragma: no cover
916 f"NodeProto (name: '{node.name}', type: '{node.op_type}') "
917 f"has zero input and zero output.",
918 node)
920 # If encounter experimental op, stop checking
921 if check_is_experimental_op(node.op_type):
922 warnings.warn( # pragma: no cover
923 f"Warning: Checker does not support models "
924 f"with experimental ops: '{node.op_type}'.")
925 return # pragma: no cover
927 # Resolve domain for node
928 opset_imports = ctx.get_opset_imports()
929 if node.domain not in opset_imports:
930 raise OnnxCheckError( # pragma: no cover
931 f"No opset import for domain '{node.domain}'.",
932 node)
933 domain_version = opset_imports[node.domain]
935 for attr in node.attribute:
936 check_attribute(attr, ctx, lex_ctx)
938 schema = ctx.get_schema_registry().GetSchema(
939 node.op_type, domain_version, node.domain)
940 if not schema:
941 if node.domain in (ONNX_DOMAIN, AI_ONNX_ML_DOMAIN, # pragma: no cover
942 "ai.onnx", AI_ONNX_TRAINING_DOMAIN):
943 # fail the checker if op in built-in domains has no schema
944 raise OnnxCheckError( # pragma: no cover
945 f"No Op registered for '{node.op_type}' with domain_version "
946 f"of {domain_version}.",
947 node)
948 else:
949 # TODO: expose the registration of the op schemas appropriately in
950 # python, so we can load and register operators in other domains
951 # before we complete the above todo, let's skip the schema check for now
952 pass # pragma: no cover
953 elif schema.deprecated_:
954 raise OnnxCheckError( # pragma: no cover
955 f"Op registered for '{node.op_type}' is deprecated "
956 f"in domain_version of {domain_version}.",
957 node)
958 else:
959 schema.verify(node)
962def _check_graph(graph, ctx, parent_lex):
963 _enforce_non_empty_field(graph, "name")
965 for value_info in graph.input:
966 _check_value_info(value_info, ctx)
967 for value_info in graph.output:
968 _check_value_info(value_info, ctx)
970 # Inherit values available in outer scope
971 # Note that we do not allow shadowing, so the presence of an already-defined
972 # name is always an error.
973 lex_ctx = LexicalScopeContext(parent_lex)
975 for value_info in graph.input:
976 # TODO: If shadowing isn't allowed, this should maybe use
977 # this_or_ancestor_graph_has
978 if lex_ctx.this_graph_has(value_info.name):
979 raise OnnxCheckError( # pragma: no cover
980 f"Graph must be in single static assignment (SSA) form, "
981 f"however '{value_info.name}' has been used as "
982 f"graph input names multiple times.",
983 graph)
984 lex_ctx.add(value_info.name)
986 initializer_name_checker = set()
987 # std::unordered_set<std::reference_wrapper<const std::string>, std::hash<std::string>, std::equal_to<std::string>>
989 for init in graph.initializer:
990 _enforce_has_field(init, "name")
991 name = init.name
992 if not name:
993 raise OnnxCheckError( # pragma: no cover
994 f"Tensor initializers must have a non-empty name.",
995 graph)
997 if name in initializer_name_checker:
998 raise OnnxCheckError( # pragma: no cover
999 f"'{name}' initializer name is not unique.",
1000 graph)
1001 initializer_name_checker.add(name)
1003 _check_tensor(init, ctx)
1005 if ctx.get_ir_version() <= 0x00000003:
1006 # Initializers are a subset of graph inputs for IR_VERSION <= 3
1007 if not lex_ctx.this_graph_has(name):
1008 raise OnnxCheckError( # pragma: no cover
1009 f"'{name}' in initializer but not in graph input.",
1010 graph)
1011 else:
1012 # An initializer is allowed to have the same name as an input,
1013 # but is not required to (for IR_VERSION >= 4)
1014 lex_ctx.add(name)
1016 for sparse_init in graph.sparse_initializer: # pragma: no cover
1017 values = sparse_init.values()
1018 _enforce_has_field(values, name)
1019 name = values.name
1020 if name.empty():
1021 raise OnnxCheckError( # pragma: no cover
1022 f"Sparse tensor initializers must have a non-empty name.",
1023 graph)
1024 if name in initializer_name_checker:
1025 raise OnnxCheckError( # pragma: no cover
1026 f"'{name}' initializer name is not unique across "
1027 f"initializers and sparse_initializers.",
1028 graph)
1029 initializer_name_checker.add(name)
1030 _check_sparse_tensor(sparse_init, ctx)
1031 lex_ctx.add(name)
1033 errors = []
1034 for node in graph.node:
1035 # nodes must be in topologically sorted order
1036 for input in node.input:
1037 # explicit optional input
1038 if not input:
1039 continue # pragma: no cover
1040 if not lex_ctx.this_or_ancestor_graph_has(input):
1041 raise OnnxCheckError( # pragma: no cover
1042 f"Nodes in a graph must be topologically sorted, however "
1043 f"input '{input}' of node name '{node.name}', type "
1044 f"'{node.op_type}' is not output of any previous nodes.",
1045 node)
1047 # This needs to happen before SSA check since we don't want to recurse and
1048 # find that outputs from control flow ops are colliding with names in the
1049 # inner block
1051 try:
1052 _check_node(node, ctx, lex_ctx)
1053 except OnnxCheckError as e:
1054 errors.append(e)
1056 # check for SSA form
1057 for output in node.output:
1058 # optional output
1059 if not output:
1060 continue
1062 if lex_ctx.this_or_ancestor_graph_has(output):
1063 raise OnnxCheckError( # pragma: no cover
1064 f"Graph must be in single static assignment "
1065 f"(SSA) form, however '{output}' "
1066 f"has been used as output names multiple times.",
1067 graph)
1068 lex_ctx.add(output)
1071def _get_version_for_domain(domain, opset_imports): # pragma: no cover
1072 # Utilify function to get the imported version of domain from opset imports
1073 # Returns -1 if requested domain is not found in the opset_imports
1074 if domain not in opset_imports.end():
1075 return -1
1076 return opset_imports[domain]
1079def _check_opset_compatibility( # pragma: no cover
1080 node, ctx, func_opset_imports, model_opset_imports): # pragma: no cover
1081 func_opset_version = _get_version_for_domain(
1082 node.domain, func_opset_imports)
1083 model_opset_version = _get_version_for_domain(
1084 node.domain, model_opset_imports)
1086 if func_opset_version == -1:
1087 raise OnnxCheckError( # pragma: no cover
1088 f"No Opset registered for domain '{node.domain}'.",
1089 node)
1091 if model_opset_version == -1:
1092 # model does not include opset import for a node present in function body.
1093 # This is ok as along as the opset import is present in function level opset imports.
1094 return
1096 if func_opset_version == model_opset_version:
1097 # both versions are same, no need to verify schema.
1098 return
1100 schema_for_model_import = ctx.get_schema_registry().GetSchema(
1101 node.op_type, model_opset_version, node.domain)
1102 schema_for_function_import = ctx.get_schema_registry().GetSchema(
1103 node.op_type, func_opset_version, node.domain)
1105 if not schema_for_model_import and not schema_for_function_import:
1106 # the op belongs to a custom domain so we cannot verify schema
1107 return
1109 # if schema is present for 1 but not other or the schema since
1110 # versions do not match then raise an error
1111 if (not schema_for_model_import or not schema_for_function_import or
1112 schema_for_function_import.since_version() != schema_for_model_import.since_version()):
1113 raise OnnxCheckError( # pragma: no cover
1114 f"Opset import for domain '{node.domain}' in function op "
1115 f"'{node.op_type} is not compatible with the version "
1116 f"imported by model. FunctionOp imports version "
1117 f"{func_opset_version} whereas model imports version "
1118 f"{model_opset_version}.",
1119 node)
1122def _check_model_local_functions(model, ctx, parent_lex): # pragma: no cover
1123 # make a copy of model opset imports to maintain a main copy of opset imports across the model and
1124 # all model local functions to verify opset compatibility
1125 model_opset_imports = ctx.get_opset_imports()
1127 # merge the opset imports from every function in model_opset_imports
1128 # only add the opset import if an entry for it does not exist in model_opset_imports
1129 # if there is an entry then the compatibility will be checked later
1130 # on in check_opset_compatibility
1131 # called by check_function.
1132 for function_proto in model.functions:
1133 for opset_import in function_proto.opset_import():
1134 if _get_version_for_domain(opset_import.domain, model_opset_imports) == -1:
1135 model_opset_imports[opset_import.domain] = opset_import.version
1137 ctx_copy = CheckerContext(ctx)
1138 ctx_copy.set_opset_imports(model_opset_imports)
1140 for function_proto in model.functions:
1141 _check_function(function_proto, ctx_copy, parent_lex)
1144def _check_function(function, ctx, parent_lex): # pragma: no cover
1145 _enforce_non_empty_field(function, "name")
1147 if ctx.get_ir_version() >= 0x00000008:
1148 _enforce_has_field(function, "domain")
1150 model_opset_imports = ctx.get_opset_imports()
1151 ctx_copy = CheckerContext(ctx)
1153 func_opset_imports = {}
1154 for relied_opset in function.opset_import():
1155 func_opset_imports[relied_opset.domain] = int(relied_opset.version)
1157 ctx_copy.set_opset_imports(func_opset_imports)
1159 lex_ctx = LexicalScopeContext(parent_lex)
1161 for input in function.input:
1162 # TODO: If shadowing isn't allowed, this should maybe use
1163 # this_or_ancestor_graph_has
1164 if lex_ctx.this_graph_has(input):
1165 raise OnnxCheckError( # pragma: no cover
1166 f"Graph must be in single static assignment (SSA) form, "
1167 f"however '{input}' has been used multiple times.",
1168 function)
1169 lex_ctx.add(input)
1171 outputs = set()
1172 for output in function.output:
1173 if output in outputs:
1174 raise OnnxCheckError( # pragma: no cover
1175 f"Function '{function.name}' should not have "
1176 f"duplicate outputs specified.",
1177 function)
1178 outputs.add(output)
1180 attrs = set()
1181 for attr in function.attribute:
1182 if attr in attrs:
1183 raise OnnxCheckError( # pragma: no cover
1184 f"Function '{function.name}' should not have "
1185 f"duplicate attributes specified.",
1186 function)
1188 for node in function.node():
1189 # nodes must be in topologically sorted order
1190 for input in node.input:
1191 # explicit optional input
1192 if input.empty():
1193 continue
1194 if not lex_ctx.this_graph_has(input):
1195 raise OnnxCheckError( # pragma: no cover
1196 f"Nodes in a function must be topologically sorted, "
1197 f"however input '{input}' of node name '{node.name}' "
1198 f"and type '{node.op_type}' is neither output "
1199 f"of any previous nodes nor input of the function.",
1200 function)
1202 # check whether the opset version imported for a domain by function and model are
1203 # compatible
1204 _check_opset_compatibility(
1205 node, ctx_copy, func_opset_imports, model_opset_imports)
1206 _check_node(node, ctx_copy, lex_ctx)
1208 # check for SSA form
1209 for output in node.output:
1210 # optional output
1211 if output.empty():
1212 continue
1214 if lex_ctx.this_or_ancestor_graph_has(output):
1215 raise OnnxCheckError( # pragma: no cover
1216 f"Function must be in single static assignment (SSA) "
1217 f"form, however '{output}' has been used as output "
1218 f"names multiple times.",
1219 function)
1220 lex_ctx.add(output)
1223def _check_model(model, ctx):
1224 if not model.ir_version:
1225 raise OnnxCheckError( # pragma: no cover
1226 f"The model does not have an ir_version set properly.",
1227 model)
1228 if model.ir_version > IR_VERSION:
1229 raise OnnxCheckError( # pragma: no cover
1230 f"Your model ir_version is higher than the checker's.",
1231 model)
1232 if len(model.metadata_props) > 1: # pragma: no cover
1233 keys = set()
1234 for entry in model.metadata_props:
1235 if entry.key() in keys:
1236 raise OnnxCheckError( # pragma: no cover
1237 f"Your model has duplicate keys '{entry.key()}' "
1238 f"in metadata_props.", model)
1239 keys.add(entry.key())
1241 ctx.set_ir_version(int(model.ir_version))
1242 opset_imports = {}
1243 for opset_import in model.opset_import:
1244 opset_imports[opset_import.domain] = int(opset_import.version)
1245 if model.ir_version >= 3:
1246 if not opset_imports:
1247 raise OnnxCheckError( # pragma: no cover
1248 f"Model with IR version >= 3 must specify opset_import for "
1249 f"ONNX ({opset_imports}).",
1250 model)
1251 elif not opset_imports: # pragma: no cover
1252 opset_imports[ONNX_DOMAIN] = 1
1253 else:
1254 raise OnnxCheckError( # pragma: no cover
1255 f"Model with IR version < 3 cannot have opset_import specified.",
1256 model)
1258 ctx.set_opset_imports(opset_imports)
1259 lex_ctx = LexicalScopeContext()
1260 _check_graph(model.graph, ctx, lex_ctx)
1262 if ctx.get_ir_version() >= 0x00000008:
1263 _check_model_local_functions(model, ctx, lex_ctx)
1266def check_model(model):
1267 """
1268 Checks a model is consistent with ONNX language.
1269 The function fails if the model is not consistent.
1271 :param model: :epkg:`ModelProto`
1272 """
1273 ctx = CheckerContext()
1274 if isinstance(model, bytes):
1275 m = ModelProto()
1276 m.ParseFromString(model)
1277 _check_model(m, ctx)
1278 else:
1279 _check_model(model, ctx)
1282experimental_ops = {
1283 "ATen",
1284 "Affine",
1285 "ConstantFill",
1286 "Crop",
1287 "DynamicSlice",
1288 "GRUUnit",
1289 "GivenTensorFill",
1290 "ImageScaler",
1291 "ParametricSoftplus",
1292 "Scale",
1293 "ScaledTanh"}
1296def check_is_experimental_op(node_op_type):
1297 "Tells if an operator is experimentation."
1298 return bool(experimental_ops & {node_op_type})