Coverage for mlprodict/onnx_tools/_onnx_check_model.py: 92%

231 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# pylint: disable=W0511,E1101,W1309,E0611,C0302,R0912,C0200,R1725,R0205,E0401,E1136,E1111 

2""" 

3@file 

4@brief Python implementation of `onnx.checker.check_model`. 

5""" 

6import os 

7import warnings 

8import numpy 

9from onnx import ( # pylint: disable=W0611 

10 TensorProto, TypeProto, ModelProto, AttributeProto, SequenceProto, 

11 OptionalProto) 

12from onnx.defs import onnx_opset_version, get_schema, OpSchema 

13from onnx.onnx_cpp2py_export.defs import SchemaError 

14from .. import get_ir_version 

15 

16 

17IR_VERSION = get_ir_version(onnx_opset_version()) 

18ONNX_DOMAIN = '' 

19AI_ONNX_ML_DOMAIN = 'ai.onnx.ml' 

20AI_ONNX_TRAINING_DOMAIN = 'ai.onnx.ml.training' 

21 

22 

23class OnnxCheckError(RuntimeError): 

24 """ 

25 Raised when a model fails check. 

26 

27 :param msg: message 

28 :param proto: proto 

29 """ 

30 

31 def __init__(self, msg, proto): 

32 RuntimeError.__init__(self, msg) 

33 self.proto = proto 

34 

35 

36class UndefinedSchema: 

37 """ 

38 Undefined schema. 

39 """ 

40 

41 def __init__(self, name, version, domain): 

42 self.name = name 

43 self.version = version 

44 self.domain = domain 

45 

46 @property 

47 def deprecated_(self): 

48 "Returns False." 

49 return False 

50 

51 def verify(self, node): 

52 "Verifies a, undefined node is consistent with ONNX language." 

53 if self.deprecated_: 

54 raise OnnxCheckError( # pragma: no cover 

55 f"Operator '{self.name_}' has been deprecated since " 

56 f"version {self.since_version_}.", 

57 node) 

58 

59 

60class Schema(object): 

61 """ 

62 Wrapper around a schema. 

63 """ 

64 

65 def __init__(self, schema): 

66 self.schema = schema 

67 

68 def __getattr__(self, attr): 

69 if attr.endswith('_') and hasattr(self.schema, attr[:-1]): 

70 return getattr(self.schema, attr[:-1]) 

71 return super(Schema, self).__getattribute__(attr) 

72 

73 def num_inputs_allowed(self, n): 

74 "Not implemented yet." 

75 # return allowed_input_nums.count(n); 

76 return True 

77 

78 def num_outputs_allowed(self, n): 

79 "Not implemented yet." 

80 # return allowed_input_nums.count(n); 

81 return True 

82 

83 def verify(self, node): 

84 "Verifies a node is consistent with ONNX language." 

85 if self.deprecated_: 

86 raise OnnxCheckError( # pragma: no cover 

87 f"Operator '{self.name_}' has been deprecated since " 

88 f"version {self.since_version_}.", 

89 node) 

90 

91 # Check the number of inputs. 

92 if (len(node.input) < self.min_input_ or 

93 len(node.input) > self.max_input_): 

94 raise OnnxCheckError( # pragma: no cover 

95 f"Node '{node.name}' has input size {len(node.input)} " 

96 f"not in range [min={self.min_input_}, " 

97 f"max={self.max_input_}].", 

98 node) 

99 

100 if not self.num_inputs_allowed(len(node.input)): 

101 raise OnnxCheckError( # pragma: no cover 

102 f"Node '{node.name}' has input size {len(node.input)} " 

103 f"not in allowed input sizes.", 

104 node) 

105 

106 # Check the number of outputs. 

107 if (len(node.output) < self.min_output_ or 

108 len(node.output) > self.max_output_): 

109 raise OnnxCheckError( # pragma: no cover 

110 f"Node '{node.name}' has output size {len(node.output)} " 

111 f"not in range [min={self.min_output_}, " 

112 f"max={self.max_output_}].", 

113 node) 

114 

115 if not self.num_outputs_allowed(len(node.output)): 

116 raise OnnxCheckError( # pragma: no cover 

117 f"Node '{node.name}' has output size {len(node.output)} " 

118 f"not in allowed output sizes.", 

119 node) 

120 

121 # Check the values of inputs / outputs 

122 for in_idx in range(len(node.input)): 

123 if in_idx >= len(self.inputs_): 

124 if (not self.inputs_ and 

125 OpSchema.FormalParameterOption.Variadic == 

126 self.inputs_.back().GetOption()): 

127 # The last input formal parameter should be variadic. 

128 break 

129 else: 

130 raise OnnxCheckError( # pragma: no cover 

131 f"Node '{node.name}' has more inputs (" 

132 f"{len(node.input)} than declared {len(self.inputs_)}. " 

133 f"in op definition.", 

134 node) 

135 

136 if (not node.input[in_idx] and 

137 OpSchema.FormalParameterOption.Single == 

138 self.inputs_[in_idx].GetOption()): 

139 raise OnnxCheckError( # pragma: no cover 

140 f"Node '{node.name}' input[{in_idx}] is marked single but " 

141 f"has an empty string in the graph.", 

142 node) 

143 

144 for out_idx in range(len(node.output)): 

145 if out_idx >= len(self.outputs_): 

146 if (not self.outputs_ and 

147 OpSchema.FormalParameterOption.Variadic == 

148 self.outputs_.back().GetOption()): 

149 # The last output formal parameter should be variadic. 

150 break 

151 else: 

152 raise OnnxCheckError( # pragma: no cover 

153 f"Node '{node.name}' has more outputs (" 

154 f"{len(node.output)} than declared {len(self.outputs_)}. " 

155 f"in op definition.", 

156 node) 

157 

158 if (not node.output[out_idx] and 

159 OpSchema.FormalParameterOption.Single == 

160 self.outputs_[out_idx].GetOption()): 

161 raise OnnxCheckError( # pragma: no cover 

162 f"Node '{node.name}' output[{out_idx}] is marked single but " 

163 f"has an empty string in the graph.", 

164 node) 

165 

166 # An internal symbol is defined as starting with two underscores. Attributes 

167 # with names meeting this condition are considered implementation details 

168 # and should be ignored for the purpose of schema checking. 

169 def isInternalSymbol(sym): # pragma: no cover 

170 return len(sym) >= 2 and sym[0] == '_' and sym[1] == '_' 

171 

172 # Check attributes 

173 seen_attr_names = set() 

174 for attr_proto in node.attribute: # pragma: no cover 

175 name = attr_proto.name 

176 

177 if name in seen_attr_names: 

178 raise OnnxCheckError( # pragma: no cover 

179 f"Attribute '{name}' appeared multiple times.", 

180 node) 

181 seen_attr_names.add(name) 

182 

183 if name in self.attributes_: 

184 search = self.attributes_.index(name) 

185 else: 

186 search = -1 

187 expected_type = None 

188 if search != -1: 

189 expected_type = self.attributes_[search] 

190 elif self.allows_unchecked_attributes_ or isInternalSymbol(name): 

191 continue 

192 else: 

193 raise OnnxCheckError( # pragma: no cover 

194 f"Unrecognized attribute '{name}' for operator " 

195 f"'{node.op_type}'.", node) 

196 

197 # Type would be UNDEFINED if not set 

198 if attr_proto.type != expected_type: 

199 raise OnnxCheckError( # pragma: no cover 

200 f"Mismatched attribute type in '{node.name}' and " 

201 f"attribute '{name}'.", node) 

202 

203 # ref_attr_name is only valid when non-empty 

204 # we simply read default value if not present 

205 if not attr_proto.ref_attr_name: 

206 continue 

207 

208 # if attr_proto.type != UNDEFINED 

209 # we consider primitive types to be set even 

210 # if proto3 did not output default values into the stream 

211 # in which case we will read the default 

212 if expected_type in (AttributeProto.FLOAT, 

213 AttributeProto.INT, 

214 AttributeProto.STRING): 

215 pass 

216 elif expected_type == AttributeProto.TENSOR: 

217 if attr_proto.t.ByteSize == 0: 

218 raise OnnxCheckError( # pragma: no cover 

219 f"Attribute '{name}' is expected to have field " 

220 f"'t'.", node) 

221 elif expected_type == AttributeProto.SPARSE_TENSOR: 

222 if attr_proto.sparse_tensor.ByteSize == 0: 

223 raise OnnxCheckError( # pragma: no cover 

224 f"Attribute '{name}' is expected to have field " 

225 f"'sparse_tensor'.", node) 

226 elif expected_type == AttributeProto.GRAPH: 

227 if attr_proto.g.ByteSize == 0: 

228 raise OnnxCheckError( # pragma: no cover 

229 f"Attribute '{name}' is expected to have field " 

230 f"'g'.", node) 

231 if node.op_type == 'If' and len(attr_proto.g.input) > 0: 

232 raise OnnxCheckError( # pragma: no cover 

233 f"Attribute '{attr_proto.name}' of " 

234 f"operator If with name '{node.name}' must not have " 

235 f"inputs.", node) 

236 elif expected_type == AttributeProto.TYPE_PROTO: 

237 if attr_proto.tp.ByteSize == 0: 

238 raise OnnxCheckError( # pragma: no cover 

239 f"Attribute '{name}' is expected to have field " 

240 f"'tp'.", node) 

241 elif expected_type == AttributeProto.FLOATS: 

242 if attr_proto.floats.ByteSize == 0: 

243 raise OnnxCheckError( # pragma: no cover 

244 f"Attribute '{name}' is expected to have field " 

245 f"'floats'.", node) 

246 elif expected_type == AttributeProto.INTS: 

247 if attr_proto.ints.ByteSize == 0: 

248 raise OnnxCheckError( # pragma: no cover 

249 f"Attribute '{name}' is expected to have field " 

250 f"'ints'.", node) 

251 elif expected_type == AttributeProto.STRINGS: 

252 if attr_proto.strings.ByteSize == 0: 

253 raise OnnxCheckError( # pragma: no cover 

254 f"Attribute '{name}' is expected to have field " 

255 f"'strings'.", node) 

256 elif expected_type == AttributeProto.TENSORS: 

257 if attr_proto.tensors.ByteSize == 0: 

258 raise OnnxCheckError( # pragma: no cover 

259 f"Attribute '{name}' is expected to have field " 

260 f"'tensors'.", node) 

261 elif expected_type == AttributeProto.SPARSE_TENSORS: 

262 # Not adding check ... we should likely delete the check in all other 

263 # cases, which will not allow us to have an empty list as a valid value 

264 # for an attribute and this seems undesirable. 

265 pass 

266 elif expected_type == AttributeProto.GRAPHS: 

267 if attr_proto.graphs.ByteSize == 0: 

268 raise OnnxCheckError( # pragma: no cover 

269 f"Attribute '{name}' is expected to have field " 

270 f"'graphs'.", node) 

271 elif expected_type == AttributeProto.TYPE_PROTOS: 

272 if attr_proto.type_protos.ByteSize == 0: 

273 raise OnnxCheckError( # pragma: no cover 

274 f"Attribute '{name}' is expected to have field " 

275 f"'type_protos'.", node) 

276 else: 

277 raise OnnxCheckError( # pragma: no cover 

278 f"Attribute '{name}' has unknown expected type.", 

279 node) 

280 

281 for attr in self.attributes_: 

282 if not attr.required: 

283 continue 

284 if attr.name not in seen_attr_names: 

285 raise OnnxCheckError( # pragma: no cover 

286 f"Required attribute '{attr.name}' is missing.", 

287 node) 

288 

289 

290class CheckerContextDefaultRegistry: 

291 """ 

292 Registry. 

293 """ 

294 

295 def get_schema(self, op_type, version, domain): 

296 "Accessor." 

297 try: 

298 return Schema(get_schema(op_type, version, domain)) 

299 except SchemaError: 

300 return UndefinedSchema(op_type, version, domain) 

301 

302 def GetSchema(self, op_type, version, domain): 

303 "Accessor." 

304 return self.get_schema(op_type, version, domain) 

305 

306 

307class CheckerContext: 

308 """ 

309 Class hosting information about a graph. 

310 """ 

311 

312 def __init__(self, ctx=None): 

313 if ctx is None: 

314 self.ir_version_ = -1 

315 self.opset_imports_ = {} 

316 self.schema_registry_ = CheckerContextDefaultRegistry() 

317 self.model_dir_ = None 

318 self.is_main_graph_ = True 

319 else: 

320 self.ir_version_ = ctx.ir_version_ 

321 self.opset_imports_ = ctx.opset_imports_.copy() 

322 self.schema_registry_ = ctx.schema_registry_ 

323 self.model_dir_ = ctx.model_dir_ 

324 self.is_main_graph_ = ctx.is_main_graph_ 

325 

326 def get_ir_version(self): 

327 "Accessor." 

328 return self.ir_version_ 

329 

330 def set_ir_version(self, v): 

331 "Accessor." 

332 self.ir_version_ = v 

333 

334 def get_opset_imports(self): 

335 "Accessor." 

336 return self.opset_imports_ 

337 

338 def set_opset_imports(self, imps): 

339 "Accessor." 

340 self.opset_imports_ = imps 

341 

342 def is_main_graph(self): 

343 "Accessor." 

344 return self.is_main_graph_ 

345 

346 def set_is_main_graph(self, is_main_graph): 

347 "Accessor." 

348 self.is_main_graph_ = is_main_graph # pragma: no cover 

349 

350 def set_schema_registry(self, schema_registry): 

351 "Accessor." 

352 self.schema_registry_ = schema_registry # pragma: no cover 

353 

354 def get_schema_registry(self): 

355 "Accessor." 

356 return self.schema_registry_ 

357 

358 def set_model_dir(self, model_dir): 

359 "Accessor." 

360 self.model_dir_ = model_dir # pragma: no cover 

361 

362 def get_model_dir(self): 

363 "Accessor." 

364 return self.model_dir_ # pragma: no cover 

365 

366 

367class LexicalScopeContext: 

368 """ 

369 Construct an instance with the lexical scope from the parent graph to allow 

370 lookup of names from that scope via this_or_ancestor_graph_has. 

371 The caller must ensure parent_context remains valid for the entire lifetime 

372 of the new instance. Alternatively, if that cannot be guaranteed, create an 

373 instance with the default constructor and populate output_names with the 

374 values from the parent scope so the values are copied instead. 

375 """ 

376 

377 def __init__(self, parent_context=None): 

378 if parent_context is None: 

379 self.parent_context_ = None 

380 else: 

381 self.parent_context_ = parent_context.copy() 

382 self.output_names = set() 

383 

384 def add(self, name): 

385 "Adds a name to the context." 

386 self.output_names.add(name) 

387 

388 def this_graph_has(self, name): 

389 "Checks the context includes a specific name." 

390 return name in self.output_names 

391 

392 def this_or_ancestor_graph_has(self, name): 

393 "Checks the context and its ancestor includes a specific name." 

394 return self.this_graph_has(name) or ( 

395 self.parent_context_ and 

396 self.parent_context_.this_or_ancestor_graph_has(name)) 

397 

398 def copy(self): 

399 "Copies the instance." 

400 ctx = LexicalScopeContext(self.parent_context_) 

401 ctx.output_names = set(self.output_names) 

402 return ctx 

403 

404 

405def _enforce_has_field(proto, field): 

406 if not hasattr(proto, field): 

407 raise OnnxCheckError( # pragma: no cover 

408 f"Field '{field}' of '{proto}' is required but missing.", proto) 

409 

410 

411def _enforce_has_repeated_field(proto, field): 

412 if not getattr(proto, field + '_size')(): # pragma: no cover 

413 raise OnnxCheckError( # pragma: no cover 

414 f"Repeated field '{field}' of '{proto}' is required but missing.", proto) 

415 

416 

417def _enforce_non_empty_field(proto, field): 

418 if not getattr(proto, field): 

419 raise OnnxCheckError( # pragma: no cover 

420 f"Field '{field}' of '{proto}' is required to be non-empty.", proto) 

421 

422 

423def _check_value_info(value_info, ctx): 

424 _enforce_non_empty_field(value_info, "name") 

425 # Relax constraint for subgraph input/output. 

426 if not ctx.is_main_graph(): 

427 return # pragma: no cover 

428 _enforce_has_field(value_info, "type") 

429 value_case = None 

430 for n in dir(value_info.type): 

431 if n.endswith('_type'): 

432 tt = getattr(value_info.type, n) 

433 if tt.ByteSize() > 0: 

434 if value_case is not None: 

435 raise OnnxCheckError( # pragma: no cover 

436 f"Value_info {value_info} has multiple types.", 

437 value_info) 

438 value_case = n 

439 

440 if value_case == "tensor_type": 

441 _enforce_has_field(tt, "elem_type") 

442 _enforce_has_field(tt, "shape") 

443 elif value_case == "optional_type": # pragma: no cover 

444 tt = value_info.type.optional_type 

445 _enforce_has_field(tt, "elem_type") 

446 elif value_case == "sequence_type": # pragma: no cover 

447 tt = value_info.type.sequence_type 

448 _enforce_has_field(tt, "elem_type") 

449 elif value_case == "map_type": # pragma: no cover 

450 tt = value_info.type.map_type 

451 _enforce_has_field(tt, "key_type") 

452 _enforce_has_field(tt, "value_type") 

453 elif value_case == "opaque_type": # pragma: no cover 

454 pass 

455 elif value_case == "sparse_tensor_type": # pragma: no cover 

456 tt = value_info.type.sparse_tensor_type 

457 _enforce_has_field(tt, "elem_type") 

458 _enforce_has_field(tt, "shape") 

459 else: 

460 raise OnnxCheckError( # pragma: no cover 

461 f"Unrecognized type value case (value_info name '{value_info.name}' " 

462 f"value_case={value_case!r}.", value_info) 

463 

464 

465def _check_data_field(tensor, field, num_value_fields): 

466 at = getattr(tensor, field) 

467 has = len(at) 

468 if has: 

469 num_value_fields[0] += 1 # pylint: disable=E1137 

470 value_field = getattr(tensor, field) 

471 return value_field 

472 return None 

473 

474 

475def _check_field(tensor, field, value_field, nelem): 

476 if nelem != 0 and len(getattr(tensor, field)): # pragma: no cover 

477 raise OnnxCheckError( # pragma: no cover 

478 f"values of data_type '{tensor.data_type} " 

479 f"should be stored in field '{field}' " 

480 f"instead of '{value_field}'.", 

481 tensor) 

482 

483 

484def _check_tensor(tensor, ctx): 

485 

486 _enforce_has_field(tensor, "data_type") 

487 if tensor.data_type == TensorProto.UNDEFINED: 

488 raise OnnxCheckError( # pragma: no cover 

489 f"Setting data_type field (tensor name '{tensor.name}' " 

490 f"to UNDEFINED is not allowed.", tensor) 

491 

492 num_value_fields = [0] 

493 

494 value_field = ( 

495 _check_data_field(tensor, "float_data", num_value_fields) or 

496 _check_data_field(tensor, "int32_data", num_value_fields) or 

497 _check_data_field(tensor, "string_data", num_value_fields) or 

498 _check_data_field(tensor, "int64_data", num_value_fields) or 

499 _check_data_field(tensor, "raw_data", num_value_fields) or 

500 _check_data_field(tensor, "double_data", num_value_fields) or 

501 _check_data_field(tensor, "uint64_data", num_value_fields)) 

502 

503 num_value_fields = num_value_fields[0] 

504 

505 stored_externally = ( 

506 hasattr(tensor, 'data_location') and 

507 tensor.data_location == TensorProto.EXTERNAL) 

508 if stored_externally: 

509 if num_value_fields != 0: # pragma: no cover 

510 raise OnnxCheckError( # pragma: no cover 

511 f"Data of TensorProto ( tensor name: f{tensor.name}) " 

512 f"is stored externally and should not have data field: " 

513 f"{value_field}.", tensor) 

514 

515 has_location = False 

516 for entry in tensor.external_data(): # pragma: no cover 

517 # if entry.has_key() and entry.has_value() and entry.key() == "location": 

518 if entry.has_value() and entry.key() == "location": 

519 has_location = True 

520 data_path = os.path.join(ctx.get_model_dir(), entry.value()) 

521 # use stat to check whether the file exists 

522 if os.stat(data_path).st_size != 0: 

523 raise OnnxCheckError( # pragma: no cover 

524 f"Data of TensorProto ( tensor name: {tensor.name} " 

525 f"should be stored in {data_path}, but it doesn't " 

526 "exist or is not accessible.", tensor) 

527 if not has_location: 

528 raise OnnxCheckError( # pragma: no cover 

529 f"TensorProto tensor name {tensor.name} is stored externally " 

530 f"but doesn't have a location.", 

531 tensor) 

532 return 

533 

534 nelem = 1 

535 for x in tensor.dims: 

536 nelem *= x 

537 

538 if nelem == 0 and num_value_fields != 0: 

539 raise OnnxCheckError( # pragma: no cover 

540 f"TensorProto (tensor name f{tensor.name} " 

541 f"is 0-element but contains data!", 

542 tensor) 

543 if nelem != 0 and num_value_fields != 1: 

544 raise OnnxCheckError( # pragma: no cover 

545 f"TensorProto (tensor name: {tensor.name} " 

546 f"should contain one and only one value field.", 

547 tensor) 

548 if hasattr(tensor, 'raw_data') and len(tensor.raw_data) > 0: 

549 if tensor.data_type == TensorProto.STRING: 

550 raise OnnxCheckError( # pragma: no cover 

551 f"STRING data (tensor name: f{tensor.name} " 

552 f"should not be stored in raw_data field", 

553 tensor) 

554 else: # pragma: no cover 

555 if tensor.data_type in (TensorProto.FLOAT, 

556 TensorProto.COMPLEX64): 

557 _check_field(tensor, "float_data", value_field, nelem) 

558 elif tensor.data_type in (TensorProto.DOUBLE, 

559 TensorProto.COMPLEX128): 

560 _check_field(tensor, "double_data", value_field, nelem) 

561 elif tensor.data_type in (TensorProto.INT32, 

562 TensorProto.UINT8, 

563 TensorProto.INT8, 

564 TensorProto.UINT16, 

565 TensorProto.INT16, 

566 TensorProto.BOOL, 

567 TensorProto.FLOAT16, 

568 TensorProto.BFLOAT16): 

569 _check_field(tensor, "int32_data", value_field, nelem) 

570 elif tensor.data_type == TensorProto.INT64: 

571 _check_field(tensor, "int64_data", value_field, nelem) 

572 elif tensor.data_type == TensorProto.INT64: 

573 _check_field(tensor, "int64_data", value_field, nelem) 

574 elif tensor.data_type in (TensorProto.UINT32, 

575 TensorProto.UINT64): 

576 _check_field(tensor, "uint64_data", value_field, nelem) 

577 elif tensor.data_type == TensorProto.STRING: 

578 _check_field(tensor, "string_data", value_field, nelem) 

579 else: 

580 raise OnnxCheckError( # pragma: no cover 

581 f"Unrecognized data_type (tensor name: {tensor.name} " 

582 f"): {tensor.data_type}.", 

583 tensor) 

584 

585 

586def _check_sequence(sequence, ctx): # pragma: no cover 

587 _enforce_has_field(sequence, "elem_type") 

588 if sequence.elem_type == SequenceProto.TENSOR: 

589 for tensor in sequence.tensor_values(): 

590 _check_tensor(tensor, ctx) 

591 elif sequence.elem_type == SequenceProto.SPARSE_TENSOR: 

592 for sparse_tensor in sequence.sparse_tensor_values(): 

593 _check_sparse_tensor(sparse_tensor, ctx) 

594 elif sequence.elem_type == SequenceProto.SEQUENCE: 

595 for seq in sequence.sequence_values(): 

596 _check_sequence(seq, ctx) 

597 elif sequence.elem_type == SequenceProto.MAP: 

598 for map in sequence.map_values(): 

599 _check_map(map, ctx) 

600 else: 

601 raise OnnxCheckError( # pragma: no cover 

602 f"Sequence ( Structure name: {sequence.name}, " 

603 f"elem_type: {sequence.elem_type}) is not have " 

604 f"a valid element type.", 

605 sequence) 

606 

607 

608def _check_optional(optional, ctx): # pragma: no cover 

609 _enforce_has_field(optional, "elem_type") 

610 if optional.elem_type == OptionalProto.UNDEFINED: 

611 return 

612 elif optional.elem_type == OptionalProto.TENSOR: 

613 if optional.has_tensor_value(): 

614 _check_tensor(optional.tensor_value(), ctx) 

615 elif optional.elem_type == OptionalProto.SPARSE_TENSOR: 

616 if optional.has_sparse_tensor_value(): 

617 _check_sparse_tensor(optional.sparse_tensor_value(), ctx) 

618 elif optional.elem_type == OptionalProto.SEQUENCE: 

619 if optional.has_sequence_value(): 

620 _check_sequence(optional.sequence_value(), ctx) 

621 elif optional.elem_type == OptionalProto.MAP: 

622 if (optional.has_map_value()): 

623 _check_map(optional.map_value(), ctx) 

624 else: 

625 raise OnnxCheckError( # pragma: no cover 

626 f"Optional ( Structure name: {optional.name}, " 

627 f"elem_type: {optional.elem_type}) is not " 

628 f"have a valid element type.", 

629 optional) 

630 

631 

632def _check_map(map, ctx): # pragma: no cover 

633 _enforce_has_field(map, 'key_type') 

634 if map.key_type() == TensorProto.UNDEFINED: 

635 raise OnnxCheckError( # pragma: no cover 

636 f"Setting key_type field (map name: '{map.name}') " 

637 f"to UNDEFINED is not allowed.", 

638 map) 

639 # Check if key is a valid type, specifically INT8, INT16, INT32, INT64, 

640 # UINT8, UINT16, UINT32, UINT64, or STRING. 

641 if map.key_type() in (TensorProto.FLOAT, TensorProto.BOOL, 

642 TensorProto.FLOAT16, TensorProto.COMPLEX64, 

643 TensorProto.COMPLEX128): 

644 raise OnnxCheckError( # pragma: no cover 

645 f"Setting key_type field (map name: {map.name}) " 

646 f" to invalid TensorProto key_type {map.key_type()} " 

647 f"is not allowed", 

648 map) 

649 # MapProto will use either keys or string_keys, so only one should be > 0. 

650 if map.keys_size() > 0 and map.string_keys_size() > 0: 

651 raise OnnxCheckError( # pragma: no cover 

652 f"Map (name: '{map.name}') should not " 

653 f"contain more than one keys field.", 

654 map) 

655 

656 num_keys = map.keys_size() + map.string_keys_size() 

657 num_values = 0 

658 

659 _enforce_has_field(map, 'values') 

660 _check_sequence(map.values(), ctx) 

661 

662 if map.values().elem_type == SequenceProto.TENSOR: 

663 num_values = map.values().tensor_values_size() 

664 elif map.values().elem_type == SequenceProto.SPARSE_TENSOR: 

665 num_values = map.values().sparse_tensor_values_size() 

666 elif map.values().elem_type == SequenceProto.SEQUENCE: 

667 num_values = map.values().sequence_values_size() 

668 elif map.values().elem_type == SequenceProto.MAP: 

669 num_values = map.values().map_values_size() 

670 

671 if num_keys != num_values: 

672 raise OnnxCheckError( # pragma: no cover 

673 f"Length of map keys and map values are not the same " 

674 f"(map name: '{map.name}').", 

675 map) 

676 

677 

678def _parse_data(dtype, indices): 

679 if dtype != indices.dtype: 

680 raise OnnxCheckError( # pragma: no cover 

681 f"Wrong element type {indices.dtype}, expected is {dtype}.", 

682 None) 

683 

684 

685def _check_sparse_tensor_indices_1( # pragma: no cover 

686 indices, sparse_tensor_proto, nnz): # pragma: no cover 

687 """ 

688 Check that the index data stored in a SparseTensorProto is valid. 

689 indices: a 1-dimensional tensor; indices[i] represents the 

690 linearized index value for the i-th nonzero value. 

691 """ 

692 dense_rank = sparse_tensor_proto.dims_size() 

693 dense_size = 1 

694 for i in range(dense_rank): 

695 dense_size *= sparse_tensor_proto.dims(i) 

696 if indices.dims(0) != nnz: 

697 raise OnnxCheckError( # pragma: no cover 

698 f"Sparse tensor indices '{indices.name}' has " 

699 f"{indices.dims(0)} values, but NNZ is {nnz}.", 

700 sparse_tensor_proto) 

701 

702 # Check if indices appear in ascending order, and if they have valid 

703 # values. The i-th value in index_data is the linear index of the i-th 

704 # non-zero value. 

705 index_data = _parse_data(numpy.int64, indices) 

706 

707 prev_index = -1 

708 for i in range(nnz): 

709 curr_index = index_data[i] # linearized index of i-th value 

710 if curr_index < 0 or curr_index >= dense_size: 

711 raise OnnxCheckError( # pragma: no cover 

712 f"Sparse tensor '{indices.name}' index value at " 

713 f"position [{i}] out of range [0, {dense_size - 1}].", 

714 sparse_tensor_proto) 

715 if curr_index <= prev_index: 

716 raise OnnxCheckError( # pragma: no cover 

717 f"Sparse tensor '{indices.name}' index value at " 

718 f"position [{i}] not in sorted order.", 

719 sparse_tensor_proto) 

720 prev_index = curr_index 

721 

722 

723def _check_sparse_tensor_indices_2( # pragma: no cover 

724 indices, sparse_tensor_proto, nnz): # pragma: no cover 

725 """ 

726 Check that the index data stored in a SparseTensorProto is valid. 

727 indices: a 2-dimensional tensor; indices[i,j] represents the j-th 

728 index value for the i-th nonzero value. 

729 """ 

730 dense_rank = sparse_tensor_proto.dims_size() 

731 if indices.dims(0) != nnz: 

732 raise OnnxCheckError( # pragma: no cover 

733 f"Sparse tensor indices '{indices.name}' " 

734 f"first dimension size does not equal NNZ={nnz}.", 

735 sparse_tensor_proto) 

736 

737 if indices.dims(1) != dense_rank: 

738 raise OnnxCheckError( # pragma: no cover 

739 f"Sparse tensor indices '{indices.name}' " 

740 f"second dimension size does not equal " 

741 f"dense_rank={dense_rank}.", 

742 sparse_tensor_proto) 

743 

744 # Check if indices appear in ascending order, and if they have valid 

745 # values. 

746 index_data = _parse_data(numpy.int64, indices) 

747 prev_index = -1 

748 for i in range(nnz): 

749 curr_index = 0 # linearized index of i-th value 

750 for j in range(dense_rank): 

751 index_ij = index_data[i * dense_rank + j] 

752 if index_ij < 0 or index_ij >= sparse_tensor_proto.dims(j): 

753 raise OnnxCheckError( # pragma: no cover 

754 f"Sparse tensor '{indices.name}' index value " 

755 f"at position [{i}, {j}] out of range.", 

756 sparse_tensor_proto) 

757 curr_index = curr_index * sparse_tensor_proto.dims(j) + index_ij 

758 if curr_index <= prev_index: 

759 raise OnnxCheckError( # pragma: no cover 

760 f"Sparse tensor '{indices.name}' index value " 

761 f"at position [{i}] not in lexicographic sorted " 

762 "order.", sparse_tensor_proto) 

763 prev_index = curr_index 

764 

765 

766def _check_sparse_tensor(sparse_tensor_proto, ctx): # pragma: no cover 

767 _enforce_has_field(sparse_tensor_proto, "values") 

768 

769 values = sparse_tensor_proto.values() 

770 _check_tensor(values, ctx) 

771 

772 # values must be a tensor of shape [NNZ] 

773 # Currently we restrict the value associated with a particular index-tuple 

774 # to be a single value. In the future, if there is a requirement, 

775 # we may extend this to permit the value to be a "sub-tensor", in which 

776 # case values will have dimension > 1. 

777 if values.dims_size() != 1: 

778 raise OnnxCheckError( # pragma: no cover 

779 f"Sparse tensor values '{values.name}' must have rank 1.", 

780 sparse_tensor_proto) 

781 

782 nnz = values.dims(0) 

783 dense_rank = sparse_tensor_proto.dims_size() 

784 if dense_rank == 0: 

785 raise OnnxCheckError( # pragma: no cover 

786 f"Sparse tensor '{values.name}' must have a " 

787 f"dense-rank > 0.", sparse_tensor_proto) 

788 

789 for i in range(dense_rank): 

790 if sparse_tensor_proto.dims(i) <= 0: 

791 raise OnnxCheckError( # pragma: no cover 

792 f"Sparse tensor '{values.name} dimensions " 

793 f"are not positive.", sparse_tensor_proto) 

794 

795 if sparse_tensor_proto.has_indices(): 

796 indices = sparse_tensor_proto.indices() 

797 _check_tensor(indices, ctx) 

798 if indices.data_type != TensorProto.INT64: 

799 raise OnnxCheckError( # pragma: no cover 

800 f"Sparse tensor indices '{indices.name}' must have INT64 type.", 

801 sparse_tensor_proto) 

802 

803 if indices.dims().size() == 1: 

804 # Indices in linearized format 

805 _check_sparse_tensor_indices_1(indices, sparse_tensor_proto, nnz) 

806 return 

807 if indices.dims().size() == 2: 

808 # Check COO-style index. E.g., an index for a 3D tensor is a 3-tuple. 

809 _check_sparse_tensor_indices_2(indices, sparse_tensor_proto, nnz) 

810 return 

811 raise OnnxCheckError( # pragma: no cover 

812 f"Sparse tensor indices '{indices.name}' must have rank 1 or 2.", 

813 sparse_tensor_proto) 

814 elif nnz != 0: 

815 raise OnnxCheckError( # pragma: no cover 

816 f"Sparse tensor '{values.name}' has no index values.", 

817 sparse_tensor_proto) 

818 

819 

820def check_attribute(attr, ctx, lex_ctx): # pragma: no cover 

821 """ 

822 NB: This is a generic "attribute well-formedness" check, it doesn't 

823 actually test if an attribute is valid per a schema. 

824 """ 

825 _enforce_non_empty_field(attr, "name") 

826 

827 if ctx.get_ir_version() >= 0x00000002: 

828 _enforce_has_field(attr, "type") 

829 

830 used_fields = 0 

831 

832 def check_type(expected_type): 

833 if hasattr(attr, 'type') and attr.type != expected_type: 

834 raise OnnxCheckError( # pragma: no cover 

835 f"Type field and data field mismatch in attribute '{attr.name}'.", 

836 attr) 

837 

838 def check_singular_field(field, itype): 

839 if hasattr(attr, field): 

840 check_type(itype) 

841 return 1 

842 return 0 

843 

844 def check_repeated_field(field, type): 

845 if getattr(attr, field + '_size')() > 0: 

846 check_type(type) 

847 return 1 

848 return 0 

849 

850 used_fields += check_singular_field("f", AttributeProto.FLOAT) 

851 used_fields += check_singular_field("i", AttributeProto.INT) 

852 used_fields += check_singular_field("s", AttributeProto.STRING) 

853 used_fields += check_singular_field("t", AttributeProto.TENSOR) 

854 used_fields += check_singular_field("g", AttributeProto.GRAPH) 

855 used_fields += check_singular_field("tp", AttributeProto.TYPE_PROTO) 

856 used_fields += check_singular_field("sparse_tensor", 

857 AttributeProto.SPARSE_TENSOR) 

858 used_fields += check_repeated_field("floats", AttributeProto.FLOATS) 

859 used_fields += check_repeated_field("ints", AttributeProto.INTS) 

860 used_fields += check_repeated_field("strings", AttributeProto.STRINGS) 

861 used_fields += check_repeated_field("tensors", AttributeProto.TENSORS) 

862 used_fields += check_repeated_field("graphs", AttributeProto.GRAPHS) 

863 used_fields += check_repeated_field("sparse_tensors", 

864 AttributeProto.SPARSE_TENSORS) 

865 used_fields += check_repeated_field("type_protos", 

866 AttributeProto.TYPE_PROTOS) 

867 

868 # Normally, used_fields is expected to be 1. 

869 # In proto3, when the value to be set is type default value 

870 # (say 0 for int), used_fields may be 0. 

871 if used_fields > 1: 

872 raise OnnxCheckError( # pragma: no cover 

873 f"Attribute (name: '{attr.name}') should not " 

874 f"contain more than one value field.", 

875 attr) 

876 

877 if not ctx.is_main_graph(): 

878 # It's an attribute of a node in function body. 

879 if attr.has_ref_attr_name() and used_fields != 0: 

880 # The attribute proto is supposed to refer to data outside and does not 

881 # have its own value field set. 

882 raise OnnxCheckError( # pragma: no cover 

883 f"Attribute (name: '{attr.name}') should refer " 

884 f"to attribute in parent node.", 

885 attr) 

886 

887 if attr.has_t(): 

888 _check_tensor(attr.t(), ctx) 

889 

890 if attr.has_sparse_tensor(): 

891 _check_sparse_tensor(attr.sparse_tensor(), ctx) 

892 

893 if attr.has_g(): 

894 subgraph_ctx = CheckerContext(ctx) 

895 subgraph_ctx.set_is_main_graph(False) 

896 _check_graph(attr.g(), subgraph_ctx, lex_ctx) 

897 

898 for tensor in attr.tensors(): 

899 _check_tensor(tensor, ctx) 

900 

901 for sparse_tensor in attr.sparse_tensors(): 

902 _check_sparse_tensor(sparse_tensor, ctx) 

903 

904 if attr.graphs().size() > 0: 

905 subgraph_ctx = CheckerContext(ctx) 

906 subgraph_ctx.set_is_main_graph(False) 

907 for graph in attr.graphs(): 

908 _check_graph(graph, subgraph_ctx, lex_ctx) 

909 

910 

911def _check_node(node, ctx, lex_ctx): 

912 _enforce_non_empty_field(node, "op_type") 

913 

914 if not node.input and not node.output: 

915 raise OnnxCheckError( # pragma: no cover 

916 f"NodeProto (name: '{node.name}', type: '{node.op_type}') " 

917 f"has zero input and zero output.", 

918 node) 

919 

920 # If encounter experimental op, stop checking 

921 if check_is_experimental_op(node.op_type): 

922 warnings.warn( # pragma: no cover 

923 f"Warning: Checker does not support models " 

924 f"with experimental ops: '{node.op_type}'.") 

925 return # pragma: no cover 

926 

927 # Resolve domain for node 

928 opset_imports = ctx.get_opset_imports() 

929 if node.domain not in opset_imports: 

930 raise OnnxCheckError( # pragma: no cover 

931 f"No opset import for domain '{node.domain}'.", 

932 node) 

933 domain_version = opset_imports[node.domain] 

934 

935 for attr in node.attribute: 

936 check_attribute(attr, ctx, lex_ctx) 

937 

938 schema = ctx.get_schema_registry().GetSchema( 

939 node.op_type, domain_version, node.domain) 

940 if not schema: 

941 if node.domain in (ONNX_DOMAIN, AI_ONNX_ML_DOMAIN, # pragma: no cover 

942 "ai.onnx", AI_ONNX_TRAINING_DOMAIN): 

943 # fail the checker if op in built-in domains has no schema 

944 raise OnnxCheckError( # pragma: no cover 

945 f"No Op registered for '{node.op_type}' with domain_version " 

946 f"of {domain_version}.", 

947 node) 

948 else: 

949 # TODO: expose the registration of the op schemas appropriately in 

950 # python, so we can load and register operators in other domains 

951 # before we complete the above todo, let's skip the schema check for now 

952 pass # pragma: no cover 

953 elif schema.deprecated_: 

954 raise OnnxCheckError( # pragma: no cover 

955 f"Op registered for '{node.op_type}' is deprecated " 

956 f"in domain_version of {domain_version}.", 

957 node) 

958 else: 

959 schema.verify(node) 

960 

961 

962def _check_graph(graph, ctx, parent_lex): 

963 _enforce_non_empty_field(graph, "name") 

964 

965 for value_info in graph.input: 

966 _check_value_info(value_info, ctx) 

967 for value_info in graph.output: 

968 _check_value_info(value_info, ctx) 

969 

970 # Inherit values available in outer scope 

971 # Note that we do not allow shadowing, so the presence of an already-defined 

972 # name is always an error. 

973 lex_ctx = LexicalScopeContext(parent_lex) 

974 

975 for value_info in graph.input: 

976 # TODO: If shadowing isn't allowed, this should maybe use 

977 # this_or_ancestor_graph_has 

978 if lex_ctx.this_graph_has(value_info.name): 

979 raise OnnxCheckError( # pragma: no cover 

980 f"Graph must be in single static assignment (SSA) form, " 

981 f"however '{value_info.name}' has been used as " 

982 f"graph input names multiple times.", 

983 graph) 

984 lex_ctx.add(value_info.name) 

985 

986 initializer_name_checker = set() 

987 # std::unordered_set<std::reference_wrapper<const std::string>, std::hash<std::string>, std::equal_to<std::string>> 

988 

989 for init in graph.initializer: 

990 _enforce_has_field(init, "name") 

991 name = init.name 

992 if not name: 

993 raise OnnxCheckError( # pragma: no cover 

994 f"Tensor initializers must have a non-empty name.", 

995 graph) 

996 

997 if name in initializer_name_checker: 

998 raise OnnxCheckError( # pragma: no cover 

999 f"'{name}' initializer name is not unique.", 

1000 graph) 

1001 initializer_name_checker.add(name) 

1002 

1003 _check_tensor(init, ctx) 

1004 

1005 if ctx.get_ir_version() <= 0x00000003: 

1006 # Initializers are a subset of graph inputs for IR_VERSION <= 3 

1007 if not lex_ctx.this_graph_has(name): 

1008 raise OnnxCheckError( # pragma: no cover 

1009 f"'{name}' in initializer but not in graph input.", 

1010 graph) 

1011 else: 

1012 # An initializer is allowed to have the same name as an input, 

1013 # but is not required to (for IR_VERSION >= 4) 

1014 lex_ctx.add(name) 

1015 

1016 for sparse_init in graph.sparse_initializer: # pragma: no cover 

1017 values = sparse_init.values() 

1018 _enforce_has_field(values, name) 

1019 name = values.name 

1020 if name.empty(): 

1021 raise OnnxCheckError( # pragma: no cover 

1022 f"Sparse tensor initializers must have a non-empty name.", 

1023 graph) 

1024 if name in initializer_name_checker: 

1025 raise OnnxCheckError( # pragma: no cover 

1026 f"'{name}' initializer name is not unique across " 

1027 f"initializers and sparse_initializers.", 

1028 graph) 

1029 initializer_name_checker.add(name) 

1030 _check_sparse_tensor(sparse_init, ctx) 

1031 lex_ctx.add(name) 

1032 

1033 errors = [] 

1034 for node in graph.node: 

1035 # nodes must be in topologically sorted order 

1036 for input in node.input: 

1037 # explicit optional input 

1038 if not input: 

1039 continue # pragma: no cover 

1040 if not lex_ctx.this_or_ancestor_graph_has(input): 

1041 raise OnnxCheckError( # pragma: no cover 

1042 f"Nodes in a graph must be topologically sorted, however " 

1043 f"input '{input}' of node name '{node.name}', type " 

1044 f"'{node.op_type}' is not output of any previous nodes.", 

1045 node) 

1046 

1047 # This needs to happen before SSA check since we don't want to recurse and 

1048 # find that outputs from control flow ops are colliding with names in the 

1049 # inner block 

1050 

1051 try: 

1052 _check_node(node, ctx, lex_ctx) 

1053 except OnnxCheckError as e: 

1054 errors.append(e) 

1055 

1056 # check for SSA form 

1057 for output in node.output: 

1058 # optional output 

1059 if not output: 

1060 continue 

1061 

1062 if lex_ctx.this_or_ancestor_graph_has(output): 

1063 raise OnnxCheckError( # pragma: no cover 

1064 f"Graph must be in single static assignment " 

1065 f"(SSA) form, however '{output}' " 

1066 f"has been used as output names multiple times.", 

1067 graph) 

1068 lex_ctx.add(output) 

1069 

1070 

1071def _get_version_for_domain(domain, opset_imports): # pragma: no cover 

1072 # Utilify function to get the imported version of domain from opset imports 

1073 # Returns -1 if requested domain is not found in the opset_imports 

1074 if domain not in opset_imports.end(): 

1075 return -1 

1076 return opset_imports[domain] 

1077 

1078 

1079def _check_opset_compatibility( # pragma: no cover 

1080 node, ctx, func_opset_imports, model_opset_imports): # pragma: no cover 

1081 func_opset_version = _get_version_for_domain( 

1082 node.domain, func_opset_imports) 

1083 model_opset_version = _get_version_for_domain( 

1084 node.domain, model_opset_imports) 

1085 

1086 if func_opset_version == -1: 

1087 raise OnnxCheckError( # pragma: no cover 

1088 f"No Opset registered for domain '{node.domain}'.", 

1089 node) 

1090 

1091 if model_opset_version == -1: 

1092 # model does not include opset import for a node present in function body. 

1093 # This is ok as along as the opset import is present in function level opset imports. 

1094 return 

1095 

1096 if func_opset_version == model_opset_version: 

1097 # both versions are same, no need to verify schema. 

1098 return 

1099 

1100 schema_for_model_import = ctx.get_schema_registry().GetSchema( 

1101 node.op_type, model_opset_version, node.domain) 

1102 schema_for_function_import = ctx.get_schema_registry().GetSchema( 

1103 node.op_type, func_opset_version, node.domain) 

1104 

1105 if not schema_for_model_import and not schema_for_function_import: 

1106 # the op belongs to a custom domain so we cannot verify schema 

1107 return 

1108 

1109 # if schema is present for 1 but not other or the schema since 

1110 # versions do not match then raise an error 

1111 if (not schema_for_model_import or not schema_for_function_import or 

1112 schema_for_function_import.since_version() != schema_for_model_import.since_version()): 

1113 raise OnnxCheckError( # pragma: no cover 

1114 f"Opset import for domain '{node.domain}' in function op " 

1115 f"'{node.op_type} is not compatible with the version " 

1116 f"imported by model. FunctionOp imports version " 

1117 f"{func_opset_version} whereas model imports version " 

1118 f"{model_opset_version}.", 

1119 node) 

1120 

1121 

1122def _check_model_local_functions(model, ctx, parent_lex): # pragma: no cover 

1123 # make a copy of model opset imports to maintain a main copy of opset imports across the model and 

1124 # all model local functions to verify opset compatibility 

1125 model_opset_imports = ctx.get_opset_imports() 

1126 

1127 # merge the opset imports from every function in model_opset_imports 

1128 # only add the opset import if an entry for it does not exist in model_opset_imports 

1129 # if there is an entry then the compatibility will be checked later 

1130 # on in check_opset_compatibility 

1131 # called by check_function. 

1132 for function_proto in model.functions: 

1133 for opset_import in function_proto.opset_import(): 

1134 if _get_version_for_domain(opset_import.domain, model_opset_imports) == -1: 

1135 model_opset_imports[opset_import.domain] = opset_import.version 

1136 

1137 ctx_copy = CheckerContext(ctx) 

1138 ctx_copy.set_opset_imports(model_opset_imports) 

1139 

1140 for function_proto in model.functions: 

1141 _check_function(function_proto, ctx_copy, parent_lex) 

1142 

1143 

1144def _check_function(function, ctx, parent_lex): # pragma: no cover 

1145 _enforce_non_empty_field(function, "name") 

1146 

1147 if ctx.get_ir_version() >= 0x00000008: 

1148 _enforce_has_field(function, "domain") 

1149 

1150 model_opset_imports = ctx.get_opset_imports() 

1151 ctx_copy = CheckerContext(ctx) 

1152 

1153 func_opset_imports = {} 

1154 for relied_opset in function.opset_import(): 

1155 func_opset_imports[relied_opset.domain] = int(relied_opset.version) 

1156 

1157 ctx_copy.set_opset_imports(func_opset_imports) 

1158 

1159 lex_ctx = LexicalScopeContext(parent_lex) 

1160 

1161 for input in function.input: 

1162 # TODO: If shadowing isn't allowed, this should maybe use 

1163 # this_or_ancestor_graph_has 

1164 if lex_ctx.this_graph_has(input): 

1165 raise OnnxCheckError( # pragma: no cover 

1166 f"Graph must be in single static assignment (SSA) form, " 

1167 f"however '{input}' has been used multiple times.", 

1168 function) 

1169 lex_ctx.add(input) 

1170 

1171 outputs = set() 

1172 for output in function.output: 

1173 if output in outputs: 

1174 raise OnnxCheckError( # pragma: no cover 

1175 f"Function '{function.name}' should not have " 

1176 f"duplicate outputs specified.", 

1177 function) 

1178 outputs.add(output) 

1179 

1180 attrs = set() 

1181 for attr in function.attribute: 

1182 if attr in attrs: 

1183 raise OnnxCheckError( # pragma: no cover 

1184 f"Function '{function.name}' should not have " 

1185 f"duplicate attributes specified.", 

1186 function) 

1187 

1188 for node in function.node(): 

1189 # nodes must be in topologically sorted order 

1190 for input in node.input: 

1191 # explicit optional input 

1192 if input.empty(): 

1193 continue 

1194 if not lex_ctx.this_graph_has(input): 

1195 raise OnnxCheckError( # pragma: no cover 

1196 f"Nodes in a function must be topologically sorted, " 

1197 f"however input '{input}' of node name '{node.name}' " 

1198 f"and type '{node.op_type}' is neither output " 

1199 f"of any previous nodes nor input of the function.", 

1200 function) 

1201 

1202 # check whether the opset version imported for a domain by function and model are 

1203 # compatible 

1204 _check_opset_compatibility( 

1205 node, ctx_copy, func_opset_imports, model_opset_imports) 

1206 _check_node(node, ctx_copy, lex_ctx) 

1207 

1208 # check for SSA form 

1209 for output in node.output: 

1210 # optional output 

1211 if output.empty(): 

1212 continue 

1213 

1214 if lex_ctx.this_or_ancestor_graph_has(output): 

1215 raise OnnxCheckError( # pragma: no cover 

1216 f"Function must be in single static assignment (SSA) " 

1217 f"form, however '{output}' has been used as output " 

1218 f"names multiple times.", 

1219 function) 

1220 lex_ctx.add(output) 

1221 

1222 

1223def _check_model(model, ctx): 

1224 if not model.ir_version: 

1225 raise OnnxCheckError( # pragma: no cover 

1226 f"The model does not have an ir_version set properly.", 

1227 model) 

1228 if model.ir_version > IR_VERSION: 

1229 raise OnnxCheckError( # pragma: no cover 

1230 f"Your model ir_version is higher than the checker's.", 

1231 model) 

1232 if len(model.metadata_props) > 1: # pragma: no cover 

1233 keys = set() 

1234 for entry in model.metadata_props: 

1235 if entry.key() in keys: 

1236 raise OnnxCheckError( # pragma: no cover 

1237 f"Your model has duplicate keys '{entry.key()}' " 

1238 f"in metadata_props.", model) 

1239 keys.add(entry.key()) 

1240 

1241 ctx.set_ir_version(int(model.ir_version)) 

1242 opset_imports = {} 

1243 for opset_import in model.opset_import: 

1244 opset_imports[opset_import.domain] = int(opset_import.version) 

1245 if model.ir_version >= 3: 

1246 if not opset_imports: 

1247 raise OnnxCheckError( # pragma: no cover 

1248 f"Model with IR version >= 3 must specify opset_import for " 

1249 f"ONNX ({opset_imports}).", 

1250 model) 

1251 elif not opset_imports: # pragma: no cover 

1252 opset_imports[ONNX_DOMAIN] = 1 

1253 else: 

1254 raise OnnxCheckError( # pragma: no cover 

1255 f"Model with IR version < 3 cannot have opset_import specified.", 

1256 model) 

1257 

1258 ctx.set_opset_imports(opset_imports) 

1259 lex_ctx = LexicalScopeContext() 

1260 _check_graph(model.graph, ctx, lex_ctx) 

1261 

1262 if ctx.get_ir_version() >= 0x00000008: 

1263 _check_model_local_functions(model, ctx, lex_ctx) 

1264 

1265 

1266def check_model(model): 

1267 """ 

1268 Checks a model is consistent with ONNX language. 

1269 The function fails if the model is not consistent. 

1270 

1271 :param model: :epkg:`ModelProto` 

1272 """ 

1273 ctx = CheckerContext() 

1274 if isinstance(model, bytes): 

1275 m = ModelProto() 

1276 m.ParseFromString(model) 

1277 _check_model(m, ctx) 

1278 else: 

1279 _check_model(model, ctx) 

1280 

1281 

1282experimental_ops = { 

1283 "ATen", 

1284 "Affine", 

1285 "ConstantFill", 

1286 "Crop", 

1287 "DynamicSlice", 

1288 "GRUUnit", 

1289 "GivenTensorFill", 

1290 "ImageScaler", 

1291 "ParametricSoftplus", 

1292 "Scale", 

1293 "ScaledTanh"} 

1294 

1295 

1296def check_is_experimental_op(node_op_type): 

1297 "Tells if an operator is experimentation." 

1298 return bool(experimental_ops & {node_op_type})