Coverage for mlprodict/testing/test_utils/__init__.py: 100%

12 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Inspired from sklearn-onnx, handles two backends. 

4""" 

5import numpy 

6from .utils_backend_onnxruntime import _capture_output 

7 

8 

9from .tests_helper import ( # noqa 

10 binary_array_to_string, 

11 dump_data_and_model, 

12 dump_one_class_classification, 

13 dump_binary_classification, 

14 dump_multilabel_classification, 

15 dump_multiple_classification, 

16 dump_multiple_regression, 

17 dump_single_regression, 

18 convert_model, 

19 fit_classification_model, 

20 fit_classification_model_simple, 

21 fit_multilabel_classification_model, 

22 fit_regression_model) 

23 

24 

25def create_tensor(N, C, H=None, W=None): 

26 "Creates a tensor." 

27 if H is None and W is None: 

28 return numpy.random.rand(N, C).astype(numpy.float32, copy=False) # pylint: disable=E1101 

29 elif H is not None and W is not None: 

30 return numpy.random.rand(N, C, H, W).astype(numpy.float32, copy=False) # pylint: disable=E1101 

31 raise ValueError( # pragma no cover 

32 'This function only produce 2-D or 4-D tensor.') 

33 

34 

35def ort_version_greater(ver): 

36 """ 

37 Tells if onnxruntime version is greater than *ver*. 

38 

39 :param ver: version as a string 

40 :return: boolean 

41 """ 

42 from onnxruntime import __version__ 

43 from pyquickhelper.texthelper.version_helper import compare_module_version 

44 return compare_module_version(__version__, ver) >= 0