Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Overwrites unit test class with additional testing functions. 

4""" 

5from io import StringIO 

6import os 

7import sys 

8import unittest 

9import warnings 

10import decimal 

11import pprint 

12from logging import getLogger, INFO, StreamHandler 

13from contextlib import redirect_stdout, redirect_stderr 

14from .ci_helper import is_travis_or_appveyor 

15from .profiling import profile 

16from ..texthelper import compare_module_version 

17 

18 

19class ExtTestCase(unittest.TestCase): 

20 """ 

21 Overwrites unit test class with additional testing functions. 

22 Unless *setUp* is overwritten, warnings *FutureWarning* and 

23 *PendingDeprecationWarning* are filtered out. 

24 """ 

25 

26 def setUp(self): 

27 """ 

28 Filters out *FutureWarning*, *PendingDeprecationWarning*. 

29 """ 

30 warnings.simplefilter("ignore", 

31 (FutureWarning, 

32 PendingDeprecationWarning, 

33 ImportWarning, 

34 DeprecationWarning)) 

35 

36 def tearDown(self): 

37 """ 

38 Stops filtering out *FutureWarning*, *PendingDeprecationWarning*. 

39 """ 

40 warnings.simplefilter("default", 

41 (FutureWarning, 

42 PendingDeprecationWarning, 

43 ImportWarning, 

44 DeprecationWarning)) 

45 

46 @staticmethod 

47 def _format_str(s): 

48 """ 

49 Returns ``s`` or ``'s'`` depending on the type. 

50 """ 

51 if hasattr(s, "replace"): 

52 return "'{0}'".format(s) 

53 return s 

54 

55 def assertNotEmpty(self, x): 

56 """ 

57 Checks that *x* is not empty. 

58 """ 

59 if x is None or (hasattr(x, "__len__") and len(x) == 0): 

60 raise AssertionError("x is empty") 

61 

62 def assertEmpty(self, x, none_allowed=True): 

63 """ 

64 Checks that *x* is empty. 

65 """ 

66 if not((none_allowed and x is None) or (hasattr(x, "__len__") and len(x) == 0)): 

67 if isinstance(x, (list, tuple, dict, set)): 

68 end = min(5, len(x)) 

69 disp = "\n" + '\n'.join(map(str, x[:end])) 

70 else: 

71 disp = "" 

72 raise AssertionError("x is not empty{0}".format(disp)) 

73 

74 def assertGreater(self, x, y, strict=False): # pylint: disable=W0221,W0237 

75 """ 

76 Checks that ``x >= y``. 

77 """ 

78 if x < y or (strict and x == y): 

79 raise AssertionError("x <{2} y with x={0} and y={1}".format( 

80 ExtTestCase._format_str(x), ExtTestCase._format_str(y), 

81 "" if strict else "=")) 

82 

83 def assertLesser(self, x, y, strict=False): 

84 """ 

85 Checks that ``x <= y``. 

86 """ 

87 if x > y or (strict and x == y): 

88 raise AssertionError("x >{2} y with x={0} and y={1}".format( 

89 ExtTestCase._format_str(x), ExtTestCase._format_str(y), 

90 "" if strict else "=")) 

91 

92 def assertExists(self, name): 

93 """ 

94 Checks that *name* exists. 

95 """ 

96 if not os.path.exists(name): 

97 raise FileNotFoundError("Unable to find '{0}'.".format(name)) 

98 

99 def assertNotExists(self, name): 

100 """ 

101 Checks that *name* does not exist. 

102 """ 

103 if os.path.exists(name): 

104 raise FileNotFoundError( # pragma: no cover 

105 "Able to find '{0}'.".format(name)) 

106 

107 def assertEqualDataFrame(self, d1, d2, **kwargs): 

108 """ 

109 Checks that two dataframes are equal. 

110 Calls :epkg:`pandas:testing:assert_frame_equal`. 

111 """ 

112 from pandas.testing import assert_frame_equal 

113 assert_frame_equal(d1, d2, **kwargs) 

114 

115 def assertNotEqualDataFrame(self, d1, d2, **kwargs): 

116 """ 

117 Checks that two dataframes are different. 

118 Calls :epkg:`pandas:testing:assert_frame_equal`. 

119 """ 

120 from pandas.testing import assert_frame_equal 

121 try: 

122 assert_frame_equal(d1, d2, **kwargs) 

123 except AssertionError: 

124 return 

125 raise AssertionError("Two dataframes are identical.") 

126 

127 def assertEqualArray(self, d1, d2, squeeze=False, **kwargs): 

128 """ 

129 Checks that two arrays are equal. 

130 Relies on :epkg:`numpy:testing:assert_almost_equal`. 

131 """ 

132 if d1 is None and d2 is None: 

133 return 

134 if d1 is None: 

135 raise AssertionError("d1 is None, d2 is not") 

136 if d2 is None: 

137 raise AssertionError("d1 is not None, d2 is") 

138 from numpy.testing import assert_almost_equal 

139 import numpy 

140 if squeeze: 

141 d1 = numpy.squeeze(d1) 

142 d2 = numpy.squeeze(d2) 

143 assert_almost_equal(d1, d2, **kwargs) 

144 

145 def assertHasNoNan(self, a): # pylint: disable=W0221 

146 """ 

147 Checks that there is no NaN in ``a``. 

148 """ 

149 if a is None: 

150 raise AssertionError("a is None") 

151 import numpy 

152 if any(map(numpy.isnan, a.ravel())): 

153 raise AssertionError("a has nan:\n{}".format(a)) 

154 

155 def assertEqualSparseArray(self, d1, d2, **kwargs): 

156 if type(d1) != type(d2): # pylint: disable=C0123 

157 raise AssertionError("d1 and d2 have difference types {} != {}.".format( 

158 type(d1), type(d2))) 

159 if d1 is None and d2 is None: 

160 return 

161 if (hasattr(d1, 'data') and hasattr(d1, 'row') and hasattr(d1, 'col') and 

162 hasattr(d2, 'data') and hasattr(d2, 'row') and hasattr(d2, 'col')): 

163 # coo_matrix 

164 self.assertEqual(d1.shape, d2.shape) 

165 self.assertEqualArray(d1.data, d2.data) 

166 self.assertEqualArray(d1.row, d2.row) 

167 self.assertEqualArray(d1.col, d2.col) 

168 return 

169 if (hasattr(d1, 'data') and hasattr(d1, 'indices') and hasattr(d1, 'indptr') and 

170 hasattr(d2, 'data') and hasattr(d2, 'indices') and hasattr(d2, 'indptr')): 

171 # coo_matrix 

172 self.assertEqual(d1.shape, d2.shape) 

173 self.assertEqualArray(d1.data, d2.data) 

174 self.assertEqualArray(d1.indices, d2.indices) 

175 self.assertEqualArray(d1.indptr, d2.indptr) 

176 return 

177 raise NotImplementedError( # pragma: no cover 

178 "Comparison not implemented for types {} and {}.".format( 

179 type(d1), type(d2))) 

180 

181 def assertNotEqualArray(self, d1, d2, squeeze=False, **kwargs): 

182 """ 

183 Checks that two arrays are equal. 

184 Relies on :epkg:`numpy:testing:assert_almost_equal`. 

185 """ 

186 if d1 is None and d2 is None: 

187 raise AssertionError("d1 and d2 are equal to None") 

188 if d1 is None or d2 is None: 

189 return 

190 from numpy.testing import assert_almost_equal 

191 import numpy 

192 if squeeze: 

193 d1 = numpy.squeeze(d1) 

194 d2 = numpy.squeeze(d2) 

195 try: 

196 assert_almost_equal(d1, d2, **kwargs) 

197 except AssertionError: 

198 return 

199 raise AssertionError("Two arrays are identical.") 

200 

201 def assertEqualNumber(self, d1, d2, **kwargs): 

202 """ 

203 Checks that two numbers are equal. 

204 """ 

205 from numpy import number 

206 if not isinstance(d1, (int, float, decimal.Decimal, number)): 

207 raise TypeError('d1 is not a number but {0}'.format(type(d1))) 

208 if not isinstance(d2, (int, float, decimal.Decimal, number)): 

209 raise TypeError('d2 is not a number but {0}'.format(type(d2))) 

210 diff = abs(float(d1 - d2)) 

211 mi = float(min(abs(d1), abs(d2))) 

212 tol = kwargs.get('precision', None) 

213 if tol is None: 

214 if diff != 0: 

215 raise AssertionError("d1 != d2: {0} != {1}".format(d1, d2)) 

216 else: 

217 if mi == 0: 

218 if diff > tol: # pragma: no cover 

219 raise AssertionError( 

220 "d1 != d2: {0} != {1} +/- {2}".format(d1, d2, tol)) 

221 else: 

222 rel = diff / mi 

223 if rel > tol: 

224 raise AssertionError( # pragma: no cover 

225 "d1 != d2: {0} != {1} +/- {2}".format(d1, d2, tol)) 

226 

227 def assertRaise(self, fct, exc=None, msg=None): 

228 """ 

229 Checks that function *fct* with no parameter 

230 raises an exception of a given type. 

231 

232 @param fct function to test (no parameter) 

233 @param exc exception type to catch (None for all) 

234 @param msg error message to check (None for no message to check) 

235 """ 

236 try: 

237 fct() 

238 except Exception as e: 

239 if exc is None: 

240 return # pragma: no cover 

241 elif isinstance(e, exc): 

242 if msg is None: 

243 return 

244 if msg not in str(e): 

245 raise AssertionError( # pragma: no cover 

246 "Function '{0}' raise exception with wrong message '{1}' " 

247 "(must contain '{2}').".format(fct, e, msg)) 

248 return 

249 raise AssertionError( 

250 "Function '{0}' does not raise exception '{1}' but '{2}' of type " 

251 "'{3}'.".format(fct, exc, e, type(e))) 

252 raise AssertionError( # pragma: no cover 

253 "Function '{0}' does not raise exception.".format(fct)) 

254 

255 def capture(self, fct): 

256 """ 

257 Runs a function and capture standard output and error. 

258 

259 @param fct function to run 

260 @return result of *fct*, output, error 

261 """ 

262 sout = StringIO() 

263 serr = StringIO() 

264 with redirect_stdout(sout): 

265 with redirect_stderr(serr): 

266 res = fct() 

267 return res, sout.getvalue(), serr.getvalue() 

268 

269 def assertStartsWith(self, sub, whole): 

270 """ 

271 Checks that string *sub* starts with *whole*. 

272 """ 

273 if not whole.startswith(sub): 

274 if len(whole) > len(sub) * 2: 

275 whole = whole[:len(sub) * 2] # pragma: no cover 

276 raise AssertionError( 

277 "'{1}' does not start with '{0}'".format(sub, whole)) 

278 

279 def assertNotStartsWith(self, sub, whole): 

280 """ 

281 Checks that string *sub* does not start with *whole*. 

282 """ 

283 if whole.startswith(sub): 

284 if len(whole) > len(sub) * 2: 

285 whole = whole[:len(sub) * 2] # pragma: no cover 

286 raise AssertionError( 

287 "'{1}' starts with '{0}'".format(sub, whole)) 

288 

289 def assertEndsWith(self, sub, whole): 

290 """ 

291 Checks that string *sub* ends with *whole*. 

292 """ 

293 if not whole.endswith(sub): 

294 if len(whole) > len(sub) * 2: 

295 whole = whole[-len(sub) * 2:] # pragma: no cover 

296 raise AssertionError( 

297 "'{1}' does not end with '{0}'".format(sub, whole)) 

298 

299 def assertNotEndsWith(self, sub, whole): 

300 """ 

301 Checks that string *sub* does not end with *whole*. 

302 """ 

303 if whole.endswith(sub): 

304 if len(whole) > len(sub) * 2: 

305 whole = whole[-len(sub) * 2:] 

306 raise AssertionError( 

307 "'{1}' ends with '{0}'".format(sub, whole)) 

308 

309 def assertEqual(self, a, b): # pylint: disable=W0221 

310 """ 

311 Checks that ``a == b``. 

312 """ 

313 if a is None and b is not None: 

314 raise AssertionError("a is None, b is not") 

315 if a is not None and b is None: 

316 raise AssertionError("a is not None, b is") 

317 try: 

318 unittest.TestCase.assertEqual(self, a, b) 

319 except ValueError as e: 

320 if "The truth value of a DataFrame is ambiguous" in str(e) or \ 

321 "The truth value of an array with more than one element is ambiguous." in str(e): 

322 with warnings.catch_warnings(): 

323 warnings.filterwarnings("ignore", category=ImportWarning) 

324 import pandas 

325 if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame): 

326 self.assertEqualDataFrame(a, b) 

327 return 

328 import numpy 

329 if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray): 

330 self.assertEqualArray(a, b) 

331 return 

332 raise AssertionError( # pragma: no cover 

333 "Unable to check equality for types {0} and {1}".format( 

334 type(a), type(b))) from e 

335 

336 def assertNotEqual(self, a, b): # pylint: disable=W0221 

337 """ 

338 Checks that ``a != b``. 

339 """ 

340 if a is None and b is None: 

341 raise AssertionError("a is None, b is too") # pragma: no cover 

342 if a is None and b is not None: 

343 return # pragma: no cover 

344 if a is not None and b is None: 

345 return # pragma: no cover 

346 try: 

347 unittest.TestCase.assertNotEqual(self, a, b) 

348 except ValueError as e: 

349 if "Can only compare identically-labeled DataFrame objects" in str(e) or \ 

350 "The truth value of a DataFrame is ambiguous." in str(e) or \ 

351 "The truth value of an array with more than one element is ambiguous." in str(e): 

352 with warnings.catch_warnings(): 

353 warnings.filterwarnings("ignore", category=ImportWarning) 

354 import pandas 

355 if isinstance(a, pandas.DataFrame) and isinstance(b, pandas.DataFrame): 

356 self.assertNotEqualDataFrame(a, b) 

357 return 

358 import numpy 

359 if isinstance(a, numpy.ndarray) and isinstance(b, numpy.ndarray): 

360 self.assertNotEqualArray(a, b) 

361 return 

362 raise e # pragma: no cover 

363 

364 def assertEqualFloat(self, a, b, precision=1e-5): 

365 """ 

366 Checks that ``abs(a-b) < precision``. 

367 """ 

368 mi = min(abs(a), abs(b)) 

369 if mi == 0: 

370 d = abs(a - b) 

371 try: 

372 self.assertLesser(d, precision) 

373 except AssertionError: 

374 raise AssertionError("{} != {} (p={})".format(a, b, precision)) 

375 else: 

376 r = float(abs(a - b)) / mi 

377 try: 

378 self.assertLesser(r, precision) 

379 except AssertionError: 

380 raise AssertionError("{} != {} (p={})".format(a, b, precision)) 

381 

382 def assertCallable(self, fct): 

383 """ 

384 Checks that *fct* is callable. 

385 """ 

386 if not callable(fct): 

387 raise AssertionError("fct is not callable: {0}".format(type(fct))) 

388 

389 def assertEqualDict(self, a, b): 

390 """ 

391 Checks that ``a == b``. 

392 """ 

393 if not isinstance(a, dict): 

394 raise TypeError('a is not dict but {0}'.format(type(a))) 

395 if not isinstance(b, dict): 

396 raise TypeError('b is not dict but {0}'.format(type(b))) 

397 rows = [] 

398 for key in sorted(b): 

399 if key not in a: 

400 rows.append("** Added key '{0}' in b".format(key)) 

401 else: 

402 if a[key] != b[key]: 

403 rows.append( 

404 "** Value != for key '{0}': != id({1}) != id({2})\n==1 {3}\n==2 {4}".format( 

405 key, id(a[key]), id(b[key]), a[key], b[key])) 

406 for key in sorted(a): 

407 if key not in b: 

408 rows.append("** Removed key '{0}' in a".format(key)) 

409 if len(rows) > 0: 

410 raise AssertionError( 

411 "Dictionaries are different\n{0}".format('\n'.join(rows))) 

412 

413 def fLOG(self, *args, **kwargs): 

414 """ 

415 Prints out some information. 

416 @see fn fLOG. 

417 """ 

418 # delayed import 

419 from ..loghelper import fLOG as _flog # pragma: no cover 

420 _flog(*args, **kwargs) # pragma: no cover 

421 

422 @staticmethod 

423 def profile(fct, sort='cumulative', rootrem=None, 

424 return_results=False): 

425 """ 

426 Profiles the execution of a function with function 

427 :func:`profile <pyquickhelper.pycode.profiling.profile>`. 

428 

429 :param fct: function to profile 

430 :param sort: see :meth:`pstats.Stats.sort_stats` 

431 :param rootrem: root to remove in filenames 

432 :param return_results: return the results as well 

433 :return: statistics text dump 

434 

435 .. versionchanged:: 1.11 

436 Parameter *return_results* was added. 

437 """ 

438 return profile(fct, sort=sort, rootrem=rootrem, 

439 return_results=return_results) 

440 

441 def read_file(self, filename, mode='r', encoding="utf-8"): 

442 """ 

443 Returns the content of a file. 

444 

445 @param filename filename 

446 @param encoding encoding 

447 @param mode reading mode 

448 @return content 

449 """ 

450 self.assertExists(filename) 

451 with open(filename, mode, encoding=encoding) as f: 

452 return f.read() 

453 

454 def write_file(self, filename, content, mode='w', encoding='utf-8'): 

455 """ 

456 Writes the content of a file. 

457 

458 @param filename filename 

459 @param content content to write 

460 @param encoding encoding 

461 @param mode reading mode 

462 @return content 

463 """ 

464 with open(filename, mode, encoding=encoding) as f: 

465 return f.write(content) 

466 

467 def assertIn(self, sub, ensemble, msg=None): # pylint: disable=W0221,W0237 

468 """ 

469 Checks that substring *sub* is in *text*. 

470 

471 @param sub sub set 

472 @param ensemble full set 

473 @param msg error message 

474 @raises AssertionError 

475 """ 

476 if sub is None: 

477 return # pragma: no cover 

478 if ensemble is None: 

479 raise AssertionError(msg or "'text' is None") # pragma: no cover 

480 if sub not in ensemble: 

481 raise AssertionError( # pragma: no cover 

482 msg or "Unable to find '{}' in\n{}".format( 

483 sub, pprint.pformat(ensemble))) 

484 

485 def assertWarning(self, fct): 

486 """ 

487 Returns the list of warnings raised while 

488 executing function *fct*. 

489 

490 @param fct function to run 

491 @return result, list of warnings 

492 """ 

493 with warnings.catch_warnings(record=True) as w: 

494 warnings.simplefilter("always") 

495 r = fct() 

496 return r, list(w) 

497 

498 def assertLogging(self, fct, logger_name, level=INFO, log_sphinx=False): 

499 """ 

500 Returns the logged information in a logger defined 

501 by its name. 

502 

503 @param fct function to run 

504 @param logger_name logger name 

505 @param level level to intercept 

506 @param log_sphinx logging from :epkg:`sphinx` 

507 @return result, logged information 

508 """ 

509 from sphinx.util import logging as logging_sphinx 

510 

511 class MyStream: 

512 def __init__(self): 

513 self.rows = [] 

514 

515 def write(self, text): 

516 self.rows.append(text) 

517 

518 def getvalue(self): 

519 return "\n".join(self.rows) 

520 

521 def __len__(self): 

522 return len(self.rows) 

523 

524 logger = (logging_sphinx.getLogger(logger_name).logger 

525 if log_sphinx else getLogger(logger_name)) 

526 

527 hs = list(logger.handlers) 

528 for h in logger.handlers: 

529 logger.removeHandler(h) # pragma: no cover 

530 

531 log_capture_string = MyStream() 

532 ch = StreamHandler(log_capture_string) 

533 ch.setLevel(level) 

534 logger.addHandler(ch) 

535 

536 res = fct() 

537 

538 logs = log_capture_string.getvalue() 

539 logger.removeHandler(ch) 

540 

541 for h in hs: 

542 logger.addHandler(h) # pragma: no cover 

543 return res, logs 

544 

545 @staticmethod 

546 def abs_path_join(filename, *args): 

547 """ 

548 Returns an absolute and normalized path from this location. 

549 

550 :param filename: filename, the folder which contains it 

551 is used as the base 

552 :param args: list of subpaths to the previous path 

553 :return: absolute and normalized path 

554 """ 

555 dirname = os.path.join(os.path.dirname(filename), *args) 

556 return os.path.normpath(os.path.abspath(dirname)) 

557 

558 

559def skipif_appveyor(msg): 

560 """ 

561 Skips a unit test if it runs on :epkg:`appveyor`. 

562 """ 

563 if is_travis_or_appveyor() != 'appveyor': 

564 return lambda x: x 

565 msg = 'Test does not work on appveyor due to: ' + msg # pragma: no cover 

566 return unittest.skip(msg) # pragma: no cover 

567 

568 

569def skipif_travis(msg): 

570 """ 

571 Skips a unit test if it runs on :epkg:`travis`. 

572 """ 

573 if is_travis_or_appveyor() != 'travis': 

574 return lambda x: x 

575 msg = 'Test does not work on travis due to: ' + msg # pragma: no cover 

576 return unittest.skip(msg) # pragma: no cover 

577 

578 

579def skipif_circleci(msg): 

580 """ 

581 Skips a unit test if it runs on :epkg:`circleci`. 

582 """ 

583 if is_travis_or_appveyor() != 'circleci': 

584 return lambda x: x 

585 msg = 'Test does not work on circleci due to: ' + msg # pragma: no cover 

586 return unittest.skip(msg) # pragma: no cover 

587 

588 

589def skipif_azure(msg): 

590 """ 

591 Skips a unit test if it runs on :epkg:`azure pipeline`. 

592 """ 

593 if is_travis_or_appveyor() != 'azurepipe': 

594 return lambda x: x # pragma: no cover 

595 msg = 'Test does not work on azure pipeline due to: ' + msg # pragma: no cover 

596 return unittest.skip(msg) # pragma: no cover 

597 

598 

599def skipif_azure_linux(msg): 

600 """ 

601 Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`. 

602 """ 

603 if not sys.platform.startswith('lin') and is_travis_or_appveyor() != 'azurepipe': 

604 return lambda x: x # pragma: no cover 

605 msg = 'Test does not work on azure pipeline (linux) due to: ' + msg 

606 return unittest.skip(msg) 

607 

608 

609def skipif_azure_macosx(msg): 

610 """ 

611 Skips a unit test if it runs on :epkg:`azure pipeline` on :epkg:`linux`. 

612 """ 

613 if not sys.platform.startswith('darwin') and is_travis_or_appveyor() != 'azurepipe': 

614 return lambda x: x 

615 msg = 'Test does not work on azure pipeline (macosx) due to: ' + msg 

616 return unittest.skip(msg) 

617 

618 

619def skipif_linux(msg): 

620 """ 

621 Skips a unit test if it runs on :epkg:`linux`. 

622 """ 

623 if not sys.platform.startswith('lin'): 

624 return lambda x: x 

625 msg = 'Test does not work on travis due to: ' + msg # pragma: no cover 

626 return unittest.skip(msg) # pragma: no cover 

627 

628 

629def skipif_vless(version, msg): 

630 """ 

631 Skips a unit test if the version is stricly below 

632 *version* (tuple). 

633 """ 

634 if sys.version_info[:3] >= version: 

635 return lambda x: x 

636 msg = 'Python {} < {}: {}'.format( 

637 sys.version_info[:3], version, msg) # pragma: no cover 

638 return unittest.skip(msg) # pragma: no cover 

639 

640 

641def unittest_require_at_least(mod, version, msg=""): 

642 """ 

643 Skips a unit test if the version of one module 

644 is not at least the provided version. 

645 

646 @param mod module (the module must have an attribute ``__version__``) 

647 @param version expected version or more recent 

648 @param msg message 

649 

650 .. versionadded:: 1.9 

651 """ 

652 v = getattr(mod, '__version__', None) 

653 if v is None: 

654 raise RuntimeError( # pragma: no cover 

655 "Module '{}' has no version.".format(mod)) 

656 if compare_module_version(v, version) >= 0: 

657 return lambda x: x 

658 msg = "Module '{}' is older than '{}' (= '{}'). {}".format( 

659 mod, version, v, msg) 

660 return unittest.skip(msg) 

661 

662 

663def ignore_warnings(warns): 

664 """ 

665 Catches warnings. 

666 

667 @param warns warnings to ignore 

668 """ 

669 def wrapper(fct): 

670 if warns is None: 

671 raise AssertionError( # pragma: no cover 

672 "warns cannot be None for '{}'.".format(fct)) 

673 

674 def call_f(self): 

675 with warnings.catch_warnings(): 

676 warnings.simplefilter("ignore", warns) 

677 return fct(self) 

678 return call_f 

679 return wrapper 

680 

681 

682def testlog(logtype="print"): 

683 """ 

684 Logs before and after a function is called. 

685 

686 :param logtype: kind of logging, only `'print'` is implemented 

687 and None to disable it 

688 """ 

689 if logtype is None: 

690 def nothing(arg): 

691 pass 

692 

693 logfct = nothing 

694 elif logtype == 'print': 

695 logfct = print 

696 else: 

697 raise ValueError("Unexpected logtype %r." % logtype) 

698 

699 def wrapper(fct): 

700 def call_f(self): 

701 logfct('START %r' % fct.__name__) 

702 fct(self) 

703 logfct('DONE- %r' % fct.__name__) 

704 return call_f 

705 return wrapper 

706 

707 

708def assert_almost_equal_detailed(expected, value, **kwargs): 

709 """ 

710 Calls :epkg:`numpy:testing:assert_almost_equal`. 

711 Add more informations in the exception message. 

712 

713 :param expected: expected value 

714 :param value: value 

715 :raises: AssertionError 

716 """ 

717 from numpy.testing import assert_almost_equal 

718 try: 

719 assert_almost_equal(expected, value, **kwargs) 

720 except AssertionError as e: 

721 if expected.shape[0] != value.shape[0]: 

722 raise e # pragma: no cover 

723 rows = ['INNER EXCEPTION:', str(e), '------', 'ROWS BY ROWS'] 

724 for i, (r1, r2) in enumerate(zip(expected, value)): 

725 try: 

726 assert_almost_equal(r1, r2, **kwargs) 

727 except AssertionError as e: 

728 rows.append('----------------------') 

729 rows.append("ISSUE WITH ROW {}/{}:0 {}".format( 

730 i, expected.shape[0], str(e))) 

731 if len(rows) > 10: 

732 break # pragma: no cover 

733 raise AssertionError("\n".join(rows))