Coverage for mlprodict/tools/graphs.py: 98%
338 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Alternative to dot to display a graph.
5.. versionadded:: 0.7
6"""
7import pprint
8import hashlib
9import numpy
10import onnx
13def make_hash_bytes(data, length=20):
14 """
15 Creates a hash of length *length*.
16 """
17 m = hashlib.sha256()
18 m.update(data)
19 res = m.hexdigest()[:length]
20 return res
23class AdjacencyGraphDisplay:
24 """
25 Structure which contains the necessary information to
26 display a graph using an adjacency matrix.
28 .. versionadded:: 0.7
29 """
31 class Action:
32 "One action to do."
34 def __init__(self, x, y, kind, label, orientation=None):
35 self.x = x
36 self.y = y
37 self.kind = kind
38 self.label = label
39 self.orientation = orientation
41 def __repr__(self):
42 "usual"
43 return "%s(%r, %r, %r, %r, %r)" % (
44 self.__class__.__name__,
45 self.x, self.y, self.kind, self.label,
46 self.orientation)
48 def __init__(self):
49 self.actions = []
51 def __iter__(self):
52 "Iterates over actions."
53 for act in self.actions:
54 yield act
56 def __str__(self):
57 "usual"
58 rows = [f"{self.__class__.__name__}("]
59 for act in self:
60 rows.append(f" {act!r}")
61 rows.append(")")
62 return "\n".join(rows)
64 def add(self, x, y, kind, label, orientation=None):
65 """
66 Adds an action to display the graph.
68 :param x: x coordinate
69 :param y: y coordinate
70 :param kind: `'cross'` or `'text'`
71 :param label: specific to kind
72 :param orientation: a 2-uple `(i,j)` where *i* or *j* in `{-1,0,1}`
73 """
74 if kind not in {'cross', 'text'}:
75 raise ValueError( # pragma: no cover
76 f"Unexpected value for kind {kind!r}.")
77 if kind == 'cross' and label[0] not in {'I', 'O'}:
78 raise ValueError( # pragma: no cover
79 "kind=='cross' and label[0]=%r not in {'I','O'}." % label)
80 if not isinstance(label, str):
81 raise TypeError( # pragma: no cover
82 f"Unexpected label type {type(label)!r}.")
83 self.actions.append(
84 AdjacencyGraphDisplay.Action(x, y, kind, label=label,
85 orientation=orientation))
87 def to_text(self):
88 """
89 Displays the graph as a single string.
90 See @see fn onnx2bigraph to see how the result
91 looks like.
93 :return: str
94 """
95 mat = {}
96 for act in self:
97 if act.kind == 'cross':
98 if act.orientation != (1, 0):
99 raise NotImplementedError( # pragma: no cover
100 "Orientation for 'cross' must be (1, 0) not %r."
101 "" % act.orientation)
102 if len(act.label) == 1:
103 mat[act.x * 3, act.y] = act.label
104 elif len(act.label) == 2:
105 mat[act.x * 3, act.y] = act.label[0]
106 mat[act.x * 3 + 1, act.y] = act.label[1]
107 else:
108 raise NotImplementedError(
109 f"Unable to display long cross label ({act.label!r}).")
110 elif act.kind == 'text':
111 x = act.x * 3
112 y = act.y
113 orient = act.orientation
114 charset = list(act.label if max(orient) == 1
115 else reversed(act.label))
116 for c in charset:
117 mat[x, y] = c
118 x += orient[0]
119 y += orient[1]
120 else:
121 raise ValueError( # pragma: no cover
122 f"Unexpected kind value {act.kind!r}.")
124 min_i = min(k[0] for k in mat)
125 min_j = min(k[1] for k in mat)
126 mat2 = {}
127 for k, v in mat.items():
128 mat2[k[0] - min_i, k[1] - min_j] = v
130 max_x = max(k[0] for k in mat2)
131 max_y = max(k[1] for k in mat2)
133 mat = numpy.full((max_y + 1, max_x + 1), ' ')
134 for k, v in mat2.items():
135 mat[k[1], k[0]] = v
136 rows = []
137 for i in range(mat.shape[0]):
138 rows.append(''.join(mat[i]))
139 return "\n".join(rows)
142class BiGraph:
143 """
144 BiGraph representation.
146 .. versionadded:: 0.7
147 """
149 class A:
150 "Additional information for a vertex or an edge."
152 def __init__(self, kind):
153 self.kind = kind
155 def __repr__(self):
156 return f"A({self.kind!r})"
158 class B:
159 "Additional information for a vertex or an edge."
161 def __init__(self, name, content, onnx_name):
162 if not isinstance(content, str):
163 raise TypeError( # pragma: no cover
164 f"content must be str not {type(content)!r}.")
165 self.name = name
166 self.content = content
167 self.onnx_name = onnx_name
169 def __repr__(self):
170 return f"B({self.name!r}, {self.content!r}, {self.onnx_name!r})"
172 def __init__(self, v0, v1, edges):
173 """
174 :param v0: first set of vertices (dictionary)
175 :param v1: second set of vertices (dictionary)
176 :param edges: edges
177 """
178 if not isinstance(v0, dict):
179 raise TypeError("v0 must be a dictionary.")
180 if not isinstance(v1, dict):
181 raise TypeError("v0 must be a dictionary.")
182 if not isinstance(edges, dict):
183 raise TypeError("edges must be a dictionary.")
184 self.v0 = v0
185 self.v1 = v1
186 self.edges = edges
187 common = set(self.v0).intersection(set(self.v1))
188 if len(common) > 0:
189 raise ValueError(
190 f"Sets v1 and v2 have common nodes (forbidden): {common!r}.")
191 for a, b in edges:
192 if a in v0 and b in v1:
193 continue
194 if a in v1 and b in v0:
195 continue
196 if b in v1:
197 # One operator is missing one input.
198 # We add one.
199 self.v0[a] = BiGraph.A('ERROR')
200 continue
201 raise ValueError(
202 f"Edges ({a!r}, {b!r}) not found among the vertices.")
204 def __str__(self):
205 """
206 usual
207 """
208 return "%s(%d v., %d v., %d edges)" % (
209 self.__class__.__name__, len(self.v0),
210 len(self.v1), len(self.edges))
212 def __iter__(self):
213 """
214 Iterates over all vertices and edges.
215 It produces 3-uples:
217 * 0, name, A: vertices in *v0*
218 * 1, name, A: vertices in *v1*
219 * -1, name, A: edges
220 """
221 for k, v in self.v0.items():
222 yield 0, k, v
223 for k, v in self.v1.items():
224 yield 1, k, v
225 for k, v in self.edges.items():
226 yield -1, k, v
228 def __getitem__(self, key):
229 """
230 Returns a vertex is key is a string or an edge
231 if it is a tuple.
233 :param key: vertex or edge name
234 :return: value
235 """
236 if isinstance(key, tuple):
237 return self.edges[key]
238 if key in self.v0:
239 return self.v0[key]
240 return self.v1[key]
242 def order_vertices(self):
243 """
244 Orders the vertices from the input to the output.
246 :return: dictionary `{vertex name: order}`
247 """
248 order = {}
249 for v in self.v0:
250 order[v] = 0
251 for v in self.v1:
252 order[v] = 0
253 modif = 1
254 n_iter = 0
255 while modif > 0:
256 modif = 0
257 for a, b in self.edges:
258 if order[b] <= order[a]:
259 order[b] = order[a] + 1
260 modif += 1
261 n_iter += 1
262 if n_iter > len(order):
263 break
264 if modif > 0:
265 raise RuntimeError(
266 f"The graph has a cycle.\n{pprint.pformat(self.edges)}")
267 return order
269 def adjacency_matrix(self):
270 """
271 Builds an adjacency matrix.
273 :return: matrix, list of row vertices, list of column vertices
274 """
275 order = self.order_vertices()
276 ord_v0 = [(v, k) for k, v in order.items() if k in self.v0]
277 ord_v1 = [(v, k) for k, v in order.items() if k in self.v1]
278 ord_v0.sort()
279 ord_v1.sort()
280 row = [b for a, b in ord_v0]
281 col = [b for a, b in ord_v1]
282 row_id = {b: i for i, b in enumerate(row)}
283 col_id = {b: i for i, b in enumerate(col)}
284 matrix = numpy.zeros((len(row), len(col)), dtype=numpy.int32)
285 for a, b in self.edges:
286 if a in row_id:
287 matrix[row_id[a], col_id[b]] = 1
288 else:
289 matrix[row_id[b], col_id[a]] = 1
290 return matrix, row, col
292 def display_structure(self, grid=5, distance=5):
293 """
294 Creates a display structure which contains
295 all the necessary steps to display a graph.
297 :param grid: align text to this grid
298 :param distance: distance to the text
299 :return: instance of @see cl AdjacencyGraphDisplay
300 """
301 def adjust(c, way):
302 if way == 1:
303 d = grid * ((c + distance * 2 - (grid // 2 + 1)) // grid)
304 else:
305 d = -grid * ((-c + distance * 2 - (grid // 2 + 1)) // grid)
306 return d
308 matrix, row, col = self.adjacency_matrix()
309 row_id = {b: i for i, b in enumerate(row)}
310 col_id = {b: i for i, b in enumerate(col)}
312 interval_y_min = numpy.zeros((matrix.shape[0], ), dtype=numpy.int32)
313 interval_y_max = numpy.zeros((matrix.shape[0], ), dtype=numpy.int32)
314 interval_x_min = numpy.zeros((matrix.shape[1], ), dtype=numpy.int32)
315 interval_x_max = numpy.zeros((matrix.shape[1], ), dtype=numpy.int32)
316 interval_y_min[:] = max(matrix.shape)
317 interval_x_min[:] = max(matrix.shape)
319 graph = AdjacencyGraphDisplay()
320 for key, value in self.edges.items():
321 if key[0] in row_id:
322 y = row_id[key[0]]
323 x = col_id[key[1]]
324 else:
325 x = col_id[key[0]]
326 y = row_id[key[1]]
327 graph.add(x, y, 'cross', label=value.kind, orientation=(1, 0))
328 if x < interval_y_min[y]:
329 interval_y_min[y] = x
330 if x > interval_y_max[y]:
331 interval_y_max[y] = x
332 if y < interval_x_min[x]:
333 interval_x_min[x] = y
334 if y > interval_x_max[x]:
335 interval_x_max[x] = y
337 for k, v in self.v0.items():
338 y = row_id[k]
339 x = adjust(interval_y_min[y], -1)
340 graph.add(x, y, 'text', label=v.kind, orientation=(-1, 0))
341 x = adjust(interval_y_max[y], 1)
342 graph.add(x, y, 'text', label=k, orientation=(1, 0))
344 for k, v in self.v1.items():
345 x = col_id[k]
346 y = adjust(interval_x_min[x], -1)
347 graph.add(x, y, 'text', label=v.kind, orientation=(0, -1))
348 y = adjust(interval_x_max[x], 1)
349 graph.add(x, y, 'text', label=k, orientation=(0, 1))
351 return graph
353 def order(self):
354 """
355 Order nodes. Depth first.
356 Returns a sequence of keys of mixed *v1*, *v2*.
357 """
358 # Creates forwards nodes.
359 forwards = {}
360 backwards = {}
361 for k in self.v0:
362 forwards[k] = []
363 backwards[k] = []
364 for k in self.v1:
365 forwards[k] = []
366 backwards[k] = []
367 modif = True
368 while modif:
369 modif = False
370 for edge in self.edges:
371 a, b = edge
372 if b not in forwards[a]:
373 forwards[a].append(b)
374 modif = True
375 if a not in backwards[b]:
376 backwards[b].append(a)
377 modif = True
379 # roots
380 roots = [b for b, backs in backwards.items() if len(backs) == 0]
381 if len(roots) == 0:
382 raise RuntimeError( # pragma: no cover
383 "This graph has cycles. Not allowed.")
385 # ordering
386 order = {}
387 stack = roots
388 while len(stack) > 0:
389 node = stack.pop()
390 order[node] = len(order)
391 w = forwards[node]
392 if len(w) == 0:
393 continue
394 last = w.pop()
395 stack.append(last)
397 return order
399 def summarize(self):
400 """
401 Creates a text summary of the graph.
402 """
403 order = self.order()
404 keys = [(o, k) for k, o in order.items()]
405 keys.sort()
407 rows = []
408 for _, k in keys:
409 if k in self.v1:
410 rows.append(str(self.v1[k]))
411 return "\n".join(rows)
413 @staticmethod
414 def _onnx2bigraph_basic(model_onnx, recursive=False):
415 """
416 Implements graph type `'basic'` for function
417 @see fn onnx2bigraph.
418 """
420 if recursive:
421 raise NotImplementedError( # pragma: no cover
422 "Option recursive=True is not implemented yet.")
423 v0 = {}
424 v1 = {}
425 edges = {}
427 # inputs
428 for i, o in enumerate(model_onnx.graph.input):
429 v0[o.name] = BiGraph.A('Input-%d' % i)
430 for i, o in enumerate(model_onnx.graph.output):
431 v0[o.name] = BiGraph.A('Output-%d' % i)
432 for o in model_onnx.graph.initializer:
433 v0[o.name] = BiGraph.A('Init')
434 for n in model_onnx.graph.node:
435 nname = n.name if len(n.name) > 0 else "id%d" % id(n)
436 v1[nname] = BiGraph.A(n.op_type)
437 for i, o in enumerate(n.input):
438 c = str(i) if i < 10 else "+"
439 nname = n.name if len(n.name) > 0 else "id%d" % id(n)
440 edges[o, nname] = BiGraph.A(f'I{c}')
441 for i, o in enumerate(n.output):
442 c = str(i) if i < 10 else "+"
443 if o not in v0:
444 v0[o] = BiGraph.A('inout')
445 nname = n.name if len(n.name) > 0 else "id%d" % id(n)
446 edges[nname, o] = BiGraph.A(f'O{c}')
448 return BiGraph(v0, v1, edges)
450 @staticmethod
451 def _onnx2bigraph_simplified(model_onnx, recursive=False):
452 """
453 Implements graph type `'simplified'` for function
454 @see fn onnx2bigraph.
455 """
456 if recursive:
457 raise NotImplementedError( # pragma: no cover
458 "Option recursive=True is not implemented yet.")
459 v0 = {}
460 v1 = {}
461 edges = {}
463 # inputs
464 for o in model_onnx.graph.input:
465 v0[f"I{len(v0)}"] = BiGraph.B(
466 'In', make_hash_bytes(o.type.SerializeToString(), 2), o.name)
467 for o in model_onnx.graph.output:
468 v0[f"O{len(v0)}"] = BiGraph.B(
469 'Ou', make_hash_bytes(o.type.SerializeToString(), 2), o.name)
470 for o in model_onnx.graph.initializer:
471 v0[f"C{len(v0)}"] = BiGraph.B(
472 'Cs', make_hash_bytes(o.raw_data, 10), o.name)
474 names_v0 = {v.onnx_name: k for k, v in v0.items()}
476 for n in model_onnx.graph.node:
477 key_node = f"N{len(v1)}"
478 if len(n.attribute) > 0:
479 ats = []
480 for at in n.attribute:
481 ats.append(at.SerializeToString())
482 ct = make_hash_bytes(b"".join(ats), 10)
483 else:
484 ct = ""
485 v1[key_node] = BiGraph.B(
486 n.op_type, ct, n.name)
487 for o in n.input:
488 key_in = names_v0[o]
489 edges[key_in, key_node] = BiGraph.A('I')
490 for o in n.output:
491 if o not in names_v0:
492 key = f"R{len(v0)}"
493 v0[key] = BiGraph.B('Re', n.op_type, o)
494 names_v0[o] = key
495 edges[key_node, key] = BiGraph.A('O')
497 return BiGraph(v0, v1, edges)
499 @staticmethod
500 def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print):
501 """
502 Computes a distance between two ONNX graphs. They must not
503 be too big otherwise this function might take for ever.
504 The function relies on package :epkg:`mlstatpy`.
506 :param onx1: first graph (ONNX graph or model file name)
507 :param onx2: second graph (ONNX graph or model file name)
508 :param verbose: verbosity
509 :param fLOG: logging function
510 :return: distance and differences
512 .. warning::
514 This is very experimental and very slow.
516 .. versionadded:: 0.7
517 """
518 from mlstatpy.graph.graph_distance import GraphDistance
520 if isinstance(onx1, str):
521 onx1 = onnx.load(onx1)
522 if isinstance(onx2, str):
523 onx2 = onnx.load(onx2)
525 def make_hash(init):
526 return make_hash_bytes(init.raw_data)
528 def build_graph(onx):
529 edges = []
530 labels = {}
531 for node in list(onx.graph.node):
532 if len(node.name) == 0:
533 name = str(id(node))
534 else:
535 name = node.name
536 for i in node.input:
537 edges.append((i, name))
538 for p, i in enumerate(node.output):
539 edges.append((name, i))
540 labels[i] = "%s:%d" % (node.op_type, p)
541 labels[name] = node.op_type
542 for init in onx.graph.initializer:
543 labels[init.name] = make_hash(init)
545 g = GraphDistance(edges, vertex_label=labels)
546 return g
548 g1 = build_graph(onx1)
549 g2 = build_graph(onx2)
551 dist, gdist = g1.distance_matching_graphs_paths(
552 g2, verbose=verbose, fLOG=fLOG, use_min=False)
553 return dist, gdist
556def onnx2bigraph(model_onnx, recursive=False, graph_type='basic'):
557 """
558 Converts an ONNX graph into a graph representation,
559 edges, vertices.
561 :param model_onnx: ONNX graph
562 :param recursive: dig into subgraphs too
563 :param graph_type: kind of graph it creates
564 :return: see @cl BiGraph
566 About *graph_type*:
568 * `'basic'`: basic graph structure, it returns an instance
569 of type @see cl BiGraph. The structure keeps the original
570 names.
571 * `'simplified'`: simplifed graph structure, names are removed
572 as they could be prevent the algorithm to find any matching.
574 .. exref::
575 :title: Displays an ONNX graph as text
577 The function uses an adjacency matrix of the graph.
578 Results are displayed by rows, operator by columns.
579 Results kinds are shows on the left,
580 their names on the right. Operator types are displayed
581 on the top, their names on the bottom.
583 .. runpython::
584 :showcode:
586 import numpy
587 from mlprodict.onnx_conv import to_onnx
588 from mlprodict import __max_supported_opset__ as opv
589 from mlprodict.tools.graphs import onnx2bigraph
590 from mlprodict.npy.xop import loadop
592 OnnxAdd, OnnxSub = loadop('Add', 'Sub')
594 idi = numpy.identity(2).astype(numpy.float32)
595 A = OnnxAdd('X', idi, op_version=opv)
596 B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv)
597 onx = B.to_onnx({'X': idi, 'W': idi})
598 bigraph = onnx2bigraph(onx)
599 graph = bigraph.display_structure()
600 text = graph.to_text()
601 print(text)
603 .. versionadded:: 0.7
604 """
605 if graph_type == 'basic':
606 return BiGraph._onnx2bigraph_basic(
607 model_onnx, recursive=recursive)
608 if graph_type == 'simplified':
609 return BiGraph._onnx2bigraph_simplified(
610 model_onnx, recursive=recursive)
611 raise ValueError(
612 f"Unknown value for graph_type={graph_type!r}.")
615def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print):
616 """
617 Computes a distance between two ONNX graphs. They must not
618 be too big otherwise this function might take for ever.
619 The function relies on package :epkg:`mlstatpy`.
621 :param onx1: first graph (ONNX graph or model file name)
622 :param onx2: second graph (ONNX graph or model file name)
623 :param verbose: verbosity
624 :param fLOG: logging function
625 :return: distance and differences
627 .. warning::
629 This is very experimental and very slow.
631 .. versionadded:: 0.7
632 """
633 return BiGraph.onnx_graph_distance(onx1, onx2, verbose=verbose, fLOG=fLOG)