Coverage for mlprodict/onnx_tools/exports/tf2onnx_helper.py: 95%
309 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Helpers to run examples created with function
4@see fn export2tf2onnx.
5"""
6import collections
7import inspect
8import numpy
9from onnx.numpy_helper import from_array
10from onnx.helper import (
11 make_node, make_graph, make_model, set_model_props, make_tensor)
12from onnx import AttributeProto
13from ..onnx2py_helper import guess_dtype, guess_proto_dtype
14from ..onnx_tools import ensure_topological_order
17_make_name_id = 0
20def make_tf2onnx_code(opset, name=None, op_type=None, domain='',
21 inputs=None, outputs=None, attributes=None,
22 used=None, context=None, mark_inits=None, indent=8,
23 **unused):
24 """
25 Converts an ONNX operators into :epkg:`tf2onnx` code.
27 :param opset: target opset for the conversion (usually unused)
28 :param name: node name
29 :param op_type: operator type
30 :param domain: domain
31 :param inputs: inputs
32 :param outputs: outputs
33 :param attributes: attributes
34 :param used: dictionary `{k: v}`,
35 list of nodes taking *k* as input
36 :param context: whole context
37 :param mark_inits: marks initializer as replaced
38 :param indent: number of spaces to add on the second
39 and following rows
40 :return: code as str
41 """
42 def simplify(name, kind, force=False):
43 value = None
44 if (used is not None and name in used and
45 len(used[name]) == 1 and context is not None):
46 inits = context['initializers_dict']
47 if name in inits:
48 v = inits[name]
49 if v.dtype == numpy.int64 and v.size < 10:
50 value = v
51 if name not in mark_inits:
52 mark_inits[name] = []
53 mark_inits[name].append(v)
55 if value is None and force:
56 inits = context['initializers_dict']
57 if name not in inits:
58 raise RuntimeError( # pragma: no cover
59 "Unable to find init %r in %r value=%r." % (
60 name, list(sorted(inits)), value))
61 value = inits[name]
62 if kind == 'list': # pragma: no cover
63 if value is None:
64 return name
65 if len(value.shape) == 0:
66 return str(value)
67 return str(list(value))
68 if kind == 'list_var':
69 if value is None:
70 return f"varx[{name!r}]"
71 if len(value.shape) == 0:
72 return str(value)
73 return str(list(value))
74 raise NotImplementedError( # pragma: no cover
75 f"Unknown scenario to simplify ({kind!r}).")
77 rows = []
78 if op_type == 'Unsqueeze':
79 if len(inputs) == 2:
80 rows.append(
81 "node = GraphBuilder(ctx).make_unsqueeze("
82 "{'data': varx[%r], 'axes': %s}, return_node=True)"
83 "" % (inputs[0], simplify(inputs[1], 'list_var')))
84 else:
85 raise NotImplementedError( # pragma: no cover
86 f"Unable to create code for operator {op_type!r} (opset <= 12).")
87 elif op_type == 'Squeeze':
88 if len(inputs) == 1:
89 rows.append( # pragma: no cover
90 "node = GraphBuilder(ctx).make_squeeze("
91 "{'data': varx[%r]}, return_node=True)"
92 "" % (inputs[0], ))
93 elif len(inputs) == 2:
94 rows.append(
95 "node = GraphBuilder(ctx).make_squeeze("
96 "{'data': varx[%r], 'axes': %s}, return_node=True)"
97 "" % (inputs[0], simplify(inputs[1], 'list_var')))
98 else:
99 raise NotImplementedError( # pragma: no cover
100 f"Unable to create code for operator {op_type!r} (opset <= 12).")
101 elif op_type == 'Slice':
102 atts = dict(zip(['starts', 'ends', 'axes', 'steps'],
103 inputs[1:]))
104 text = ", ".join(f"'{k}': {simplify(v, 'list_var')}"
105 for k, v in atts.items())
106 if len(inputs) in (3, 4, 5):
107 rows.append(
108 "node = GraphBuilder(ctx).make_slice("
109 "{'data': varx[%r], %s}, return_node=True)"
110 "" % (inputs[0], text))
111 else:
112 raise NotImplementedError( # pragma: no cover
113 f"Unable to create code for operator {op_type!r} (opset <= 12).")
114 else:
115 if len(attributes) > 0:
116 attributes_str = ", ".join(f"{k}={v}" for k, v in attributes)
117 attr = f", attr=dict({attributes_str})"
118 else:
119 attr = ""
120 rows.append(
121 f"inputs = [{', '.join('varx[%r]' % n for n in inputs)}]")
122 sdomain = '' if domain == '' else (f"domain={domain!r}, ")
123 rows.append(
124 "node = ctx.make_node(%r, inputs=inputs%s, %s"
125 "name=make_name(%r))" % (
126 op_type, attr, sdomain, name))
127 for i, n in enumerate(outputs):
128 rows.append("varx[%r] = node.output[%d]" % (n, i))
129 if indent > 0:
130 sind = " " * indent
131 for i in range(1, len(rows)):
132 rows[i] = sind + rows[i]
133 return "\n".join(rows)
136def make_name(name):
137 "Creates a unique name."
138 global _make_name_id # pylint: disable=W0603
139 name = "%s_%d" % (name, _make_name_id)
140 _make_name_id += 1
141 return name
144def get_max_value(np_dtype):
145 "Returns the maximum value for a specific type."
146 return numpy.iinfo(np_dtype).max
149def make_sure(cond, msg, *args):
150 "Raises an exception if cond is not verified."
151 if not cond:
152 raise RuntimeError(msg % tuple(args)) # pragma: no cover
155def map_onnx_to_numpy_type(onnx_dtype):
156 "Converts ONNX type into numpy type."
157 if onnx_dtype is None:
158 return numpy.float32
159 return guess_dtype(onnx_dtype)
162class tf_op:
163 """
164 Decorator to register any new converter.
165 :param name: type of the operator to rewrite
166 :param domain: domain
167 """
168 _OPSETS = collections.OrderedDict()
170 def __init__(self, name, domain='', **kwargs):
171 if not isinstance(name, list):
172 name = [name]
173 self.names = name
174 self.domain = domain
175 self.kwargs = kwargs
177 def __call__(self, func):
178 for ke, va in inspect.getmembers(func, inspect.ismethod):
179 if ke.startswith("version_"):
180 version = int(ke.replace("version_", ""))
181 self._register_handler(
182 va, version, self.names, self.domain, self.kwargs)
183 return func
185 def _register_handler(self, func, version, names, domain, kwargs):
186 opset = tf_op._OPSETS.get(domain)
187 if not opset:
188 opset = []
189 tf_op._OPSETS[domain] = opset
190 while version >= len(opset):
191 opset.append({})
192 opset_dict = opset[version]
193 for name in names:
194 opset_dict[name] = (func, kwargs)
197class Tf2OnnxConvert:
198 """
199 Applies the converter on an ONNX graph.
201 :param onnx_model: ONNX graph
202 :param tf_op: class which register
203 :param verbose: verbosity
204 :param target_opset: targetted opsets
205 """
207 def __init__(self, onnx_model, _tf_op=None, verbose=None,
208 target_opset=None, max_iter=5):
209 self._onnx_model = onnx_model
210 self._tf_op = _tf_op or tf_op
211 self.verbose = verbose
212 self.max_iter = max_iter
213 if isinstance(target_opset, int):
214 self.target_opsets = {'': target_opset} # pragma: no cover
215 elif isinstance(target_opset, dict):
216 self.target_opsets = target_opset
217 elif target_opset is None: # pragma: no cover
218 opsets = {}
219 for oimp in onnx_model.opset_import:
220 if oimp.domain == '':
221 opsets[oimp.domain] = oimp.version
222 opset = oimp.version
223 else:
224 opsets[oimp.domain] = opset
225 self.target_opsets = opsets
226 else:
227 raise ValueError( # pragma: no cover
228 f"Unexepected value for target_opset={target_opset!r}.")
229 self._names = {}
230 for node in onnx_model.graph.node:
231 self._names[node.name] = node
232 for init in onnx_model.graph.initializer:
233 self._names[init.name] = init
234 # _forbidden_new_names contains current names and deleted names.
235 self._forbidden_new_names = set(self._names)
236 if '' in self.target_opsets:
237 self.opset = self.target_opsets['']
238 if not hasattr(self, 'opset'):
239 raise RuntimeError( # pragma: no cover
240 f"Attribute opset is missing, target_opset={target_opset!r}.")
242 def get_node_by_name(self, name): # pragma: no cover
243 """
244 Retrieves a node by its name.
246 :param name: node name
247 :return: node name
248 """
249 if name not in self._names:
250 raise RuntimeError(
251 "Unable to find node name %r among %r." % (
252 name, ", ".join(sorted(self._names))))
253 return self._names[name]
255 def _add_node_name(self, obj):
256 """
257 Registers an object in in the graph by its name.
258 :param name: node or initializer
259 """
260 if obj.name in self._forbidden_new_names:
261 raise RuntimeError( # pragma: no cover
262 f"Name {obj.name!r} is already registered.")
263 self._names[obj.name] = obj
264 self._forbidden_new_names.add(obj.name)
266 def make_node(self, op_type, inputs, attr=None, outputs=None,
267 name=None, domain='', output_count=1,
268 shapes=None, dtypes=None):
269 """
270 Adds a node to the list of nodes.
272 :param op_type: operator type
273 :param inputs: list of strings
274 :param attr: dictionary of attributes
275 :param outputs: None or list of strings
276 :param output_count: used if outputs is None to guess
277 the number of outputs of this node
278 :param name: name of the node
279 :param domain: domain
280 :param shapes: unused
281 :param dtypes: unused
282 :return: created node
283 """
284 if self.verbose:
285 print( # pragma: no cover
286 f"[Tf2OnnxConvert.make_node] op_type={op_type!r} inputs={inputs!r}")
288 if attr is None:
289 attr = {}
290 if name is None:
291 name = make_name(op_type)
292 if name in self._names:
293 raise RuntimeError( # pragma: no cover
294 "Node name %r already exists in %r." % (
295 name, ", ".join(sorted(self._names))))
297 if outputs is None:
298 outputs = [(name + ":" + str(i)) for i in range(output_count)]
300 output_count = len(outputs)
301 raw_attr = {}
302 onnx_attrs = []
303 for a, v in attr.items():
304 if isinstance(v, AttributeProto):
305 onnx_attrs.append(v) # pragma: no cover
306 else:
307 raw_attr[a] = v
309 onnx_node = make_node(
310 op_type, inputs, outputs, name=name, domain=domain, **raw_attr)
312 self._add_node_name(onnx_node)
313 return onnx_node
315 def make_const(self, name, np_val, skip_conversion=False, raw=True):
316 """
317 Make a new constants in the graph.
318 :param name: const node name, must be unique.
319 :param np_val: value of type numpy ndarray.
320 :param skip_conversion:
321 bool, indicate whether this created node would be mapped
322 during conversion
323 :param raw: whether to store data at field of raw_data or the
324 specific field according to its dtype
325 :return: create initializer
326 """
327 if name in self._names:
328 raise RuntimeError( # pragma: no cover
329 "Initializer name %r already exists in %r." % (
330 name, ", ".join(sorted(self._names))))
331 np_val_flat = np_val.flatten()
332 is_bytes = (np_val.dtype == numpy.object and len(np_val_flat) > 0 and
333 isinstance(np_val_flat[0], bytes))
334 if raw and not is_bytes:
335 onnx_tensor = from_array(np_val, name)
336 else: # pragma: no cover
337 onnx_tensor = make_tensor(
338 name, guess_proto_dtype(np_val.dtype),
339 np_val.shape, np_val_flat, raw=False)
341 self._add_node_name(onnx_tensor)
342 return onnx_tensor
344 def get_dtype(self, input_name):
345 """
346 Returns the type of one node or None if unknown.
347 :param input_name: result name
348 :return: numpy dtype
349 """
350 inputs = self._onnx_model.graph.input
351 names = [_.name for _ in inputs]
352 if input_name not in names:
353 return None # pragma: no cover
354 ind = names.index(input_name)
355 return inputs[ind].type.tensor_type.elem_type
357 def replace_all_inputs(self, old_name, new_name):
358 """
359 Every taking *old_name* as inputs will take *new_name* instead.
360 Looks in the output as well but in that case, it creates an identity
361 node to avoid changing an output name.
362 :param old_name: name to replace
363 :param new_name: new name
364 :return: list of impacted nodes
365 """
366 if self.verbose:
367 print( # pragma: no cover
368 "[Tf2OnnxConvert.replace_all_inputs] replace %r by %r" % (
369 old_name, new_name))
370 res = []
371 for node in self._names.values():
372 if not hasattr(node, 'input'):
373 continue
374 if old_name not in node.input:
375 continue
376 new_inputs = [ # pragma: no cover
377 new_name if i == old_name else i for i in node.input]
378 node.input[:] = new_inputs[:] # pragma: no cover
379 res.append(node) # pragma: no cover
380 if self.verbose: # pragma: no cover
381 print(
382 "[Tf2OnnxConvert.replace_all_inputs] replace %r by %r in node %r" % (
383 old_name, new_name, node.name))
384 for o in self._onnx_model.graph.output:
385 if o.name != old_name:
386 continue # pragma: no cover
387 n = self.make_node("Identity", [new_name], outputs=[old_name],
388 name=make_name("IdOutputReplaced"))
389 res.append(n)
390 if self.verbose:
391 print( # pragma: no cover
392 "[Tf2OnnxConvert.replace_all_inputs] add id node from %r to %r "
393 "with node %r." % (
394 old_name, new_name, n.name)) # pylint: disable=E1101
395 if self.verbose:
396 print( # pragma: no cover
397 "[Tf2OnnxConvert.replace_all_inputs] end")
398 return res
400 def remove_node(self, name):
401 """
402 Removes a node name from the list.
403 """
404 if name not in self._names:
405 raise RuntimeError( # pragma: no cover
406 f"Unable to delete name {name!r} because it does not exists.")
407 del self._names[name]
408 if self.verbose:
409 print( # pragma: no cover
410 f"[Tf2OnnxConvert.remove_node] delete name {name!r}")
412 def get_shape(self, input_name):
413 """
414 Returns the type of one node or None if unknown.
415 :param input_name: result name
416 :return: numpy dtype
417 """
418 inputs = self._onnx_model.graph.input
419 names = [_.name for _ in inputs]
420 if input_name not in names:
421 return None # pragma: no cover
422 ind = names.index(input_name)
423 dims = inputs[ind].type.tensor_type.shape.dim
424 return tuple(dims)
426 def run(self):
427 """
428 Calls the registered converters on the graph
429 held by this instance. Returns the new onnx graph.
431 :return: ONNX graph
432 """
433 if len(self._tf_op._OPSETS) == 0:
434 raise RuntimeError( # pragma: no cover
435 "No converter was registered.")
436 if self.verbose:
437 print("[Tf2OnnxConvert.run]") # pragma: no cover
439 done = {}
440 modif = 1
441 turn = 0
442 while modif > 0 and turn < self.max_iter:
443 modif = 0
444 turn += 1
445 # The converter may alter the current list of nodes, we freeze it.
446 current_values = list(self._names.values())
447 for node in current_values:
448 if not hasattr(node, 'domain'):
449 # initializer
450 continue
451 if done.get(node.name, False):
452 continue # pragma: no cover
453 domain = node.domain
454 if domain not in self._tf_op._OPSETS:
455 continue # pragma: no cover
457 # look for a converter
458 rews = self._tf_op._OPSETS[domain]
459 target = min(self.target_opsets[domain], len(rews))
460 conv = None
461 for i in range(len(rews) - 1, -1, -1):
462 if node.op_type in rews[i]:
463 conv = rews[i][node.op_type]
464 break
465 if conv is None:
466 continue
468 # applies the converter
469 if self.verbose:
470 print( # pragma: no cover
471 "[Tf2OnnxConvert.run] convert node type=%r opset=%r name=%r"
472 "" % (node.op_type, target, node.name))
473 fct, kwargs = conv
474 fct(self, node, target_opset=target, **kwargs)
475 modif += 1
477 if turn >= self.max_iter:
478 raise RuntimeError( # pragma: no cover
479 "Too many iterations and no stable ONNX was reached, "
480 "iter=%d\n%s" % (turn, str(self.make_model())))
481 return self.make_model()
483 def make_model(self):
484 """
485 Produces the new ONNX graph with the updated sets of nodes.
486 """
487 inputs = self._onnx_model.graph.input
488 outputs = self._onnx_model.graph.output
489 inits = [init[1] for init in sorted(self._names.items())
490 if not hasattr(init[1], 'domain')]
491 nodes = [node[1] for node in sorted(self._names.items())
492 if hasattr(node[1], 'domain')]
493 nodes = ensure_topological_order(inputs, inits, nodes)
495 if self.verbose:
496 print( # pragma: no cover
497 "[Tf2OnnxConvert.make_node] %d nodes %d inputs %d "
498 "outputs %d initializers"
499 "" % (len(nodes), len(inputs), len(outputs), len(inits)))
500 graph = make_graph(nodes, self._onnx_model.graph.name,
501 inputs, outputs, inits)
502 onnx_model = make_model(graph, functions=self._onnx_model.functions)
503 onnx_model.ir_version = self._onnx_model.ir_version
504 onnx_model.producer_name = self._onnx_model.producer_name + "-mlprodict"
505 onnx_model.producer_version = self._onnx_model.producer_version
506 onnx_model.domain = self._onnx_model.domain
507 onnx_model.model_version = self._onnx_model.model_version
508 onnx_model.doc_string = self._onnx_model.doc_string
509 metadata = {p.key: p.value for p in self._onnx_model.metadata_props}
510 set_model_props(onnx_model, metadata)
512 # opsets
513 del onnx_model.opset_import[:] # pylint: disable=E1101
514 for dom, value in self.target_opsets.items():
515 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
516 op_set.domain = dom
517 op_set.version = value
518 return onnx_model
521class GraphBuilder:
522 """
523 Helpers to build graph.
524 :param graph!
525 """
527 def __init__(self, graph):
528 self._g = graph
530 @property
531 def graph(self):
532 "Returns the graph."
533 return self._g
535 def make_slice(self, kwargs, name=None, shapes=None, dtypes=None,
536 return_node=False):
537 """
538 slice changes its schema at opset 10: it treats some
539 attributes as dynamic input so this function has to process
540 inputs according to graph's opset version
541 to get "inputs" and "attr" to feed "make_node"
542 kwargs: key could be `["data", "starts", "ends",
543 "axes", "steps", "outputs"]`.
544 """
545 outputs = kwargs.pop("outputs", None)
547 if self.graph.opset < 10:
548 # "data" is string
549 # "starts", "ends" and "axes" are attributes,
550 # and "axes" is optional.
551 data = kwargs.pop("data") # pragma: no cover
552 starts = self._convert_to_attribute( # pragma: no cover
553 kwargs.pop("starts"))
554 ends = self._convert_to_attribute( # pragma: no cover
555 kwargs.pop("ends"))
556 axes = self._convert_to_attribute( # pragma: no cover
557 kwargs.pop("axes", None), is_optional=True)
558 attr = {"starts": starts, "ends": ends,
559 "axes": axes} # pragma: no cover
560 inputs = [data] # pragma: no cover
561 else:
562 # slice-10 has 3 required inputs "data", "starts", "ends"l
563 # and 2 optional inputs "axes", "steps"
564 # input sequence should be "data", "starts", "ends",
565 # "axes", "steps"
566 attr = {}
567 data = kwargs.pop("data")
568 starts = self._convert_to_input(
569 kwargs.pop("starts"), "const_starts", dtype=numpy.int64)
570 ends = self._convert_to_input(
571 kwargs.pop("ends"), "const_ends", dtype=numpy.int64)
572 axes = self._convert_to_input(
573 kwargs.pop("axes", None), "const_axes",
574 is_optional=True, dtype=numpy.int64)
575 steps = self._convert_to_input(
576 kwargs.pop("steps", None), "const_steps",
577 is_optional=True, dtype=numpy.int64)
578 inputs = [data, starts, ends, axes, steps]
580 # pro-process inputs and attr
581 make_sure(not kwargs, "kwargs contains un-used key")
583 new_attr = {}
584 for key, val in attr.items():
585 if val is not None: # pragma: no cover
586 new_attr[key] = val
587 attr = new_attr
589 for ind, val in enumerate(inputs):
590 if val is None:
591 inputs[ind] = "" # empty string means no connection in ONNX
592 # remove tailing ""
593 while inputs[-1] == "":
594 inputs = inputs[:-1]
596 if self.graph.opset >= 10:
597 dtype = self.graph.get_dtype(inputs[1])
598 for input_data in inputs[1:]:
599 if input_data != "":
600 make_sure(dtype == self.graph.get_dtype(
601 input_data), "dtype should be same")
603 node = self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr,
604 name=name, outputs=outputs, shapes=shapes,
605 dtypes=dtypes)
606 if return_node:
607 return node
608 raise NotImplementedError( # pragma: no cover
609 "return_node must be True")
611 def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None,
612 return_node=False, op_name_scope=None):
613 """
614 Squeeze changes its schema at opset 13: it treats axes as a dynamic input
615 kwargs: key could be ["data", "axes"].
616 """
617 outputs = kwargs.pop("outputs", None)
619 if self.graph.opset < 13: # pragma: no cover
620 data = kwargs.pop("data")
621 axes = self._convert_to_attribute(
622 kwargs.pop("axes", None), is_optional=True)
623 attr = {"axes": axes}
624 inputs = [data]
625 else:
626 data = kwargs.pop("data")
627 axes = self._convert_to_input(
628 kwargs.pop("axes", None), "const_axes",
629 is_optional=True, dtype=numpy.int64)
630 attr = {}
631 inputs = [data, axes]
633 make_sure(not kwargs, "kwargs contains un-used key")
635 new_attr = {}
636 for key, val in attr.items():
637 if val is not None: # pragma: no cover
638 new_attr[key] = val
639 attr = new_attr
641 for ind, val in enumerate(inputs):
642 if val is None: # pragma: no cover
643 inputs[ind] = "" # empty string means no connection in ONNX
644 # remove tailing ""
645 while inputs[-1] == "":
646 inputs = inputs[:-1] # pragma: no cover
648 node = self.graph.make_node(
649 op_type="Squeeze", inputs=inputs, attr=attr, name=name,
650 outputs=outputs)
651 if return_node:
652 return node
653 raise NotImplementedError( # pragma: no cover
654 "return_node must be True")
656 def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None,
657 return_node=False, op_name_scope=None):
658 """
659 Unsqueeze changes its schema at opset 13: it treats axes as a dynamic input
660 kwargs: key could be ["data", "axes"].
661 """
662 outputs = kwargs.pop("outputs", None)
664 if self.graph.opset < 13:
665 data = kwargs.pop("data") # pragma: no cover
666 axes = self._convert_to_attribute( # pragma: no cover
667 kwargs.pop("axes", None), is_optional=True)
668 attr = {"axes": axes} # pragma: no cover
669 inputs = [data] # pragma: no cover
670 else:
671 data = kwargs.pop("data")
672 axes = self._convert_to_input(
673 kwargs.pop("axes", None), "const_axes",
674 is_optional=True, dtype=numpy.int64)
675 attr = {}
676 inputs = [data, axes]
678 make_sure(not kwargs, "kwargs contains un-used key")
680 new_attr = {}
681 for key, val in attr.items():
682 if val is not None: # pragma: no cover
683 new_attr[key] = val
684 attr = new_attr
686 for ind, val in enumerate(inputs):
687 if val is None: # pragma: no cover
688 inputs[ind] = "" # empty string means no connection in ONNX
689 # remove tailing ""
690 while inputs[-1] == "":
691 inputs = inputs[:-1] # pragma: no cover
693 node = self.graph.make_node(
694 op_type="Unsqueeze", inputs=inputs, attr=attr, name=name,
695 outputs=outputs)
696 if return_node:
697 return node
698 raise NotImplementedError( # pragma: no cover
699 "return_node must be True")
701 def _convert_to_input(self, tensor, const_name, # pragma: no cover
702 is_optional=False, dtype=None):
703 """in ONNX, input shold come from node, so it must be a string"""
704 if is_optional and tensor is None:
705 return None
707 make_sure(tensor is not None,
708 "input is required so it couldn't be None")
710 res = tensor
711 if isinstance(tensor, list):
712 res = self.graph.make_const(
713 make_name(const_name), numpy.array(tensor, dtype)).name
714 return res
716 def _convert_to_attribute(self, tensor, is_optional=False):
717 if is_optional and tensor is None:
718 return None
720 make_sure(tensor is not None,
721 "input is required so it couldn't be None")
723 res = tensor
724 if isinstance(tensor, str):
725 const_node = self.graph.get_node_by_output(tensor)
726 res = const_node.get_tensor_value(as_list=True)
728 make_sure(isinstance(res, list),
729 "input is an attr, so a list is needed")
731 return res