Coverage for pyquickhelper/pycode/unittestclass.py: 96%

367 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-03 02:21 +0200

1""" 

2@file 

3@brief Overwrites unit test class with additional testing functions. 

4""" 

5from io import StringIO 

6import os 

7import sys 

8import logging 

9import unittest 

10import warnings 

11import decimal 

12import pprint 

13from logging import getLogger, INFO, StreamHandler 

14from contextlib import redirect_stdout, redirect_stderr 

15from .ci_helper import is_travis_or_appveyor 

16from .profiling import profile 

17from ..texthelper import compare_module_version 

18 

19 

20class ExtTestCase(unittest.TestCase): 

21 """ 

22 Overwrites unit test class with additional testing functions. 

23 Unless *setUp* is overwritten, warnings *FutureWarning* and 

24 *PendingDeprecationWarning* are filtered out. 

25 """ 

26 @classmethod 

27 def setUpClass(cls): 

28 """ 

29 Filters out *FutureWarning*, *PendingDeprecationWarning*. 

30 """ 

31 warnings.simplefilter("ignore", 

32 (FutureWarning, 

33 PendingDeprecationWarning, 

34 ImportWarning, 

35 DeprecationWarning)) 

36 logger = logging.getLogger('sphinx.util') 

37 cls._log_info = (logger.getEffectiveLevel(), logger.propagate) 

38 logger.setLevel(logging.ERROR) 

39 logger.propagate = False 

40 

41 @classmethod 

42 def tearDownClass(cls): 

43 """ 

44 Stops filtering out *FutureWarning*, *PendingDeprecationWarning*. 

45 """ 

46 warnings.simplefilter("default", 

47 (FutureWarning, 

48 PendingDeprecationWarning, 

49 ImportWarning, 

50 DeprecationWarning)) 

51 if hasattr(cls, '_log_info'): 

52 logger = logging.getLogger('sphinx.util') 

53 logger.setLevel(cls._log_info[0]) 

54 logger.propagate = cls._log_info[1] 

55 else: 

56 raise AssertionError( # pragma: no cover 

57 "ExtTestCase must be called.") 

58 

59 @staticmethod 

60 def _format_str(s): 

61 """ 

62 Returns ``s`` or ``'s'`` depending on the type. 

63 """ 

64 if hasattr(s, "replace"): 

65 return f"'{s}'" 

66 return s 

67 

68 def assertNotEmpty(self, x): 

69 """ 

70 Checks that *x* is not empty. 

71 """ 

72 if x is None or (hasattr(x, "__len__") and len(x) == 0): 

73 raise AssertionError("x is empty") 

74 

75 def assertEmpty(self, x, none_allowed=True): 

76 """ 

77 Checks that *x* is empty. 

78 """ 

79 if not ((none_allowed and x is None) or (hasattr(x, "__len__") and len(x) == 0)): 

80 if isinstance(x, (list, tuple, dict, set)): 

81 end = min(5, len(x)) 

82 disp = "\n" + '\n'.join(map(str, x[:end])) 

83 else: 

84 disp = "" 

85 raise AssertionError(f"x is not empty{disp}") 

86 

87 def assertGreater(self, x, y, strict=False): # pylint: disable=W0221,W0237 

88 """ 

89 Checks that ``x >= y``. 

90 """ 

91 if x < y or (strict and x == y): 

92 raise AssertionError("x <{2} y with x={0} and y={1}".format( 

93 ExtTestCase._format_str(x), ExtTestCase._format_str(y), 

94 "" if strict else "=")) 

95 

96 def assertLesser(self, x, y, strict=False): 

97 """ 

98 Checks that ``x <= y``. 

99 """ 

100 if x > y or (strict and x == y): 

101 raise AssertionError("x >{2} y with x={0} and y={1}".format( 

102 ExtTestCase._format_str(x), ExtTestCase._format_str(y), 

103 "" if strict else "=")) 

104 

105 def assertExists(self, name): 

106 """ 

107 Checks that *name* exists. 

108 """ 

109 if not os.path.exists(name): 

110 raise FileNotFoundError(f"Unable to find '{name}'.") 

111 

112 def assertNotExists(self, name): 

113 """ 

114 Checks that *name* does not exist. 

115 """ 

116 if os.path.exists(name): 

117 raise FileNotFoundError( # pragma: no cover 

118 f"Able to find '{name}'.") 

119 

120 def assertEqualDataFrame(self, d1, d2, **kwargs): 

121 """ 

122 Checks that two dataframes are equal. 

123 Calls :epkg:`pandas:testing:assert_frame_equal`. 

124 """ 

125 from pandas.testing import assert_frame_equal 

126 assert_frame_equal(d1, d2, **kwargs) 

127 

128 def assertNotEqualDataFrame(self, d1, d2, **kwargs): 

129 """ 

130 Checks that two dataframes are different. 

131 Calls :epkg:`pandas:testing:assert_frame_equal`. 

132 """ 

133 from pandas.testing import assert_frame_equal 

134 try: 

135 assert_frame_equal(d1, d2, **kwargs) 

136 except AssertionError: 

137 return 

138 raise AssertionError("Two dataframes are identical.") 

139 

140 def assertEqualArray(self, d1, d2, squeeze=False, **kwargs): 

141 """ 

142 Checks that two arrays are equal. 

143 Relies on :epkg:`numpy:testing:assert_almost_equal`. 

144 """ 

145 if d1 is None and d2 is None: 

146 return 

147 if d1 is None: 

148 raise AssertionError("d1 is None, d2 is not") 

149 if d2 is None: 

150 raise AssertionError("d1 is not None, d2 is") 

151 from numpy.testing import assert_almost_equal, assert_allclose 

152 from numpy import squeeze 

153 d1 = squeeze(d1) 

154 d2 = squeeze(d2) 

155 if 'decimal' in kwargs: 

156 assert_almost_equal(d1, d2, **kwargs) 

157 else: 

158 assert_allclose(d1, d2, **kwargs) 

159 

160 def assertHasNoNan(self, a): # pylint: disable=W0221 

161 """ 

162 Checks that there is no NaN in ``a``. 

163 """ 

164 if a is None: 

165 raise AssertionError("a is None") 

166 import numpy 

167 if any(map(numpy.isnan, a.ravel())): 

168 raise AssertionError(f"a has nan:\n{a}") 

169 

170 def assertEqualSparseArray(self, d1, d2, **kwargs): 

171 if type(d1) != type(d2): # pylint: disable=C0123 

172 raise AssertionError("d1 and d2 have difference types {} != {}.".format( 

173 type(d1), type(d2))) 

174 if d1 is None and d2 is None: 

175 return 

176 if (hasattr(d1, 'data') and hasattr(d1, 'row') and hasattr(d1, 'col') and 

177 hasattr(d2, 'data') and hasattr(d2, 'row') and hasattr(d2, 'col')): 

178 # coo_matrix 

179 self.assertEqual(d1.shape, d2.shape) 

180 self.assertEqualArray(d1.data, d2.data) 

181 self.assertEqualArray(d1.row, d2.row) 

182 self.assertEqualArray(d1.col, d2.col) 

183 return 

184 if (hasattr(d1, 'data') and hasattr(d1, 'indices') and hasattr(d1, 'indptr') and 

185 hasattr(d2, 'data') and hasattr(d2, 'indices') and hasattr(d2, 'indptr')): 

186 # coo_matrix 

187 self.assertEqual(d1.shape, d2.shape) 

188 self.assertEqualArray(d1.data, d2.data) 

189 self.assertEqualArray(d1.indices, d2.indices) 

190 self.assertEqualArray(d1.indptr, d2.indptr) 

191 return 

192 raise NotImplementedError( # pragma: no cover 

193 f"Comparison not implemented for types {type(d1)} and {type(d2)}.") 

194 

195 def assertNotEqualArray(self, d1, d2, squeeze=False, **kwargs): 

196 """ 

197 Checks that two arrays are equal. 

198 Relies on :epkg:`numpy:testing:assert_almost_equal`. 

199 """ 

200 if d1 is None and d2 is None: 

201 raise AssertionError("d1 and d2 are equal to None") 

202 if d1 is None or d2 is None: 

203 return 

204 from numpy.testing import assert_almost_equal 

205 import numpy 

206 if squeeze: 

207 d1 = numpy.squeeze(d1) 

208 d2 = numpy.squeeze(d2) 

209 try: 

210 assert_almost_equal(d1, d2, **kwargs) 

211 except AssertionError: 

212 return 

213 raise AssertionError("Two arrays are identical.") 

214 

215 def assertEqualNumber(self, d1, d2, **kwargs): 

216 """ 

217 Checks that two numbers are equal. 

218 """ 

219 from numpy import number 

220 if not isinstance(d1, (int, float, decimal.Decimal, number)): 

221 raise TypeError(f'd1 is not a number but {type(d1)}') 

222 if not isinstance(d2, (int, float, decimal.Decimal, number)): 

223 raise TypeError(f'd2 is not a number but {type(d2)}') 

224 diff = abs(float(d1 - d2)) 

225 mi = float(min(abs(d1), abs(d2))) 

226 tol = kwargs.get('precision', None) 

227 if tol is None: 

228 if diff != 0: 

229 raise AssertionError(f"d1 != d2: {d1} != {d2}") 

230 else: 

231 if mi == 0: 

232 if diff > tol: # pragma: no cover 

233 raise AssertionError( 

234 f"d1 != d2: {d1} != {d2} +/- {tol}") 

235 else: 

236 rel = diff / mi 

237 if rel > tol: 

238 raise AssertionError( # pragma: no cover 

239 f"d1 != d2: {d1} != {d2} +/- {tol}") 

240 

241 def assertRaise(self, fct, exc=None, msg=None): 

242 """ 

243 Checks that function *fct* with no parameter 

244 raises an exception of a given type. 

245 

246 @param fct function to test (no parameter) 

247 @param exc exception type to catch (None for all) 

248 @param msg error message to check (None for no message to check) 

249 """ 

250 try: 

251 fct() 

252 except Exception as e: 

253 if exc is None: 

254 return # pragma: no cover 

255 elif isinstance(e, exc): 

256 if msg is None: 

257 return 

258 if msg not in str(e): 

259 raise AssertionError( # pragma: no cover 

260 "Function '{0}' raise exception with wrong message '{1}' " 

261 "(must contain '{2}').".format(fct, e, msg)) 

262 return 

263 raise AssertionError( 

264 "Function '{0}' does not raise exception '{1}' but '{2}' of type " 

265 "'{3}'.".format(fct, exc, e, type(e))) 

266 raise AssertionError( # pragma: no cover 

267 f"Function '{fct}' does not raise exception.") 

268 

269 def capture(self, fct): 

270 """ 

271 Runs a function and capture standard output and error. 

272 

273 @param fct function to run 

274 @return result of *fct*, output, error 

275 """ 

276 sout = StringIO() 

277 serr = StringIO() 

278 with redirect_stdout(sout): 

279 with redirect_stderr(serr): 

280 res = fct() 

281 return res, sout.getvalue(), serr.getvalue() 

282 

283 def assertStartsWith(self, sub, whole): 

284 """ 

285 Checks that string *sub* starts with *whole*. 

286 """ 

287 if not whole.startswith(sub): 

288 if len(whole) > len(sub) * 2: 

289 whole = whole[:len(sub) * 2] # pragma: no cover 

290 raise AssertionError( 

291 f"'{whole}' does not start with '{sub}'") 

292 

293 def assertNotStartsWith(self, sub, whole): 

294 """ 

295 Checks that string *sub* does not start with *whole*. 

296 """ 

297 if whole.startswith(sub): 

298 if len(whole) > len(sub) * 2: 

299 whole = whole[:len(sub) * 2] # pragma: no cover 

300 raise AssertionError( 

301 f"'{whole}' starts with '{sub}'") 

302 

303 def assertEndsWith(self, sub, whole): 

304 """ 

305 Checks that string *sub* ends with *whole*. 

306 """ 

307 if not whole.endswith(sub): 

308 if len(whole) > len(sub) * 2: 

309 whole = whole[-len(sub) * 2:] # pragma: no cover 

310 raise AssertionError( 

311 f"'{whole}' does not end with '{sub}'") 

312 

313 def assertNotEndsWith(self, sub, whole): 

314 """ 

315 Checks that string *sub* does not end with *whole*. 

316 """ 

317 if whole.endswith(sub): 

318 if len(whole) > len(sub) * 2: 

319 whole = whole[-len(sub) * 2:] 

320 raise AssertionError( 

321 f"'{whole}' ends with '{sub}'") 

322 

323 def assertEqual(self, a, b): # pylint: disable=W0221 

324 """ 

325 Checks that ``a == b``. 

326 """ 

327 if a is None and b is not None: 

328 raise AssertionError("a is None, b is not") 

329 if a is not None and b is None: 

330 raise AssertionError("a is not None, b is") 

331 try: 

332 unittest.TestCase.assertEqual(self, a, b) 

333 except ValueError as e: 

334 if "The truth value of a DataFrame is ambiguous" in str(e) or \ 

335 "The truth value of an array with more than one element is ambiguous." in str(e): 

336 with warnings.catch_warnings(): 

337 warnings.filterwarnings("ignore", category=ImportWarning) 

338 import pandas 

339 if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame): 

340 self.assertEqualDataFrame(a, b) 

341 return 

342 import numpy 

343 if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray): 

344 self.assertEqualArray(a, b) 

345 return 

346 raise AssertionError( # pragma: no cover 

347 f"Unable to check equality for types {type(a)} and {type(b)}") from e 

348 

349 def assertNotEqual(self, a, b): # pylint: disable=W0221 

350 """ 

351 Checks that ``a != b``. 

352 """ 

353 if a is None and b is None: 

354 raise AssertionError("a is None, b is too") # pragma: no cover 

355 if a is None and b is not None: 

356 return # pragma: no cover 

357 if a is not None and b is None: 

358 return # pragma: no cover 

359 try: 

360 unittest.TestCase.assertNotEqual(self, a, b) 

361 except ValueError as e: 

362 se = str(e) 

363 if ("Can only compare identically-labeled" in se or 

364 "The truth value of a DataFrame is ambiguous." in se or 

365 ("The truth value of an array with more " 

366 "than one element is ambiguous.") in se): 

367 with warnings.catch_warnings(): 

368 warnings.filterwarnings("ignore", category=ImportWarning) 

369 import pandas 

370 if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame): 

371 self.assertNotEqualDataFrame(a, b) 

372 return 

373 import numpy 

374 if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray): 

375 self.assertNotEqualArray(a, b) 

376 return 

377 raise e # pragma: no cover 

378 

379 def assertEqualFloat(self, a, b, precision=1e-5): 

380 """ 

381 Checks that ``abs(a-b) < precision``. 

382 """ 

383 mi = min(abs(a), abs(b)) 

384 if mi == 0: 

385 d = abs(a - b) 

386 try: 

387 self.assertLesser(d, precision) 

388 except AssertionError: 

389 raise AssertionError(f"{a} != {b} (p={precision})") 

390 else: 

391 r = float(abs(a - b)) / mi 

392 try: 

393 self.assertLesser(r, precision) 

394 except AssertionError: 

395 raise AssertionError(f"{a} != {b} (p={precision})") 

396 

397 def assertCallable(self, fct): 

398 """ 

399 Checks that *fct* is callable. 

400 """ 

401 if not callable(fct): 

402 raise AssertionError(f"fct is not callable: {type(fct)}") 

403 

404 def assertEqualDict(self, a, b): 

405 """ 

406 Checks that ``a == b``. 

407 """ 

408 if not isinstance(a, dict): 

409 raise TypeError(f'a is not dict but {type(a)}') 

410 if not isinstance(b, dict): 

411 raise TypeError(f'b is not dict but {type(b)}') 

412 rows = [] 

413 for key in sorted(b): 

414 if key not in a: 

415 rows.append(f"** Added key '{key}' in b") 

416 else: 

417 if a[key] != b[key]: 

418 rows.append( 

419 "** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format( 

420 key, id(a[key]), id(b[key]), a[key], b[key])) 

421 for key in sorted(a): 

422 if key not in b: 

423 rows.append(f"** Removed key '{key}' in a") 

424 if len(rows) > 0: 

425 raise AssertionError( 

426 "Dictionaries are different\n{0}".format('\n'.join(rows))) 

427 

428 def fLOG(self, *args, **kwargs): 

429 """ 

430 Prints out some information. 

431 @see fn fLOG. 

432 """ 

433 # delayed import 

434 from ..loghelper import fLOG as _flog # pragma: no cover 

435 _flog(*args, **kwargs) # pragma: no cover 

436 

437 @staticmethod 

438 def profile(fct, sort='cumulative', rootrem=None, 

439 return_results=False): 

440 """ 

441 Profiles the execution of a function with function 

442 :func:`profile <pyquickhelper.pycode.profiling.profile>`. 

443 

444 :param fct: function to profile 

445 :param sort: see :meth:`pstats.Stats.sort_stats` 

446 :param rootrem: root to remove in filenames 

447 :param return_results: return the results as well 

448 :return: statistics text dump 

449 """ 

450 return profile(fct, sort=sort, rootrem=rootrem, 

451 return_results=return_results) 

452 

453 def read_file(self, filename, mode='r', encoding="utf-8"): 

454 """ 

455 Returns the content of a file. 

456 

457 @param filename filename 

458 @param encoding encoding 

459 @param mode reading mode 

460 @return content 

461 """ 

462 self.assertExists(filename) 

463 with open(filename, mode, encoding=encoding) as f: 

464 return f.read() 

465 

466 def write_file(self, filename, content, mode='w', encoding='utf-8'): 

467 """ 

468 Writes the content of a file. 

469 

470 @param filename filename 

471 @param content content to write 

472 @param encoding encoding 

473 @param mode reading mode 

474 @return content 

475 """ 

476 with open(filename, mode, encoding=encoding) as f: 

477 return f.write(content) 

478 

479 def assertIn(self, sub, ensemble, msg=None): # pylint: disable=W0221,W0237 

480 """ 

481 Checks that substring *sub* is in *text*. 

482 

483 @param sub sub set 

484 @param ensemble full set 

485 @param msg error message 

486 @raises AssertionError 

487 """ 

488 if sub is None: 

489 return # pragma: no cover 

490 if ensemble is None: 

491 raise AssertionError(msg or "'text' is None") # pragma: no cover 

492 if sub not in ensemble: 

493 raise AssertionError( # pragma: no cover 

494 msg or f"Unable to find '{sub}' in\n{pprint.pformat(ensemble)}") 

495 

496 def assertWarning(self, fct): 

497 """ 

498 Returns the list of warnings raised while 

499 executing function *fct*. 

500 

501 @param fct function to run 

502 @return result, list of warnings 

503 """ 

504 with warnings.catch_warnings(record=True) as w: 

505 warnings.simplefilter("always") 

506 r = fct() 

507 return r, list(w) 

508 

509 def assertLogging(self, fct, logger_name, level=INFO, log_sphinx=False, 

510 console=False): 

511 """ 

512 Returns the logged information in a logger defined 

513 by its name. 

514 

515 :param fct: function to run 

516 :param logger_name: logger name 

517 :param level: level to intercept 

518 :param log_sphinx: logging from :epkg:`sphinx` 

519 :param console: shows the log on console 

520 :return: tuple(result, logged information) 

521 """ 

522 if log_sphinx: 

523 from sphinx.util import logging as logging_sphinx 

524 

525 class MyStream: 

526 def __init__(self): 

527 self.rows = [] 

528 

529 def write(self, text): 

530 self.rows.append(text) 

531 

532 def getvalue(self): 

533 return "\n".join(self.rows) 

534 

535 def __len__(self): 

536 return len(self.rows) 

537 

538 logger = (logging_sphinx.getLogger(logger_name).logger 

539 if log_sphinx else getLogger(logger_name)) 

540 

541 hs = list(logger.handlers) 

542 for h in hs: 

543 logger.removeHandler(h) # pragma: no cover 

544 

545 log_capture_string = MyStream() 

546 ch = StreamHandler(log_capture_string) 

547 ch.setLevel(level) 

548 logger.addHandler(ch) 

549 logger.setLevel(level) 

550 

551 if console: 

552 chc = StreamHandler() 

553 chc.setLevel(level) 

554 logger.addHandler(chc) 

555 if not logger.hasHandlers(): 

556 raise AssertionError( # pragma: no cover 

557 f"Logger {logger_name!r} has no handlers.") 

558 

559 prop = logger.propagate 

560 logger.propagate = False 

561 res = fct() 

562 logger.propagate = prop 

563 

564 logs = log_capture_string.getvalue() 

565 logger.removeHandler(ch) 

566 if console: 

567 logger.removeHandler(chc) 

568 

569 for h in hs: 

570 logger.addHandler(h) # pragma: no cover 

571 return res, logs 

572 

573 @staticmethod 

574 def abs_path_join(filename, *args): 

575 """ 

576 Returns an absolute and normalized path from this location. 

577 

578 :param filename: filename, the folder which contains it 

579 is used as the base 

580 :param args: list of subpaths to the previous path 

581 :return: absolute and normalized path 

582 """ 

583 dirname = os.path.join(os.path.dirname(filename), *args) 

584 return os.path.normpath(os.path.abspath(dirname)) 

585 

586 

587def skipif_appveyor(msg): 

588 """ 

589 Skips a unit test if it runs on :epkg:`appveyor`. 

590 """ 

591 if is_travis_or_appveyor() != 'appveyor': 

592 return lambda x: x 

593 msg = 'Test does not work on appveyor due to: ' + msg # pragma: no cover 

594 return unittest.skip(msg) # pragma: no cover 

595 

596 

597def skipif_travis(msg): 

598 """ 

599 Skips a unit test if it runs on :epkg:`travis`. 

600 """ 

601 if is_travis_or_appveyor() != 'travis': 

602 return lambda x: x 

603 msg = 'Test does not work on travis due to: ' + msg # pragma: no cover 

604 return unittest.skip(msg) # pragma: no cover 

605 

606 

607def skipif_circleci(msg): 

608 """ 

609 Skips a unit test if it runs on :epkg:`circleci`. 

610 """ 

611 if is_travis_or_appveyor() != 'circleci': 

612 return lambda x: x 

613 msg = 'Test does not work on circleci due to: ' + msg # pragma: no cover 

614 return unittest.skip(msg) # pragma: no cover 

615 

616 

617def skipif_azure(msg): 

618 """ 

619 Skips a unit test if it runs on :epkg:`azure pipeline`. 

620 """ 

621 if is_travis_or_appveyor() != 'azurepipe': 

622 return lambda x: x # pragma: no cover 

623 msg = 'Test does not work on azure pipeline due to: ' + msg # pragma: no cover 

624 return unittest.skip(msg) # pragma: no cover 

625 

626 

627def skipif_azure_linux(msg): 

628 """ 

629 Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`. 

630 """ 

631 if not sys.platform.startswith('lin') and is_travis_or_appveyor() != 'azurepipe': 

632 return lambda x: x # pragma: no cover 

633 msg = 'Test does not work on azure pipeline (linux) due to: ' + msg 

634 return unittest.skip(msg) 

635 

636 

637def skipif_azure_macosx(msg): 

638 """ 

639 Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`. 

640 """ 

641 if not sys.platform.startswith('darwin') and is_travis_or_appveyor() != 'azurepipe': 

642 return lambda x: x 

643 msg = 'Test does not work on azure pipeline (macosx) due to: ' + msg 

644 return unittest.skip(msg) 

645 

646 

647def skipif_linux(msg): 

648 """ 

649 Skips a unit test if it runs on :epkg:`linux`. 

650 """ 

651 if not sys.platform.startswith('lin'): 

652 return lambda x: x 

653 msg = 'Test does not work on travis due to: ' + msg # pragma: no cover 

654 return unittest.skip(msg) # pragma: no cover 

655 

656 

657def skipif_vless(version, msg): 

658 """ 

659 Skips a unit test if the version is stricly below 

660 *version* (tuple). 

661 """ 

662 if sys.version_info[:3] >= version: 

663 return lambda x: x 

664 msg = f'Python {sys.version_info[:3]} < {version}: {msg}' # pragma: no cover 

665 return unittest.skip(msg) # pragma: no cover 

666 

667 

668def unittest_require_at_least(mod, version, msg=""): 

669 """ 

670 Skips a unit test if the version of one module 

671 is not at least the provided version. 

672 

673 @param mod module (the module must have an attribute ``__version__``) 

674 @param version expected version or more recent 

675 @param msg message 

676 """ 

677 v = getattr(mod, '__version__', None) 

678 if v is None: 

679 raise RuntimeError( # pragma: no cover 

680 f"Module '{mod}' has no version.") 

681 if compare_module_version(v, version) >= 0: 

682 return lambda x: x 

683 msg = f"Module '{mod}' is older than '{version}' (= '{v}'). {msg}" 

684 return unittest.skip(msg) 

685 

686 

687def ignore_warnings(warns): 

688 """ 

689 Catches warnings. 

690 

691 @param warns warnings to ignore 

692 """ 

693 def wrapper(fct): 

694 if warns is None: 

695 raise AssertionError( # pragma: no cover 

696 f"warns cannot be None for '{fct}'.") 

697 

698 def call_f(self): 

699 with warnings.catch_warnings(): 

700 warnings.simplefilter("ignore", warns) 

701 return fct(self) 

702 return call_f 

703 return wrapper 

704 

705 

706def testlog(logtype="print"): 

707 """ 

708 Logs before and after a function is called. 

709 

710 :param logtype: kind of logging, only `'print'` is implemented 

711 and None to disable it 

712 """ 

713 if logtype is None: 

714 def nothing(arg): 

715 pass 

716 

717 logfct = nothing 

718 elif logtype == 'print': 

719 logfct = print 

720 else: 

721 raise ValueError(f"Unexpected logtype {logtype!r}.") 

722 

723 def wrapper(fct): 

724 def call_f(self): 

725 logfct(f'START {fct.__name__!r}') 

726 fct(self) 

727 logfct(f'DONE- {fct.__name__!r}') 

728 return call_f 

729 return wrapper 

730 

731 

732def assert_almost_equal_detailed(expected, value, **kwargs): 

733 """ 

734 Calls :epkg:`numpy:testing:assert_almost_equal`. 

735 Add more informations in the exception message. 

736 

737 :param expected: expected value 

738 :param value: value 

739 :raises: AssertionError 

740 """ 

741 from numpy.testing import assert_almost_equal 

742 try: 

743 assert_almost_equal(expected, value, **kwargs) 

744 except AssertionError as e: 

745 if expected.shape[0] != value.shape[0]: 

746 raise e # pragma: no cover 

747 rows = ['INNER EXCEPTION:', str(e), '------', 'ROWS BY ROWS'] 

748 for i, (r1, r2) in enumerate(zip(expected, value)): 

749 try: 

750 assert_almost_equal(r1, r2, **kwargs) 

751 except AssertionError as ee: 

752 rows.append('----------------------') 

753 rows.append( 

754 f"ISSUE WITH ROW {i}/{expected.shape[0]}:0 {str(ee)}") 

755 if len(rows) > 10: 

756 break # pragma: no cover 

757 raise AssertionError("\n".join(rows))