Coverage for mlprodict/testing/test_utils/__init__.py: 100%
12 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Inspired from sklearn-onnx, handles two backends.
4"""
5import numpy
6from .utils_backend_onnxruntime import _capture_output
9from .tests_helper import ( # noqa
10 binary_array_to_string,
11 dump_data_and_model,
12 dump_one_class_classification,
13 dump_binary_classification,
14 dump_multilabel_classification,
15 dump_multiple_classification,
16 dump_multiple_regression,
17 dump_single_regression,
18 convert_model,
19 fit_classification_model,
20 fit_classification_model_simple,
21 fit_multilabel_classification_model,
22 fit_regression_model)
25def create_tensor(N, C, H=None, W=None):
26 "Creates a tensor."
27 if H is None and W is None:
28 return numpy.random.rand(N, C).astype(numpy.float32, copy=False) # pylint: disable=E1101
29 elif H is not None and W is not None:
30 return numpy.random.rand(N, C, H, W).astype(numpy.float32, copy=False) # pylint: disable=E1101
31 raise ValueError( # pragma no cover
32 'This function only produce 2-D or 4-D tensor.')
35def ort_version_greater(ver):
36 """
37 Tells if onnxruntime version is greater than *ver*.
39 :param ver: version as a string
40 :return: boolean
41 """
42 from onnxruntime import __version__
43 from pyquickhelper.texthelper.version_helper import compare_module_version
44 return compare_module_version(__version__, ver) >= 0