Coverage for src/code_beatrix/algorithm/classroom.py: 79%

159 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-04-29 13:45 +0200

1""" 

2@file 

3@brief Positions in a classroom 

4""" 

5 

6import random 

7from pyquickhelper.loghelper import noLOG 

8from .data import load_prenoms_w 

9 

10 

11def plot_positions(positions, edges=None, ax=None, **options): 

12 """ 

13 Draws positions and first names into a graph. 

14 

15 @param positions list of 3-uple (name, x, y) 

16 @param ax axis 

17 @param edges edges 

18 @param options options for matplotlib 

19 @return ax 

20 

21 First position: 0 

22 """ 

23 import matplotlib.pyplot as plt 

24 from matplotlib.patches import Rectangle 

25 

26 if ax is None: 

27 _, ax = plt.subplots( 

28 nrows=1, ncols=1, figsize=options.get('figsize', (5, 5))) 

29 

30 if isinstance(positions, dict): 

31 positions = [(k,) + v for k, v in positions.items()] 

32 maxx = -1 

33 maxy = -1 

34 for name, x, y in positions: 

35 r = Rectangle((x - 0.45, y - 0.45), 0.9, 0.9, 

36 fill=(0, 0, 255), alpha=0.5) 

37 ax.add_patch(r) 

38 ax.text(x * 1.0, y * 1.0, name, 

39 verticalalignment='center', horizontalalignment='center', 

40 fontsize=options.get('fontsize', 15), 

41 color=options.get('color_text', (0, 0, 0))) 

42 maxx = max(x, maxx) 

43 maxy = max(y, maxy) 

44 if edges is not None: 

45 posdict = {k: (x, y) for k, x, y in positions} 

46 if isinstance(edges, list): 

47 for e1, e2 in edges: 

48 p1 = posdict[e1] 

49 p2 = posdict[e2] 

50 if p1 != p2: 

51 d0 = p2[0] - p1[0] 

52 dx = (d0 / abs(d0) * 0.1) if p2[0] != p1[0] else 0.0 

53 d1 = p2[1] - p1[1] 

54 dy = (d1 / abs(d1) * 0.1) if p2[1] != p1[1] else 0.0 

55 d = distance(p1, p2) 

56 if d < 1.1: 

57 color = "y" 

58 elif d < 1.9: 

59 color = "b" 

60 else: 

61 color = "r" 

62 ax.arrow(p1[0] + dx, p1[1] + dy, 

63 p2[0] - p1[0] - dx, p2[1] - p1[1] - dy, 

64 color=color, shape="full", 

65 head_width=0.05, head_length=0.1, lw=3) 

66 else: 

67 raise TypeError("edges should be list") 

68 ax.set_xlim([-1, maxx + 1]) 

69 ax.set_ylim([-1, maxy + 1]) 

70 return ax 

71 

72 

73def random_positions(nb, names=None): 

74 """ 

75 Draws random position for some person in a classroom. 

76 

77 @param nb number of persons 

78 @param names names (None for default) 

79 @return list of 3-uple(name, x, y) 

80 """ 

81 if names is None: 

82 names = load_prenoms_w() 

83 names = names[:nb] 

84 

85 if nb > len(names): 

86 raise ValueError("nb={} > len(names)={}".format(nb, len(names))) 

87 names = names.copy() 

88 random.shuffle(names) 

89 

90 nbs = int(nb ** 0.5) 

91 if nbs != nb**0.5: 

92 nbs += 1 

93 positions = [] 

94 

95 ci = 0 

96 cj = 0 

97 for name in names: 

98 positions.append((name, ci, cj)) 

99 cj += 1 

100 if cj >= nbs: 

101 ci += 1 

102 cj = 0 

103 return positions 

104 

105 

106def distance(p1, p2): 

107 """ 

108 Computes the distance between two positions. 

109 

110 @param p1 position 1 

111 @param p2 position 2 

112 @return distance 

113 """ 

114 return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5 

115 

116 

117def measure_positions(positions, edges): 

118 """ 

119 Returns the sum of edges weights. 

120 

121 @param positions dictionary ``{ name : (x, y) }`` 

122 @param edges list of affinities ``(name1, name2)`` 

123 @return distance 

124 """ 

125 if isinstance(edges, list): 

126 s = 0 

127 for name1, name2 in edges: 

128 s += distance(positions[name1], positions[name2]) 

129 return s 

130 else: 

131 s = 0 

132 for name, names in edges.items(): 

133 s += sum(distance(positions[name], positions[o]) for o in names) 

134 return s / 2.0 

135 

136 

137def find_best_positions_greedy(positions, edges, name): 

138 """ 

139 Finds the best position for name, explore all positions. 

140 

141 @param positions dictionary ``{ name : (x, y) }`` 

142 @param edges list of affinities as a dictionary ``{ name: [names] }`` 

143 @param name name to optimize 

144 @return list of positions 

145 """ 

146 if not isinstance(edges, dict): 

147 raise TypeError("edges must be dict") 

148 if name not in edges: 

149 # nothing to do 

150 return None 

151 else: 

152 d0 = measure_positions(positions, edges) 

153 deltas = [] 

154 p0 = positions[name] 

155 for na, pos in positions.items(): 

156 c = positions.copy() 

157 p = positions[na] 

158 c[na] = p0 

159 c[name] = p 

160 dall = measure_positions(c, edges) - d0 

161 deltas.append((dall, pos)) 

162 

163 deltas.sort() 

164 return deltas 

165 

166 

167def optimize_positions(positions, edges, max_iter=100, fLOG=noLOG, 

168 plot_folder=None): 

169 """ 

170 Optimizes the positions. 

171 

172 @param positions dictionary ``{ name : (x, y) }`` 

173 @param edges list of affinities ``(name1, name2)`` 

174 @param max_iter maximum number of iterations 

175 @param plot_folder if not None, saves images into this folder 

176 @return positions, iterations 

177 """ 

178 edges_dict = {} 

179 for name1, name2 in edges: 

180 if name1 in edges_dict: 

181 edges_dict[name1].append(name2) 

182 else: 

183 edges_dict[name1] = [name2] 

184 if name2 in edges_dict: 

185 edges_dict[name2].append(name1) 

186 else: 

187 edges_dict[name2] = [name1] 

188 edges_dict = {k: set(v) for k, v in edges_dict.items()} 

189 

190 fLOG("[optimize_positions] #edges=%d #edges_dict=%d" % 

191 (len(edges), len(edges_dict))) 

192 

193 if isinstance(positions, list): 

194 positions = {k: (x, y) for k, x, y in positions} 

195 

196 def find_name(positions, edges_dict): 

197 keys = list(sorted(positions.keys())) 

198 name = keys[random.randint(0, len(keys) - 1)] 

199 while name not in edges_dict: 

200 name = keys[random.randint(0, len(keys) - 1)] 

201 return name 

202 

203 list_positions = {pos: 0 for _, pos in positions.items()} 

204 for _, v in positions.items(): 

205 list_positions[v] += 1 

206 if max(list_positions.values()) > 1: 

207 raise ValueError("duplicated position:\n{0}".format( 

208 str({k: v for k, v in list_positions.items() if v > 1}))) 

209 

210 name = find_name(positions, edges_dict) 

211 fLOG("[optimize_positions] name='%s' pos=%s" % 

212 (name, str(positions[name]))) 

213 total = measure_positions(positions, edges) 

214 iter = 0 

215 memo = [(total, name, positions[name])] 

216 while iter < max_iter: 

217 

218 if plot_folder is not None: 

219 import os 

220 import matplotlib.pyplot as plt 

221 fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8)) 

222 plot_positions(positions, edges=edges, ax=ax) 

223 img = os.path.join(plot_folder, "classroom_%04d.png" % iter) 

224 fig.savefig(img) 

225 plt.close('all') 

226 

227 deltas = find_best_positions_greedy(positions, edges_dict, name) 

228 delta, new_pos = deltas[0] 

229 if new_pos == positions[name] or delta >= 0: 

230 # no change, we put the name in the empty spot 

231 name = None 

232 else: 

233 rev = {v: k for k, v in positions.items() if k != name} 

234 current_name = rev[new_pos] 

235 fLOG("[optimize_positions] iter=%d total=%1.3f name='%s' <--> '%s' delta=%1.3f new_pos=%s" % 

236 (iter, total, name, current_name, delta, str(new_pos))) 

237 

238 # we switch 

239 old_pos = positions[name] 

240 positions[name] = new_pos 

241 positions[current_name] = old_pos 

242 

243 # next name 

244 name = current_name 

245 if name not in edges_dict: 

246 name = None 

247 else: 

248 list_positions = {pos: 0 for _, pos in positions.items()} 

249 for k, v in positions.items(): 

250 list_positions[v] += 1 

251 sup = {k: v for k, v in list_positions.items() if v > 1} 

252 

253 if name is None: 

254 name = find_name(positions, edges_dict) 

255 if name is None: 

256 raise ValueError("impossible") 

257 

258 total = measure_positions(positions, edges) 

259 memo.append((total, name, positions[name])) 

260 

261 iter += 1 

262 

263 # final check 

264 list_positions = {pos: 0 for _, pos in positions.items()} 

265 for k, v in positions.items(): 

266 list_positions[v] += 1 

267 sup = {k: v for k, v in list_positions.items() if v > 1} 

268 if len(sup) > 0: 

269 raise ValueError( 

270 "Too many first names at the same positions: {0}".format(sup)) 

271 return positions, memo