Coverage for mlprodict/grammar/grammar_sklearn/g_sklearn_identify.py: 100%

19 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Helpers to identify an interpreter. 

5""" 

6import keyword 

7import re 

8from .g_sklearn_linear_model import sklearn_logistic_regression, sklearn_linear_regression 

9from .g_sklearn_preprocessing import sklearn_standard_scaler 

10from .g_sklearn_tree import sklearn_decision_tree_regressor 

11 

12 

13def __pep8(): # pragma: no cover 

14 assert sklearn_decision_tree_regressor 

15 assert sklearn_linear_regression 

16 assert sklearn_logistic_regression 

17 assert sklearn_standard_scaler 

18 

19 

20def change_style(name): 

21 """ 

22 Switches from *AaBb* into *aa_bb*. 

23 

24 @param name name to convert 

25 @return converted name 

26 """ 

27 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 

28 s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 

29 return s2 if not keyword.iskeyword(s2) else s2 + "_" 

30 

31 

32def identify_interpreter(model): 

33 """ 

34 Identifies the interpreter for a *scikit-learn* model. 

35 

36 @param model model to identify 

37 @return interpreter 

38 """ 

39 class_name = model.__class__.__name__ 

40 pyname = change_style(class_name) 

41 skconv = "sklearn_" + pyname 

42 loc = globals().copy() 

43 convs = {k: v for k, v in loc.items() if k.startswith("sklearn")} 

44 if len(convs) == 0: 

45 raise ValueError( # pragma: no cover 

46 "No found interpreters, possibilities=\n{0}".format( 

47 "\n".join(sorted(loc.keys())))) 

48 if skconv in convs: 

49 return convs[skconv] 

50 raise NotImplementedError( # pragma: no cover 

51 "Model class '{0}' is not yet implemented. Available interpreters:\n{1}".format( 

52 class_name, "\n".join( 

53 sorted( 

54 convs.keys()))))