Source code for onnxmltools.convert.xgboost.convert

# SPDX-License-Identifier: Apache-2.0

from uuid import uuid4
import xgboost
import onnx
from ..common.onnx_ex import get_maximum_opset_supported
from ..common._topology import convert_topology
from ._parse import parse_xgboost, WrappedBooster

# Invoke the registration of all our converters and shape calculators
# from . import shape_calculators
from . import operator_converters, shape_calculators


[docs]def convert(model, name=None, initial_types=None, doc_string='', target_opset=None, targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None): ''' This function produces an equivalent ONNX model of the given xgboost model. :param model: A xgboost model :param initial_types: a python list. Each element is a tuple of a variable name and a type defined in data_types.py :param name: The name of the graph (type: GraphProto) in the produced ONNX model (type: ModelProto) :param doc_string: A string attached onto the produced ONNX model :param target_opset: number, for example, 7 for ONNX 1.2, and 8 for ONNX 1.3. :param targeted_onnx: A string (for example, '1.1.2' and '1.2') used to specify the targeted ONNX version of the produced model. If ONNXMLTools cannot find a compatible ONNX python package, an error may be thrown. :param custom_conversion_functions: a dictionary for specifying the user customized conversion function :param custom_shape_calculators: a dictionary for specifying the user customized shape calculator :return: An ONNX model (type: ModelProto) which is equivalent to the input xgboost model ''' if initial_types is None: raise ValueError('Initial types are required. See usage of convert(...) in \ onnxmltools.convert.xgboost.convert for details') if name is None: name = str(uuid4().hex) if isinstance(model, xgboost.Booster): model = WrappedBooster(model) target_opset = target_opset if target_opset else get_maximum_opset_supported() topology = parse_xgboost(model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators) topology.compile() onnx_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx) return onnx_model