Coverage for src/mlstatpy/ml/kppv_laesa.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v6.4.1, created at 2022-06-13 20:42 +0200

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Implements optimized k-nn. 

5""" 

6import random 

7import numpy 

8from .kppv import NuagePoints 

9 

10 

11class NuagePointsLaesa (NuagePoints): 

12 """ 

13 Implémente l'algorithme des plus proches voisins, 

14 version :ref:`LAESA <space_metric_algo_laesa_prime>`_ 

15 """ 

16 

17 def __init__(self, nb_pivots): 

18 """ 

19 Construit la classe 

20 

21 @param nb_pivots number of pivots 

22 """ 

23 NuagePoints.__init__(self) 

24 self.nb_pivots = nb_pivots 

25 

26 def fit(self, X, y=None): 

27 """ 

28 Follows sklearn API. 

29 

30 @param X training set 

31 @param y labels 

32 """ 

33 self.nuage = X 

34 self.labels = y 

35 self.selection_pivots(self.nb_pivots) 

36 

37 def selection_pivots(self, nb): 

38 """ 

39 Sélectionne *nb* pivots aléatoirements. 

40 

41 @param nb nombre de pivots 

42 """ 

43 nb = min(nb, self.nuage.shape[0]) 

44 if nb == 1: 

45 self.pivots = [2] 

46 else: 

47 self.pivots = set() 

48 while len(self.pivots) < nb: 

49 i = random.randint(0, self.nuage.shape[0] - 1) 

50 if i not in self.pivots: 

51 self.pivots.add(i) 

52 self.pivots = list(sorted(self.pivots)) 

53 

54 # on calcule aussi la distance de chaque éléments au pivots 

55 self.dist = numpy.zeros((self.nuage.shape[0], len(self.pivots))) 

56 for i in range(self.nuage.shape[0]): 

57 for j in range(len(self.pivots)): # pylint: disable=C0200 

58 self.dist[i, j] = self.distance( 

59 self.nuage[i, :], self.nuage[self.pivots[j], :]) 

60 

61 def ppv(self, obj): 

62 """ 

63 Retourne l'élément le plus proche de obj et sa distance avec obj, 

64 utilise la sélection à l'aide pivots 

65 

66 @param obj object 

67 @return ``tuple(distance, index)`` 

68 """ 

69 

70 # initialisation 

71 dp = [(self.distance(obj, self.nuage[p, :]), p, i) 

72 for i, p in enumerate(self.pivots)] 

73 

74 # pivots le plus proche 

75 dm, im, _ = min(dp) 

76 

77 # améliorations 

78 for i in range(0, self.nuage.shape[0]): 

79 

80 # on regarde si un pivot permet d'éliminer l'élément i 

81 calcul = True 

82 for d, p, ip in dp: 

83 delta = abs(d - self.dist[i, ip]) 

84 if delta > dm: 

85 calcul = False 

86 break 

87 

88 # dans le cas contraire on calcule la distance 

89 if calcul: 

90 d = self.distance(obj, self.nuage[i, :]) 

91 if d < dm: 

92 dm = d 

93 im = i 

94 

95 return dm, im