Coverage for mlprodict/tools/graphs.py: 98%

338 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Alternative to dot to display a graph. 

4 

5.. versionadded:: 0.7 

6""" 

7import pprint 

8import hashlib 

9import numpy 

10import onnx 

11 

12 

13def make_hash_bytes(data, length=20): 

14 """ 

15 Creates a hash of length *length*. 

16 """ 

17 m = hashlib.sha256() 

18 m.update(data) 

19 res = m.hexdigest()[:length] 

20 return res 

21 

22 

23class AdjacencyGraphDisplay: 

24 """ 

25 Structure which contains the necessary information to 

26 display a graph using an adjacency matrix. 

27 

28 .. versionadded:: 0.7 

29 """ 

30 

31 class Action: 

32 "One action to do." 

33 

34 def __init__(self, x, y, kind, label, orientation=None): 

35 self.x = x 

36 self.y = y 

37 self.kind = kind 

38 self.label = label 

39 self.orientation = orientation 

40 

41 def __repr__(self): 

42 "usual" 

43 return "%s(%r, %r, %r, %r, %r)" % ( 

44 self.__class__.__name__, 

45 self.x, self.y, self.kind, self.label, 

46 self.orientation) 

47 

48 def __init__(self): 

49 self.actions = [] 

50 

51 def __iter__(self): 

52 "Iterates over actions." 

53 for act in self.actions: 

54 yield act 

55 

56 def __str__(self): 

57 "usual" 

58 rows = [f"{self.__class__.__name__}("] 

59 for act in self: 

60 rows.append(f" {act!r}") 

61 rows.append(")") 

62 return "\n".join(rows) 

63 

64 def add(self, x, y, kind, label, orientation=None): 

65 """ 

66 Adds an action to display the graph. 

67 

68 :param x: x coordinate 

69 :param y: y coordinate 

70 :param kind: `'cross'` or `'text'` 

71 :param label: specific to kind 

72 :param orientation: a 2-uple `(i,j)` where *i* or *j* in `{-1,0,1}` 

73 """ 

74 if kind not in {'cross', 'text'}: 

75 raise ValueError( # pragma: no cover 

76 f"Unexpected value for kind {kind!r}.") 

77 if kind == 'cross' and label[0] not in {'I', 'O'}: 

78 raise ValueError( # pragma: no cover 

79 "kind=='cross' and label[0]=%r not in {'I','O'}." % label) 

80 if not isinstance(label, str): 

81 raise TypeError( # pragma: no cover 

82 f"Unexpected label type {type(label)!r}.") 

83 self.actions.append( 

84 AdjacencyGraphDisplay.Action(x, y, kind, label=label, 

85 orientation=orientation)) 

86 

87 def to_text(self): 

88 """ 

89 Displays the graph as a single string. 

90 See @see fn onnx2bigraph to see how the result 

91 looks like. 

92 

93 :return: str 

94 """ 

95 mat = {} 

96 for act in self: 

97 if act.kind == 'cross': 

98 if act.orientation != (1, 0): 

99 raise NotImplementedError( # pragma: no cover 

100 "Orientation for 'cross' must be (1, 0) not %r." 

101 "" % act.orientation) 

102 if len(act.label) == 1: 

103 mat[act.x * 3, act.y] = act.label 

104 elif len(act.label) == 2: 

105 mat[act.x * 3, act.y] = act.label[0] 

106 mat[act.x * 3 + 1, act.y] = act.label[1] 

107 else: 

108 raise NotImplementedError( 

109 f"Unable to display long cross label ({act.label!r}).") 

110 elif act.kind == 'text': 

111 x = act.x * 3 

112 y = act.y 

113 orient = act.orientation 

114 charset = list(act.label if max(orient) == 1 

115 else reversed(act.label)) 

116 for c in charset: 

117 mat[x, y] = c 

118 x += orient[0] 

119 y += orient[1] 

120 else: 

121 raise ValueError( # pragma: no cover 

122 f"Unexpected kind value {act.kind!r}.") 

123 

124 min_i = min(k[0] for k in mat) 

125 min_j = min(k[1] for k in mat) 

126 mat2 = {} 

127 for k, v in mat.items(): 

128 mat2[k[0] - min_i, k[1] - min_j] = v 

129 

130 max_x = max(k[0] for k in mat2) 

131 max_y = max(k[1] for k in mat2) 

132 

133 mat = numpy.full((max_y + 1, max_x + 1), ' ') 

134 for k, v in mat2.items(): 

135 mat[k[1], k[0]] = v 

136 rows = [] 

137 for i in range(mat.shape[0]): 

138 rows.append(''.join(mat[i])) 

139 return "\n".join(rows) 

140 

141 

142class BiGraph: 

143 """ 

144 BiGraph representation. 

145 

146 .. versionadded:: 0.7 

147 """ 

148 

149 class A: 

150 "Additional information for a vertex or an edge." 

151 

152 def __init__(self, kind): 

153 self.kind = kind 

154 

155 def __repr__(self): 

156 return f"A({self.kind!r})" 

157 

158 class B: 

159 "Additional information for a vertex or an edge." 

160 

161 def __init__(self, name, content, onnx_name): 

162 if not isinstance(content, str): 

163 raise TypeError( # pragma: no cover 

164 f"content must be str not {type(content)!r}.") 

165 self.name = name 

166 self.content = content 

167 self.onnx_name = onnx_name 

168 

169 def __repr__(self): 

170 return f"B({self.name!r}, {self.content!r}, {self.onnx_name!r})" 

171 

172 def __init__(self, v0, v1, edges): 

173 """ 

174 :param v0: first set of vertices (dictionary) 

175 :param v1: second set of vertices (dictionary) 

176 :param edges: edges 

177 """ 

178 if not isinstance(v0, dict): 

179 raise TypeError("v0 must be a dictionary.") 

180 if not isinstance(v1, dict): 

181 raise TypeError("v0 must be a dictionary.") 

182 if not isinstance(edges, dict): 

183 raise TypeError("edges must be a dictionary.") 

184 self.v0 = v0 

185 self.v1 = v1 

186 self.edges = edges 

187 common = set(self.v0).intersection(set(self.v1)) 

188 if len(common) > 0: 

189 raise ValueError( 

190 f"Sets v1 and v2 have common nodes (forbidden): {common!r}.") 

191 for a, b in edges: 

192 if a in v0 and b in v1: 

193 continue 

194 if a in v1 and b in v0: 

195 continue 

196 if b in v1: 

197 # One operator is missing one input. 

198 # We add one. 

199 self.v0[a] = BiGraph.A('ERROR') 

200 continue 

201 raise ValueError( 

202 f"Edges ({a!r}, {b!r}) not found among the vertices.") 

203 

204 def __str__(self): 

205 """ 

206 usual 

207 """ 

208 return "%s(%d v., %d v., %d edges)" % ( 

209 self.__class__.__name__, len(self.v0), 

210 len(self.v1), len(self.edges)) 

211 

212 def __iter__(self): 

213 """ 

214 Iterates over all vertices and edges. 

215 It produces 3-uples: 

216 

217 * 0, name, A: vertices in *v0* 

218 * 1, name, A: vertices in *v1* 

219 * -1, name, A: edges 

220 """ 

221 for k, v in self.v0.items(): 

222 yield 0, k, v 

223 for k, v in self.v1.items(): 

224 yield 1, k, v 

225 for k, v in self.edges.items(): 

226 yield -1, k, v 

227 

228 def __getitem__(self, key): 

229 """ 

230 Returns a vertex is key is a string or an edge 

231 if it is a tuple. 

232 

233 :param key: vertex or edge name 

234 :return: value 

235 """ 

236 if isinstance(key, tuple): 

237 return self.edges[key] 

238 if key in self.v0: 

239 return self.v0[key] 

240 return self.v1[key] 

241 

242 def order_vertices(self): 

243 """ 

244 Orders the vertices from the input to the output. 

245 

246 :return: dictionary `{vertex name: order}` 

247 """ 

248 order = {} 

249 for v in self.v0: 

250 order[v] = 0 

251 for v in self.v1: 

252 order[v] = 0 

253 modif = 1 

254 n_iter = 0 

255 while modif > 0: 

256 modif = 0 

257 for a, b in self.edges: 

258 if order[b] <= order[a]: 

259 order[b] = order[a] + 1 

260 modif += 1 

261 n_iter += 1 

262 if n_iter > len(order): 

263 break 

264 if modif > 0: 

265 raise RuntimeError( 

266 f"The graph has a cycle.\n{pprint.pformat(self.edges)}") 

267 return order 

268 

269 def adjacency_matrix(self): 

270 """ 

271 Builds an adjacency matrix. 

272 

273 :return: matrix, list of row vertices, list of column vertices 

274 """ 

275 order = self.order_vertices() 

276 ord_v0 = [(v, k) for k, v in order.items() if k in self.v0] 

277 ord_v1 = [(v, k) for k, v in order.items() if k in self.v1] 

278 ord_v0.sort() 

279 ord_v1.sort() 

280 row = [b for a, b in ord_v0] 

281 col = [b for a, b in ord_v1] 

282 row_id = {b: i for i, b in enumerate(row)} 

283 col_id = {b: i for i, b in enumerate(col)} 

284 matrix = numpy.zeros((len(row), len(col)), dtype=numpy.int32) 

285 for a, b in self.edges: 

286 if a in row_id: 

287 matrix[row_id[a], col_id[b]] = 1 

288 else: 

289 matrix[row_id[b], col_id[a]] = 1 

290 return matrix, row, col 

291 

292 def display_structure(self, grid=5, distance=5): 

293 """ 

294 Creates a display structure which contains 

295 all the necessary steps to display a graph. 

296 

297 :param grid: align text to this grid 

298 :param distance: distance to the text 

299 :return: instance of @see cl AdjacencyGraphDisplay 

300 """ 

301 def adjust(c, way): 

302 if way == 1: 

303 d = grid * ((c + distance * 2 - (grid // 2 + 1)) // grid) 

304 else: 

305 d = -grid * ((-c + distance * 2 - (grid // 2 + 1)) // grid) 

306 return d 

307 

308 matrix, row, col = self.adjacency_matrix() 

309 row_id = {b: i for i, b in enumerate(row)} 

310 col_id = {b: i for i, b in enumerate(col)} 

311 

312 interval_y_min = numpy.zeros((matrix.shape[0], ), dtype=numpy.int32) 

313 interval_y_max = numpy.zeros((matrix.shape[0], ), dtype=numpy.int32) 

314 interval_x_min = numpy.zeros((matrix.shape[1], ), dtype=numpy.int32) 

315 interval_x_max = numpy.zeros((matrix.shape[1], ), dtype=numpy.int32) 

316 interval_y_min[:] = max(matrix.shape) 

317 interval_x_min[:] = max(matrix.shape) 

318 

319 graph = AdjacencyGraphDisplay() 

320 for key, value in self.edges.items(): 

321 if key[0] in row_id: 

322 y = row_id[key[0]] 

323 x = col_id[key[1]] 

324 else: 

325 x = col_id[key[0]] 

326 y = row_id[key[1]] 

327 graph.add(x, y, 'cross', label=value.kind, orientation=(1, 0)) 

328 if x < interval_y_min[y]: 

329 interval_y_min[y] = x 

330 if x > interval_y_max[y]: 

331 interval_y_max[y] = x 

332 if y < interval_x_min[x]: 

333 interval_x_min[x] = y 

334 if y > interval_x_max[x]: 

335 interval_x_max[x] = y 

336 

337 for k, v in self.v0.items(): 

338 y = row_id[k] 

339 x = adjust(interval_y_min[y], -1) 

340 graph.add(x, y, 'text', label=v.kind, orientation=(-1, 0)) 

341 x = adjust(interval_y_max[y], 1) 

342 graph.add(x, y, 'text', label=k, orientation=(1, 0)) 

343 

344 for k, v in self.v1.items(): 

345 x = col_id[k] 

346 y = adjust(interval_x_min[x], -1) 

347 graph.add(x, y, 'text', label=v.kind, orientation=(0, -1)) 

348 y = adjust(interval_x_max[x], 1) 

349 graph.add(x, y, 'text', label=k, orientation=(0, 1)) 

350 

351 return graph 

352 

353 def order(self): 

354 """ 

355 Order nodes. Depth first. 

356 Returns a sequence of keys of mixed *v1*, *v2*. 

357 """ 

358 # Creates forwards nodes. 

359 forwards = {} 

360 backwards = {} 

361 for k in self.v0: 

362 forwards[k] = [] 

363 backwards[k] = [] 

364 for k in self.v1: 

365 forwards[k] = [] 

366 backwards[k] = [] 

367 modif = True 

368 while modif: 

369 modif = False 

370 for edge in self.edges: 

371 a, b = edge 

372 if b not in forwards[a]: 

373 forwards[a].append(b) 

374 modif = True 

375 if a not in backwards[b]: 

376 backwards[b].append(a) 

377 modif = True 

378 

379 # roots 

380 roots = [b for b, backs in backwards.items() if len(backs) == 0] 

381 if len(roots) == 0: 

382 raise RuntimeError( # pragma: no cover 

383 "This graph has cycles. Not allowed.") 

384 

385 # ordering 

386 order = {} 

387 stack = roots 

388 while len(stack) > 0: 

389 node = stack.pop() 

390 order[node] = len(order) 

391 w = forwards[node] 

392 if len(w) == 0: 

393 continue 

394 last = w.pop() 

395 stack.append(last) 

396 

397 return order 

398 

399 def summarize(self): 

400 """ 

401 Creates a text summary of the graph. 

402 """ 

403 order = self.order() 

404 keys = [(o, k) for k, o in order.items()] 

405 keys.sort() 

406 

407 rows = [] 

408 for _, k in keys: 

409 if k in self.v1: 

410 rows.append(str(self.v1[k])) 

411 return "\n".join(rows) 

412 

413 @staticmethod 

414 def _onnx2bigraph_basic(model_onnx, recursive=False): 

415 """ 

416 Implements graph type `'basic'` for function 

417 @see fn onnx2bigraph. 

418 """ 

419 

420 if recursive: 

421 raise NotImplementedError( # pragma: no cover 

422 "Option recursive=True is not implemented yet.") 

423 v0 = {} 

424 v1 = {} 

425 edges = {} 

426 

427 # inputs 

428 for i, o in enumerate(model_onnx.graph.input): 

429 v0[o.name] = BiGraph.A('Input-%d' % i) 

430 for i, o in enumerate(model_onnx.graph.output): 

431 v0[o.name] = BiGraph.A('Output-%d' % i) 

432 for o in model_onnx.graph.initializer: 

433 v0[o.name] = BiGraph.A('Init') 

434 for n in model_onnx.graph.node: 

435 nname = n.name if len(n.name) > 0 else "id%d" % id(n) 

436 v1[nname] = BiGraph.A(n.op_type) 

437 for i, o in enumerate(n.input): 

438 c = str(i) if i < 10 else "+" 

439 nname = n.name if len(n.name) > 0 else "id%d" % id(n) 

440 edges[o, nname] = BiGraph.A(f'I{c}') 

441 for i, o in enumerate(n.output): 

442 c = str(i) if i < 10 else "+" 

443 if o not in v0: 

444 v0[o] = BiGraph.A('inout') 

445 nname = n.name if len(n.name) > 0 else "id%d" % id(n) 

446 edges[nname, o] = BiGraph.A(f'O{c}') 

447 

448 return BiGraph(v0, v1, edges) 

449 

450 @staticmethod 

451 def _onnx2bigraph_simplified(model_onnx, recursive=False): 

452 """ 

453 Implements graph type `'simplified'` for function 

454 @see fn onnx2bigraph. 

455 """ 

456 if recursive: 

457 raise NotImplementedError( # pragma: no cover 

458 "Option recursive=True is not implemented yet.") 

459 v0 = {} 

460 v1 = {} 

461 edges = {} 

462 

463 # inputs 

464 for o in model_onnx.graph.input: 

465 v0[f"I{len(v0)}"] = BiGraph.B( 

466 'In', make_hash_bytes(o.type.SerializeToString(), 2), o.name) 

467 for o in model_onnx.graph.output: 

468 v0[f"O{len(v0)}"] = BiGraph.B( 

469 'Ou', make_hash_bytes(o.type.SerializeToString(), 2), o.name) 

470 for o in model_onnx.graph.initializer: 

471 v0[f"C{len(v0)}"] = BiGraph.B( 

472 'Cs', make_hash_bytes(o.raw_data, 10), o.name) 

473 

474 names_v0 = {v.onnx_name: k for k, v in v0.items()} 

475 

476 for n in model_onnx.graph.node: 

477 key_node = f"N{len(v1)}" 

478 if len(n.attribute) > 0: 

479 ats = [] 

480 for at in n.attribute: 

481 ats.append(at.SerializeToString()) 

482 ct = make_hash_bytes(b"".join(ats), 10) 

483 else: 

484 ct = "" 

485 v1[key_node] = BiGraph.B( 

486 n.op_type, ct, n.name) 

487 for o in n.input: 

488 key_in = names_v0[o] 

489 edges[key_in, key_node] = BiGraph.A('I') 

490 for o in n.output: 

491 if o not in names_v0: 

492 key = f"R{len(v0)}" 

493 v0[key] = BiGraph.B('Re', n.op_type, o) 

494 names_v0[o] = key 

495 edges[key_node, key] = BiGraph.A('O') 

496 

497 return BiGraph(v0, v1, edges) 

498 

499 @staticmethod 

500 def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print): 

501 """ 

502 Computes a distance between two ONNX graphs. They must not 

503 be too big otherwise this function might take for ever. 

504 The function relies on package :epkg:`mlstatpy`. 

505 

506 :param onx1: first graph (ONNX graph or model file name) 

507 :param onx2: second graph (ONNX graph or model file name) 

508 :param verbose: verbosity 

509 :param fLOG: logging function 

510 :return: distance and differences 

511 

512 .. warning:: 

513 

514 This is very experimental and very slow. 

515 

516 .. versionadded:: 0.7 

517 """ 

518 from mlstatpy.graph.graph_distance import GraphDistance 

519 

520 if isinstance(onx1, str): 

521 onx1 = onnx.load(onx1) 

522 if isinstance(onx2, str): 

523 onx2 = onnx.load(onx2) 

524 

525 def make_hash(init): 

526 return make_hash_bytes(init.raw_data) 

527 

528 def build_graph(onx): 

529 edges = [] 

530 labels = {} 

531 for node in list(onx.graph.node): 

532 if len(node.name) == 0: 

533 name = str(id(node)) 

534 else: 

535 name = node.name 

536 for i in node.input: 

537 edges.append((i, name)) 

538 for p, i in enumerate(node.output): 

539 edges.append((name, i)) 

540 labels[i] = "%s:%d" % (node.op_type, p) 

541 labels[name] = node.op_type 

542 for init in onx.graph.initializer: 

543 labels[init.name] = make_hash(init) 

544 

545 g = GraphDistance(edges, vertex_label=labels) 

546 return g 

547 

548 g1 = build_graph(onx1) 

549 g2 = build_graph(onx2) 

550 

551 dist, gdist = g1.distance_matching_graphs_paths( 

552 g2, verbose=verbose, fLOG=fLOG, use_min=False) 

553 return dist, gdist 

554 

555 

556def onnx2bigraph(model_onnx, recursive=False, graph_type='basic'): 

557 """ 

558 Converts an ONNX graph into a graph representation, 

559 edges, vertices. 

560 

561 :param model_onnx: ONNX graph 

562 :param recursive: dig into subgraphs too 

563 :param graph_type: kind of graph it creates 

564 :return: see @cl BiGraph 

565 

566 About *graph_type*: 

567 

568 * `'basic'`: basic graph structure, it returns an instance 

569 of type @see cl BiGraph. The structure keeps the original 

570 names. 

571 * `'simplified'`: simplifed graph structure, names are removed 

572 as they could be prevent the algorithm to find any matching. 

573 

574 .. exref:: 

575 :title: Displays an ONNX graph as text 

576 

577 The function uses an adjacency matrix of the graph. 

578 Results are displayed by rows, operator by columns. 

579 Results kinds are shows on the left, 

580 their names on the right. Operator types are displayed 

581 on the top, their names on the bottom. 

582 

583 .. runpython:: 

584 :showcode: 

585 

586 import numpy 

587 from mlprodict.onnx_conv import to_onnx 

588 from mlprodict import __max_supported_opset__ as opv 

589 from mlprodict.tools.graphs import onnx2bigraph 

590 from mlprodict.npy.xop import loadop 

591 

592 OnnxAdd, OnnxSub = loadop('Add', 'Sub') 

593 

594 idi = numpy.identity(2).astype(numpy.float32) 

595 A = OnnxAdd('X', idi, op_version=opv) 

596 B = OnnxSub(A, 'W', output_names=['Y'], op_version=opv) 

597 onx = B.to_onnx({'X': idi, 'W': idi}) 

598 bigraph = onnx2bigraph(onx) 

599 graph = bigraph.display_structure() 

600 text = graph.to_text() 

601 print(text) 

602 

603 .. versionadded:: 0.7 

604 """ 

605 if graph_type == 'basic': 

606 return BiGraph._onnx2bigraph_basic( 

607 model_onnx, recursive=recursive) 

608 if graph_type == 'simplified': 

609 return BiGraph._onnx2bigraph_simplified( 

610 model_onnx, recursive=recursive) 

611 raise ValueError( 

612 f"Unknown value for graph_type={graph_type!r}.") 

613 

614 

615def onnx_graph_distance(onx1, onx2, verbose=0, fLOG=print): 

616 """ 

617 Computes a distance between two ONNX graphs. They must not 

618 be too big otherwise this function might take for ever. 

619 The function relies on package :epkg:`mlstatpy`. 

620 

621 :param onx1: first graph (ONNX graph or model file name) 

622 :param onx2: second graph (ONNX graph or model file name) 

623 :param verbose: verbosity 

624 :param fLOG: logging function 

625 :return: distance and differences 

626 

627 .. warning:: 

628 

629 This is very experimental and very slow. 

630 

631 .. versionadded:: 0.7 

632 """ 

633 return BiGraph.onnx_graph_distance(onx1, onx2, verbose=verbose, fLOG=fLOG)