Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1# coding: utf-8 

2""" 

3Common functions for :epkg:`SIR` models. 

4""" 

5import numpy 

6from sympy import symbols, Symbol, latex, lambdify 

7import sympy.printing as printing 

8from sympy.parsing.sympy_parser import ( 

9 parse_expr, standard_transformations, implicit_application) 

10from ._sympy_helper import enumerate_traverse 

11from ._base_sir_sim import BaseSIRSimulation 

12from ._base_sir_estimation import BaseSIREstimation 

13 

14 

15class BaseSIR(BaseSIRSimulation, BaseSIREstimation): 

16 """ 

17 Base model for :epkg:`SIR` models. 

18 

19 :param p: list of `[(name, initial value or None, comment)]` (parameters) 

20 :param q: list of `[(name, initial value or None, comment)]` (quantities) 

21 :param c: list of `[(name, initial value or None, comment)]` (constants) 

22 :param eq: equations 

23 """ 

24 _pickled_atts = [ 

25 '_p', '_q', '_c', '_eq', '_val_p', '_val_q', '_val_c', 

26 '_val_ind', '_val_len', '_syms'] 

27 

28 def __init__(self, p, q, c=None, eq=None, **kwargs): 

29 if not isinstance(p, list): 

30 raise TypeError("p must be a list of tuple.") 

31 if not isinstance(q, list): 

32 raise TypeError("q must be a list of tuple.") 

33 if not isinstance(c, list): 

34 raise TypeError("c must be a list of tuple.") 

35 if eq is not None and not isinstance(eq, dict): 

36 raise TypeError("eq must be a dictionary.") 

37 self._p = p 

38 self._q = q 

39 self._c = c 

40 if eq is not None: 

41 locs = {'t': symbols('t', cls=Symbol)} 

42 for v in self._p: 

43 locs[v[0]] = symbols(v[0], cls=Symbol) 

44 for v in self._c: 

45 locs[v[0]] = symbols(v[0], cls=Symbol) 

46 for v in self._q: 

47 locs[v[0]] = symbols(v[0], cls=Symbol) 

48 self._syms = locs 

49 tr = standard_transformations + (implicit_application, ) 

50 self._eq = {} 

51 for k, v in eq.items(): 

52 try: 

53 self._eq[k] = parse_expr(v, locs, transformations=tr) 

54 except (TypeError, ValueError) as e: # pragma: no cover 

55 raise RuntimeError( 

56 "Unable to parse '{}'.".format(v)) from e 

57 else: 

58 self._eq = None 

59 if len(kwargs) != 0: 

60 raise NotImplementedError( # pragma: no cover 

61 "Not implemented.") 

62 self._init() 

63 

64 def copy(self): 

65 inst = self.__class__.__new__(self.__class__) 

66 for k in BaseSIR._pickled_atts: 

67 setattr(inst, k, getattr(self, k)) 

68 if hasattr(inst, '_eq') and inst._eq is not None: 

69 inst._init_lambda_() 

70 return inst 

71 

72 def __getstate__(self): 

73 ''' 

74 Returns the pickled data. 

75 ''' 

76 return {k: getattr(self, k) for k in BaseSIR._pickled_atts} 

77 

78 def __setstate__(self, state): 

79 ''' 

80 Sets the pickled data. 

81 ''' 

82 for k, v in state.items(): 

83 setattr(self, k, v) 

84 if hasattr(self, '_eq') and self._eq is not None: 

85 self._init_lambda_() 

86 

87 def _init(self): 

88 """ 

89 Starts from the initial values. 

90 """ 

91 def _def_(name, v): 

92 if v is not None: 

93 return v 

94 if name == 'N': # pragma: no cover 

95 return 10000. 

96 return 0. # pragma: no cover 

97 

98 self._val_p = numpy.array( 

99 [_def_(v[0], v[1]) for v in self._p], dtype=numpy.float64) 

100 self._val_q = numpy.array( 

101 [_def_(v[0], v[1]) for v in self._q], dtype=numpy.float64) 

102 self._val_c = numpy.array( 

103 [_def_(v[0], v[1]) for v in self._c], dtype=numpy.float64) 

104 self._val_len = (len(self._val_p) + len(self._val_q) + 

105 len(self._val_c)) 

106 self._val_ind = numpy.array([ 

107 0, len(self._val_q), len(self._val_q) + len(self._val_p), 

108 len(self._val_q) + len(self._val_p) + len(self._val_c)]) 

109 

110 if hasattr(self, '_eq') and self._eq is not None: 

111 self._init_lambda_() 

112 

113 def _init_lambda_(self): 

114 self._leq = {} 

115 for k, v in self._eq.items(): 

116 fct = self._lambdify_(k, v) 

117 eval1 = float(self.evalf_eq(v)) 

118 eval2 = self.evalf_leq(k) 

119 err = (eval2 - eval1) / max(abs(eval1), abs(eval2)) 

120 if err > 1e-8: 

121 raise ValueError( # pragma: no cover 

122 "Lambdification failed for function '{}': {} " 

123 "({} ({}) != {} ({}), error={})".format( 

124 k, v, eval1, type(eval1), eval2, type(eval2), err)) 

125 self._leq[k] = fct 

126 self._leqa = [self._leq[_[0]] for _ in self._q] 

127 

128 def get_index(self, name): 

129 ''' 

130 Returns the index of a name (True or False, position). 

131 ''' 

132 for i, v in enumerate(self._p): 

133 if v[0] == name: 

134 return 'p', i 

135 for i, v in enumerate(self._q): 

136 if v[0] == name: 

137 return 'q', i 

138 for i, v in enumerate(self._c): 

139 if v[0] == name: 

140 return 'c', i 

141 raise ValueError("Unable to find name '{}'.".format(name)) 

142 

143 def __setitem__(self, name, value): 

144 """ 

145 Updates a value whether it is a parameter or a quantity. 

146 

147 :param name: name 

148 :param value: new value 

149 """ 

150 p, pos = self.get_index(name) 

151 if p == 'p': 

152 self._val_p[pos] = value 

153 elif p == 'q': 

154 self._val_q[pos] = value 

155 elif p == 'c': 

156 self._val_c[pos] = value 

157 

158 def __getitem__(self, name): 

159 """ 

160 Retrieves a value whether it is a parameter or a quantity. 

161 

162 :param name: name 

163 :return: value 

164 """ 

165 p, pos = self.get_index(name) 

166 if p == 'p': 

167 return self._val_p[pos] 

168 if p == 'q': 

169 return self._val_q[pos] 

170 if p == 'c': 

171 return self._val_c[pos] 

172 

173 @property 

174 def names(self): 

175 'Returns the list of names.' 

176 return list(sorted( 

177 [v[0] for v in self._p] + [v[0] for v in self._q] + 

178 [v[0] for v in self._c])) 

179 

180 @property 

181 def quantity_names(self): 

182 'Returns the list of quantities names (unsorted).' 

183 return [v[0] for v in self._q] 

184 

185 @property 

186 def param_names(self): 

187 'Returns the list of parameters names (unsorted).' 

188 return [v[0] for v in self._p] 

189 

190 @property 

191 def params_dict(self): 

192 'Returns the list of parameters names in a dictionary.' 

193 return {k: self[k] for k in self.param_names} 

194 

195 @property 

196 def cst_names(self): 

197 'Returns the list of constants names (unsorted).' 

198 return [v[0] for v in self._c] 

199 

200 @property 

201 def vect_names(self): 

202 'Returns the list of names.' 

203 return ([v[0] for v in self._q] + [v[0] for v in self._p] + 

204 [v[0] for v in self._c] + ['t']) 

205 

206 def vect(self, t=0, out=None, derivative=False): 

207 """ 

208 Returns all values as a vector. 

209 

210 :param t: time *t* 

211 :param out: alternative output array in which to place the 

212 result. It must have the same shape as the expected output. 

213 :param derivative: returns the derivatives instead of the values 

214 :return: values or derivatives 

215 """ 

216 if derivative: 

217 if out is None: 

218 out = numpy.empty((self._val_len + 1 + self._val_ind[1], ), 

219 dtype=numpy.float64) 

220 self.vect(t=t, out=out) 

221 for i, v in enumerate(self._leqa): 

222 out[i - self._val_ind[1]] = v(*out[:self._val_len + 1]) 

223 else: 

224 if out is None: 

225 out = numpy.empty((self._val_len + 1, ), dtype=numpy.float64) 

226 out[:self._val_ind[1]] = self._val_q 

227 out[self._val_ind[1]:self._val_ind[2]] = self._val_p 

228 out[self._val_ind[2]:self._val_ind[3]] = self._val_c 

229 out[self._val_ind[3]] = t 

230 return out 

231 

232 @property 

233 def P(self): 

234 ''' 

235 Returns the parameters 

236 ''' 

237 return [(a[0], b, a[2]) for a, b in zip(self._p, self._val_p)] 

238 

239 @property 

240 def Q(self): 

241 ''' 

242 Returns the quantities 

243 ''' 

244 return [(a[0], b, a[2]) for a, b in zip(self._q, self._val_q)] 

245 

246 @property 

247 def C(self): 

248 ''' 

249 Returns the quantities 

250 ''' 

251 return [(a[0], b, a[2]) for a, b in zip(self._c, self._val_c)] 

252 

253 def update(self, **values): 

254 """Updates values.""" 

255 for k, v in values.items(): 

256 self[k] = v 

257 

258 def get(self): 

259 """Retrieves all values.""" 

260 return {n: self[n] for n in self.names} 

261 

262 def to_rst(self): 

263 ''' 

264 Returns a string formatted in RST. 

265 ''' 

266 rows = [ 

267 '*{}*'.format(self.__class__.__name__), 

268 '', 

269 '*Quantities*', 

270 '' 

271 ] 

272 for name, _, doc in self._q: 

273 rows.append('* *{}*: {}'.format(name, doc)) 

274 rows.extend(['', '*Constants*', '']) 

275 for name, _, doc in self._c: 

276 rows.append('* *{}*: {}'.format(name, doc)) 

277 rows.extend(['', '*Parameters*', '']) 

278 for name, _, doc in self._p: 

279 rows.append('* *{}*: {}'.format(name, doc)) 

280 if self._eq is not None: 

281 rows.extend(['', '*Equations*', '', '.. math::', 

282 '', ' \\begin{array}{l}']) 

283 for i, (k, v) in enumerate(sorted(self._eq.items())): 

284 line = "".join( 

285 [" ", "\\frac{d%s}{dt} = " % k, printing.latex(v)]) 

286 if i < len(self._eq) - 1: 

287 line += " \\\\" 

288 rows.append(line) 

289 rows.append(" \\end{array}") 

290 

291 return '\n'.join(rows) 

292 

293 def _repr_html_(self): 

294 ''' 

295 Returns a string formatted in RST. 

296 ''' 

297 rows = [ 

298 '<p><b>{}</b></p>'.format(self.__class__.__name__), 

299 '', 

300 '<p><i>Quantities</i></p>', 

301 '', 

302 '<ul>' 

303 ] 

304 for name, _, doc in self._q: 

305 rows.append('<li><i>{}</i>: {}</li>'.format(name, doc)) 

306 rows.extend(['</ul>', '', '<p><i>Constants</i></p>', '', '<ul>']) 

307 for name, _, doc in self._c: 

308 rows.append('<li><i>{}</i>: {}</li>'.format(name, doc)) 

309 rows.extend(['</ul>', '', '<p><i>Parameters</i></p>', '', '<ul>']) 

310 for name, _, doc in self._p: 

311 rows.append('<li><i>{}</i>: {}</li>'.format(name, doc)) 

312 if self._eq is not None: 

313 rows.extend(['</ul>', '', '<p><i>Equations</i></p>', '', '<ul>']) 

314 for i, (k, v) in enumerate(sorted(self._eq.items())): 

315 lats = "\\frac{d%s}{dt} = %s" % (k, printing.latex(v)) 

316 lat = latex(lats, mode='equation') 

317 line = "".join(["<li>", str(lat), '</li>']) 

318 rows.append(line) 

319 rows.append("</ul>") 

320 

321 return '\n'.join(rows) 

322 

323 def enumerate_edges(self): 

324 """ 

325 Enumerates the list of quantities contributing 

326 to others. It ignores constants. 

327 """ 

328 if self._eq is not None: 

329 params = set(_[0] for _ in self.P) 

330 quants = set(_[0] for _ in self.Q) 

331 for k, v in sorted(self._eq.items()): 

332 n2 = k 

333 n = [] 

334 for dobj in enumerate_traverse(v): 

335 term = dobj['e'] 

336 if not hasattr(term, 'name'): 

337 continue 

338 if term.name not in params: 

339 continue 

340 parent = dobj['p'] 

341 others = list( 

342 _['e'] for _ in enumerate_traverse(parent)) 

343 for o in others: 

344 if hasattr(o, 'name') and o.name in quants: 

345 sign = self.eqsign(n2, o.name) 

346 yield (sign, o.name, n2, term.name) 

347 if o.name != n2: 

348 n.append((sign, o.name, n2, term.name)) 

349 if len(n) == 0: 

350 yield (0, '?', n2, '?') 

351 

352 def to_dot(self, verbose=False, full=False): 

353 """ 

354 Produces a graph in :epkg:`DOT` format. 

355 """ 

356 rows = ['digraph{'] 

357 

358 pattern = (' {name} [label="{name}\\n{doc}" shape=record];' 

359 if verbose else 

360 ' {name} [label="{name}"];') 

361 for name, _, doc in self._q: 

362 rows.append(pattern.format(name=name, doc=doc)) 

363 for name, _, doc in self._c: 

364 rows.append(pattern.format(name=name, doc=doc)) 

365 

366 if self._eq is not None: 

367 pattern = ( 

368 ' {n1} -> {n2} [label="{sg}{name}\\nvalue={v:1.2g}"];' 

369 if verbose else ' {n1} -> {n2} [label="{sg}{name}"];') 

370 for sg, a, b, name in set(self.enumerate_edges()): 

371 if not full and (a == b or sg < 0): 

372 continue 

373 if name == '?': 

374 rows.append( # pragma: no cover 

375 pattern.format(n1=a, n2=b, name=name, 

376 v=numpy.nan, sg='0')) 

377 continue # pragma: no cover 

378 value = self[name] 

379 stsg = '' if sg > 0 else '-' 

380 rows.append( 

381 pattern.format(n1=a, n2=b, name=name, v=value, sg=stsg)) 

382 

383 rows.append('}') 

384 return '\n'.join(rows) 

385 

386 @property 

387 def cst_param(self): 

388 ''' 

389 Returns a dictionary with the constant and the parameters. 

390 ''' 

391 res = {} 

392 for k, v in zip(self._c, self._val_c): 

393 res[k[0]] = v 

394 for k, v in zip(self._p, self._val_p): 

395 res[k[0]] = v 

396 return res 

397 

398 def evalf_eq(self, eq, t=0): 

399 """ 

400 Evaluates an :epkg:`sympy` expression. 

401 """ 

402 svalues = self._eval_cache() 

403 svalues[self._syms['t']] = t 

404 for k, v in zip(self._q, self._val_q): 

405 svalues[self._syms[k[0]]] = v 

406 return eq.evalf(subs=svalues) 

407 

408 def evalf_leq(self, name, t=0): 

409 """ 

410 Evaluates a lambdified expression. 

411 

412 :param name: name of the lambdified expresion 

413 :param t: t values 

414 :return: evaluation 

415 """ 

416 leq = self._lambdified_(name) 

417 if leq is None: 

418 raise RuntimeError( # pragma: no cover 

419 "Equation '{}' was not lambdified.".format(name)) 

420 return leq(*self.vect(t)) 

421 

422 def _eval_cache(self): 

423 values = self.cst_param 

424 svalues = {self._syms[k]: v for k, v in values.items()} 

425 return svalues 

426 

427 def _lambdify_(self, name, eq, derivative=False): 

428 'Lambdifies an expression and caches in member `_lambda_`.' 

429 if not hasattr(self, '_lambda_'): 

430 self._lambda_ = {} 

431 if name not in self._lambda_: 

432 names = (self.quantity_names + self.param_names + 

433 self.cst_names + ['t']) 

434 sym = [Symbol(n) for n in names] 

435 if derivative: 

436 sym += [Symbol('d' + n) for n in self.quantity_names] 

437 self._lambda_[name] = { 

438 'names': names, 

439 'symbols': sym, 

440 'eq': eq, 

441 'pos': {n: i for i, n in enumerate(names)}, 

442 } 

443 ll = lambdify(sym, eq, 'numpy') 

444 self._lambda_[name]['la'] = ll 

445 return self._lambda_[name]['la'] 

446 

447 def _lambdified_(self, name): 

448 """ 

449 Returns the lambdified expression of name *name*. 

450 """ 

451 if hasattr(self, '_lambda_'): 

452 r = self._lambda_.get(name, None) 

453 if r is not None: 

454 return r['la'] 

455 return None 

456 

457 def _eval_diff_sympy(self, t=0): 

458 """ 

459 Evaluates derivatives. 

460 Returns a dictionary. 

461 """ 

462 svalues = self._eval_cache() 

463 svalues[self._syms['t']] = t 

464 for k, v in zip(self._q, self._val_q): 

465 svalues[self._syms[k[0]]] = v 

466 

467 x = self.vect(t=t) 

468 res = {} 

469 for k, v in self._eq.items(): 

470 res[k] = v.evalf(subs=svalues) 

471 for k, v in self._leq.items(): 

472 res[k] = v(*x) 

473 return res 

474 

475 def eval_diff(self, t=0): 

476 """ 

477 Evaluates derivatives. 

478 Returns a dictionary. 

479 """ 

480 x = self.vect(t=t) 

481 res = {} 

482 for k, v in self._leq.items(): 

483 res[k] = v(*x) 

484 return res