Coverage for mlprodict/onnx_tools/optim/onnx_optimisation.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Optimisations of :epkg:`ONNX` graphs. 

4""" 

5from ..model_checker import check_onnx 

6from ._onnx_optimisation_common import _apply_optimisation_on_graph 

7from .onnx_optimisation_identity import onnx_remove_node_identity 

8from .onnx_optimisation_redundant import onnx_remove_node_redundant 

9from .onnx_optimisation_unused import onnx_remove_node_unused 

10 

11 

12def onnx_remove_node(onnx_model, recursive=True, debug_info=None, **options): 

13 """ 

14 Removes as many nodes as possible without changing 

15 the outcome. It applies @see fn onnx_remove_node_unused, 

16 @see fn onnx_remove_node_identity, 

17 and @see fn onnx_remove_node_redundant. 

18 

19 @param onnx_model onnx model 

20 @param recursive looks into subgraphs 

21 @param debug_info debug information (private) 

22 @param options additional options 

23 @return new onnx model 

24 """ 

25 if debug_info is None: 

26 debug_info = [str(type(onnx_model)).rsplit( 

27 '.', maxsplit=1)[-1].strip("'>")] 

28 else: 

29 debug_info = (debug_info + 

30 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")]) 

31 

32 if hasattr(onnx_model, 'graph'): 

33 return _apply_optimisation_on_graph( 

34 onnx_remove_node, onnx_model, 

35 recursive=recursive, debug_info=debug_info, 

36 **options) 

37 

38 check_onnx(onnx_model) 

39 graph = onnx_model 

40 graph = onnx_remove_node_unused( 

41 graph, recursive=recursive, debug_info=debug_info, **options) 

42 check_onnx(graph) 

43 graph = onnx_remove_node_identity( 

44 graph, recursive=recursive, debug_info=debug_info, **options) 

45 check_onnx(graph) 

46 graph = onnx_remove_node_redundant( 

47 graph, recursive=recursive, debug_info=debug_info, **options) 

48 check_onnx(graph) 

49 return graph