Source code for onnxmltools.utils.visualize

# SPDX-License-Identifier: Apache-2.0

import os
import shutil
import sys
from webbrowser import open_new_tab


def get_set_node(node, i="0"):
    return "g.setNode(" + str(i) + ", { label: '" + node + "', class: 'type-" + node + "' });"


def get_set_edge(start, end):
    return "g.setEdge(" + str(start) + ", " + str(end) + ");"


def get_nodes(graph):
    graph_nodes = [(i, node.op_type) for i, node in enumerate(graph.node, 0)]
    graph_nodes.extend([(i, node.name)
                        for i, node in enumerate(graph.input, len(graph_nodes))])
    graph_nodes.extend([(i, node.name)
                        for i, node in enumerate(graph.output, len(graph_nodes) + 1)])
    return graph_nodes


def get_nodes_builder(graph_nodes):
    _ret = [get_set_node(node[1], node[0]) for node in graph_nodes]
    return _ret


def get_edges(graph):
    nodes = graph.node
    initializer_names = [init.name for init in graph.initializer]
    output_node_hash = {}
    edge_list = []
    for i, node in enumerate(nodes, 0):
        for output in node.output:
            if output in output_node_hash.keys():
                output_node_hash[output].append(i)
            else:
                output_node_hash[output] = [i]
    for i, inp in enumerate(graph.input, len(nodes)):
        output_node_hash[inp.name] = [i]
    for i, node in enumerate(nodes, 0):
        for input in node.input:
            if input in output_node_hash.keys():
                edge_list.extend([(node_id, i)
                                  for node_id in output_node_hash[input]])
            else:
                if not input in initializer_names:
                    print(
                        "No corresponding output found for {0}.".format(input))
    for i, output in enumerate(graph.output, len(nodes) + len(graph.input) + 1):
        if output.name in output_node_hash.keys():
            edge_list.extend([(node_id, i)
                              for node_id in output_node_hash[output.name]])
        else:
            pass
    return edge_list


[docs]def visualize_model(onnx_model, open_browser=True, dest="index.html"): """ Creates a graph visualization of an ONNX protobuf model. It creates a SVG graph with *d3.js* and stores it into a file. :param model: ONNX model (protobuf object) :param open_browser: opens the browser :param dest: destination file Example: :: from onnxmltools.utils import visualize_model visualize_model(model) """ graph = onnx_model.graph model_info = "Model produced by: " + onnx_model.producer_name + \ " version(" + onnx_model.producer_version + ")" html_str = """ <!doctype html> <meta charset="utf-8"> <title>ONNX Visualization</title> <script src="https://d3js.org/d3.v3.min.js"></script> <link rel="stylesheet" href="styles.css"> <script src="dagre-d3.min.js"></script> <h2>[model_info]</h2> <svg id="svg-canvas" width=960 height=600></svg> <script id="js"> var g = new dagreD3.graphlib.Graph() .setGraph({}) .setDefaultEdgeLabel(function() { return {}; }); [nodes_html] g.nodes().forEach(function(v) { var node = g.node(v); // Round the corners of the nodes node.rx = node.ry = 5; }); [edges_html] // Create the renderer var render = new dagreD3.render(); // Set up an SVG group so that we can translate the final graph. var svg = d3.select("svg"), svgGroup = svg.append("g"); // Run the renderer. This is what draws the final graph. render(d3.select("svg g"), g); // Center the graph svgGroup.attr("transform", "translate(20, 20)"); svg.attr("height", g.graph().height + 40); svg.attr("width", g.graph().width + 40); </script> """ html_str = html_str.replace("[nodes_html]", "\n".join( get_nodes_builder(get_nodes(graph)))) html_str = html_str.replace("[edges_html]", "\n".join( [get_set_edge(edge[0], edge[1]) for edge in get_edges(graph)])) html_str = html_str.replace("[model_info]", model_info) Html_file = open(dest, "w") Html_file.write(html_str) Html_file.close() pkgdir = sys.modules['onnxmltools'].__path__[0] fullpath = os.path.join(pkgdir, "utils", "styles.css") shutil.copy(fullpath, os.getcwd()) fullpath = os.path.join(pkgdir, "utils", "dagre-d3.min.js") shutil.copy(fullpath, os.getcwd()) open_new_tab("file://" + os.path.realpath("index.html"))