Coverage for mlprodict/asv_benchmark/create_asv.py: 98%

294 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file Functions to creates a benchmark based on :epkg:`asv` 

3for many regressors and classifiers. 

4""" 

5import os 

6import sys 

7import json 

8import textwrap 

9import warnings 

10import re 

11from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8 

12try: 

13 from ._create_asv_helper import ( 

14 default_asv_conf, 

15 flask_helper, 

16 pyspy_template, 

17 _handle_init_files, 

18 _asv_class_name, 

19 _read_patterns, 

20 _select_pattern_problem, 

21 _display_code_lines, 

22 add_model_import_init, 

23 find_missing_sklearn_imports) 

24except ImportError: # pragma: no cover 

25 from mlprodict.asv_benchmark._create_asv_helper import ( 

26 default_asv_conf, 

27 flask_helper, 

28 pyspy_template, 

29 _handle_init_files, 

30 _asv_class_name, 

31 _read_patterns, 

32 _select_pattern_problem, 

33 _display_code_lines, 

34 add_model_import_init, 

35 find_missing_sklearn_imports) 

36 

37try: 

38 from .. import __max_supported_opset__ 

39 from ..tools.asv_options_helper import ( 

40 shorten_onnx_options) 

41 from ..onnxrt.validate.validate_helper import sklearn_operators 

42 from ..onnxrt.validate.validate import ( 

43 _retrieve_problems_extra, _get_problem_data, _merge_options) 

44except (ValueError, ImportError): # pragma: no cover 

45 from mlprodict import __max_supported_opset__ 

46 from mlprodict.onnxrt.validate.validate_helper import sklearn_operators 

47 from mlprodict.onnxrt.validate.validate import ( 

48 _retrieve_problems_extra, _get_problem_data, _merge_options) 

49 from mlprodict.tools.asv_options_helper import shorten_onnx_options 

50try: 

51 from ..testing.verify_code import verify_code 

52except (ValueError, ImportError): # pragma: no cover 

53 from mlprodict.testing.verify_code import verify_code 

54 

55# exec function does not import models but potentially 

56# requires all specific models used to define scenarios 

57try: 

58 from ..onnxrt.validate.validate_scenarios import * # pylint: disable=W0614,W0401 

59except (ValueError, ImportError): # pragma: no cover 

60 # Skips this step if used in a benchmark. 

61 pass 

62 

63 

64def create_asv_benchmark( 

65 location, opset_min=-1, opset_max=None, 

66 runtime=('scikit-learn', 'python_compiled'), models=None, 

67 skip_models=None, extended_list=True, 

68 dims=(1, 10, 100, 10000), 

69 n_features=(4, 20), dtype=None, 

70 verbose=0, fLOG=print, clean=True, 

71 conf_params=None, filter_exp=None, 

72 filter_scenario=None, flat=False, 

73 exc=False, build=None, execute=False, 

74 add_pyspy=False, env=None, 

75 matrix=None): 

76 """ 

77 Creates an :epkg:`asv` benchmark in a folder 

78 but does not run it. 

79 

80 :param n_features: number of features to try 

81 :param dims: number of observations to try 

82 :param verbose: integer from 0 (None) to 2 (full verbose) 

83 :param opset_min: tries every conversion from this minimum opset, 

84 -1 to get the current opset defined by module :epkg:`onnx` 

85 :param opset_max: tries every conversion up to maximum opset, 

86 -1 to get the current opset defined by module :epkg:`onnx` 

87 :param runtime: runtime to check, *scikit-learn*, *python*, 

88 *python_compiled* compiles the graph structure 

89 and is more efficient when the number of observations is 

90 small, *onnxruntime1* to check :epkg:`onnxruntime`, 

91 *onnxruntime2* to check every ONNX node independently 

92 with onnxruntime, many runtime can be checked at the same time 

93 if the value is a comma separated list 

94 :param models: list of models to test or empty 

95 string to test them all 

96 :param skip_models: models to skip 

97 :param extended_list: extends the list of :epkg:`scikit-learn` converters 

98 with converters implemented in this module 

99 :param n_features: change the default number of features for 

100 a specific problem, it can also be a comma separated list 

101 :param dtype: '32' or '64' or None for both, 

102 limits the test to one specific number types 

103 :param fLOG: logging function 

104 :param clean: clean the folder first, otherwise overwrites the content 

105 :param conf_params: to overwrite some of the configuration parameters 

106 :param filter_exp: function which tells if the experiment must be run, 

107 None to run all, takes *model, problem* as an input 

108 :param filter_scenario: second function which tells if the experiment must be run, 

109 None to run all, takes *model, problem, scenario, extra* 

110 as an input 

111 :param flat: one folder for all files or subfolders 

112 :param exc: if False, raises warnings instead of exceptions 

113 whenever possible 

114 :param build: where to put the outputs 

115 :param execute: execute each script to make sure 

116 imports are correct 

117 :param add_pyspy: add an extra folder with code to profile 

118 each configuration 

119 :param env: None to use the default configuration or ``same`` to use 

120 the current one 

121 :param matrix: specifies versions for a module, 

122 example: ``{'onnxruntime': ['1.1.1', '1.1.2']}``, 

123 if a package name starts with `'~'`, the package is removed 

124 :return: created files 

125 

126 The default configuration is the following: 

127 

128 .. runpython:: 

129 :showcode: 

130 :warningout: DeprecationWarning 

131 

132 import pprint 

133 from mlprodict.asv_benchmark.create_asv import default_asv_conf 

134 

135 pprint.pprint(default_asv_conf) 

136 

137 The benchmark does not seem to work well with setting 

138 ``-environment existing:same``. The publishing fails. 

139 """ 

140 if opset_min == -1: 

141 opset_min = __max_supported_opset__ 

142 if opset_max == -1: 

143 opset_max = __max_supported_opset__ # pragma: no cover 

144 if verbose > 0 and fLOG is not None: # pragma: no cover 

145 fLOG(f"[create_asv_benchmark] opset in [{opset_min}, {opset_max}].") 

146 

147 # creates the folder if it does not exist. 

148 if not os.path.exists(location): 

149 if verbose > 0 and fLOG is not None: # pragma: no cover 

150 fLOG(f"[create_asv_benchmark] create folder '{location}'.") 

151 os.makedirs(location) # pragma: no cover 

152 

153 location_test = os.path.join(location, 'benches') 

154 if not os.path.exists(location_test): 

155 if verbose > 0 and fLOG is not None: 

156 fLOG(f"[create_asv_benchmark] create folder '{location_test}'.") 

157 os.mkdir(location_test) 

158 

159 # Cleans the content of the folder 

160 created = [] 

161 if clean: 

162 for name in os.listdir(location_test): 

163 full_name = os.path.join(location_test, name) # pragma: no cover 

164 if os.path.isfile(full_name): # pragma: no cover 

165 os.remove(full_name) 

166 

167 # configuration 

168 conf = default_asv_conf.copy() 

169 if conf_params is not None: 

170 for k, v in conf_params.items(): 

171 conf[k] = v 

172 if build is not None: 

173 for fi in ['env_dir', 'results_dir', 'html_dir']: # pragma: no cover 

174 conf[fi] = os.path.join(build, conf[fi]) 

175 if env == 'same': 

176 if matrix is not None: 

177 raise ValueError( # pragma: no cover 

178 "Parameter matrix must be None if env is 'same'.") 

179 conf['pythons'] = ['same'] 

180 conf['matrix'] = {} 

181 elif matrix is not None: 

182 drop_keys = set(p for p in matrix if p.startswith('~')) 

183 matrix = {k: v for k, v in matrix.items() if k not in drop_keys} 

184 conf['matrix'] = {k: v for k, 

185 v in conf['matrix'].items() if k not in drop_keys} 

186 conf['matrix'].update(matrix) 

187 elif env is not None: 

188 raise ValueError( # pragma: no cover 

189 f"Unable to handle env='{env}'.") 

190 dest = os.path.join(location, "asv.conf.json") 

191 created.append(dest) 

192 with open(dest, "w", encoding='utf-8') as f: 

193 json.dump(conf, f, indent=4) 

194 if verbose > 0 and fLOG is not None: 

195 fLOG("[create_asv_benchmark] create 'asv.conf.json'.") 

196 

197 # __init__.py 

198 dest = os.path.join(location, "__init__.py") 

199 with open(dest, "w", encoding='utf-8') as f: 

200 pass 

201 created.append(dest) 

202 if verbose > 0 and fLOG is not None: 

203 fLOG("[create_asv_benchmark] create '__init__.py'.") 

204 dest = os.path.join(location_test, '__init__.py') 

205 with open(dest, "w", encoding='utf-8') as f: 

206 pass 

207 created.append(dest) 

208 if verbose > 0 and fLOG is not None: 

209 fLOG("[create_asv_benchmark] create 'benches/__init__.py'.") 

210 

211 # flask_server 

212 tool_dir = os.path.join(location, 'tools') 

213 if not os.path.exists(tool_dir): 

214 os.mkdir(tool_dir) 

215 fl = os.path.join(tool_dir, 'flask_serve.py') 

216 with open(fl, "w", encoding='utf-8') as f: 

217 f.write(flask_helper) 

218 if verbose > 0 and fLOG is not None: 

219 fLOG("[create_asv_benchmark] create 'flask_serve.py'.") 

220 

221 # command line 

222 if sys.platform.startswith("win"): 

223 run_bash = os.path.join(tool_dir, 'run_asv.bat') # pragma: no cover 

224 else: 

225 run_bash = os.path.join(tool_dir, 'run_asv.sh') 

226 with open(run_bash, 'w') as f: 

227 f.write(textwrap.dedent(""" 

228 echo --BENCHRUN-- 

229 python -m asv run --show-stderr --config ./asv.conf.json 

230 echo --PUBLISH-- 

231 python -m asv publish --config ./asv.conf.json -o ./html 

232 echo --CSV-- 

233 python -m mlprodict asv2csv -f ./results -o ./data_bench.csv 

234 """)) 

235 

236 # pyspy 

237 if add_pyspy: 

238 dest_pyspy = os.path.join(location, 'pyspy') 

239 if not os.path.exists(dest_pyspy): 

240 os.mkdir(dest_pyspy) 

241 else: 

242 dest_pyspy = None 

243 

244 if verbose > 0 and fLOG is not None: 

245 fLOG("[create_asv_benchmark] create all tests.") 

246 

247 created.extend(list(_enumerate_asv_benchmark_all_models( 

248 location_test, opset_min=opset_min, opset_max=opset_max, 

249 runtime=runtime, models=models, 

250 skip_models=skip_models, extended_list=extended_list, 

251 n_features=n_features, dtype=dtype, 

252 verbose=verbose, filter_exp=filter_exp, 

253 filter_scenario=filter_scenario, 

254 dims=dims, exc=exc, flat=flat, 

255 fLOG=fLOG, execute=execute, 

256 dest_pyspy=dest_pyspy))) 

257 

258 if verbose > 0 and fLOG is not None: 

259 fLOG("[create_asv_benchmark] done.") 

260 return created 

261 

262 

263def _enumerate_asv_benchmark_all_models( # pylint: disable=R0914 

264 location, opset_min=10, opset_max=None, 

265 runtime=('scikit-learn', 'python'), models=None, 

266 skip_models=None, extended_list=True, 

267 n_features=None, dtype=None, 

268 verbose=0, filter_exp=None, 

269 dims=None, filter_scenario=None, 

270 exc=True, flat=False, execute=False, 

271 dest_pyspy=None, fLOG=print): 

272 """ 

273 Loops over all possible models and fills a folder 

274 with benchmarks following :epkg:`asv` concepts. 

275 

276 :param n_features: number of features to try 

277 :param dims: number of observations to try 

278 :param verbose: integer from 0 (None) to 2 (full verbose) 

279 :param opset_min: tries every conversion from this minimum opset 

280 :param opset_max: tries every conversion up to maximum opset 

281 :param runtime: runtime to check, *scikit-learn*, *python*, 

282 *onnxruntime1* to check :epkg:`onnxruntime`, 

283 *onnxruntime2* to check every ONNX node independently 

284 with onnxruntime, many runtime can be checked at the same time 

285 if the value is a comma separated list 

286 :param models: list of models to test or empty 

287 string to test them all 

288 :param skip_models: models to skip 

289 :param extended_list: extends the list of :epkg:`scikit-learn` converters 

290 with converters implemented in this module 

291 :param n_features: change the default number of features for 

292 a specific problem, it can also be a comma separated list 

293 :param dtype: '32' or '64' or None for both, 

294 limits the test to one specific number types 

295 :param fLOG: logging function 

296 :param filter_exp: function which tells if the experiment must be run, 

297 None to run all, takes *model, problem* as an input 

298 :param filter_scenario: second function which tells if the experiment must be run, 

299 None to run all, takes *model, problem, scenario, extra* 

300 as an input 

301 :param exc: if False, raises warnings instead of exceptions 

302 whenever possible 

303 :param flat: one folder for all files or subfolders 

304 :param execute: execute each script to make sure 

305 imports are correct 

306 :param dest_pyspy: add a file to profile the prediction 

307 function with :epkg:`pyspy` 

308 """ 

309 

310 ops = [_ for _ in sklearn_operators(extended=extended_list)] 

311 patterns = _read_patterns() 

312 

313 if models is not None: 

314 if not all(map(lambda m: isinstance(m, str), models)): 

315 raise ValueError( 

316 "models must be a set of strings.") # pragma: no cover 

317 ops_ = [_ for _ in ops if _['name'] in models] 

318 if len(ops) == 0: 

319 raise ValueError("Parameter models is wrong: {}\n{}".format( # pragma: no cover 

320 models, ops[0])) 

321 ops = ops_ 

322 if skip_models is not None: 

323 ops = [m for m in ops if m['name'] not in skip_models] 

324 

325 if verbose > 0: 

326 

327 def iterate(): 

328 for i, row in enumerate(ops): # pragma: no cover 

329 fLOG(f"{i + 1}/{len(ops)} - {row}") 

330 yield row 

331 

332 if verbose >= 11: 

333 verbose -= 10 # pragma: no cover 

334 loop = iterate() # pragma: no cover 

335 else: 

336 try: 

337 from tqdm import trange 

338 

339 def iterate_tqdm(): 

340 with trange(len(ops)) as t: 

341 for i in t: 

342 row = ops[i] 

343 disp = row['name'] + " " * (28 - len(row['name'])) 

344 t.set_description(f"{disp}") 

345 yield row 

346 

347 loop = iterate_tqdm() 

348 

349 except ImportError: # pragma: no cover 

350 loop = iterate() 

351 else: 

352 loop = ops 

353 

354 if opset_max is None: 

355 opset_max = __max_supported_opset__ 

356 opsets = list(range(opset_min, opset_max + 1)) 

357 all_created = set() 

358 

359 # loop on all models 

360 for row in loop: 

361 

362 model = row['cl'] 

363 

364 problems, extras = _retrieve_problems_extra( 

365 model, verbose, fLOG, extended_list) 

366 if extras is None or problems is None: 

367 # Not tested yet. 

368 continue # pragma: no cover 

369 

370 # flat or not flat 

371 created, location_model, prefix_import, dest_pyspy_model = _handle_init_files( 

372 model, flat, location, verbose, dest_pyspy, fLOG) 

373 for init in created: 

374 yield init 

375 

376 # loops on problems 

377 for prob in problems: 

378 if filter_exp is not None and not filter_exp(model, prob): 

379 continue 

380 

381 (X_train, X_test, y_train, 

382 y_test, Xort_test, 

383 init_types, conv_options, method_name, 

384 output_index, dofit, predict_kwargs) = _get_problem_data(prob, None) 

385 

386 for scenario_extra in extras: 

387 subset_problems = None 

388 optimisations = None 

389 new_conv_options = None 

390 

391 if len(scenario_extra) > 2: 

392 options = scenario_extra[2] 

393 if isinstance(options, dict): 

394 subset_problems = options.get('subset_problems', None) 

395 optimisations = options.get('optim', None) 

396 new_conv_options = options.get('conv_options', None) 

397 else: 

398 subset_problems = options 

399 

400 if subset_problems and isinstance(subset_problems, (list, set)): 

401 if prob not in subset_problems: 

402 # Skips unrelated problem for a specific configuration. 

403 continue 

404 elif subset_problems is not None: 

405 raise RuntimeError( # pragma: no cover 

406 "subset_problems must be a set or a list not {}.".format( 

407 subset_problems)) 

408 

409 scenario, extra = scenario_extra[:2] 

410 if optimisations is None: 

411 optimisations = [None] 

412 if new_conv_options is None: 

413 new_conv_options = [{}] 

414 

415 if (filter_scenario is not None and 

416 not filter_scenario(model, prob, scenario, 

417 extra, new_conv_options)): 

418 continue # pragma: no cover 

419 

420 if verbose >= 3 and fLOG is not None: 

421 fLOG("[create_asv_benchmark] model={} scenario={} optim={} extra={} dofit={} (problem={} method_name='{}')".format( 

422 model.__name__, scenario, optimisations, extra, dofit, prob, method_name)) 

423 created = _create_asv_benchmark_file( 

424 location_model, opsets=opsets, 

425 model=model, scenario=scenario, optimisations=optimisations, 

426 extra=extra, dofit=dofit, problem=prob, 

427 runtime=runtime, new_conv_options=new_conv_options, 

428 X_train=X_train, X_test=X_test, y_train=y_train, 

429 y_test=y_test, Xort_test=Xort_test, 

430 init_types=init_types, conv_options=conv_options, 

431 method_name=method_name, dims=dims, n_features=n_features, 

432 output_index=output_index, predict_kwargs=predict_kwargs, 

433 exc=exc, prefix_import=prefix_import, 

434 execute=execute, location_pyspy=dest_pyspy_model, 

435 patterns=patterns) 

436 for cr in created: 

437 if cr in all_created: 

438 raise RuntimeError( # pragma: no cover 

439 f"File '{cr}' was already created.") 

440 all_created.add(cr) 

441 if verbose > 1 and fLOG is not None: 

442 fLOG(f"[create_asv_benchmark] add '{cr}'.") 

443 yield cr 

444 

445 

446def _create_asv_benchmark_file( # pylint: disable=R0914 

447 location, model, scenario, optimisations, new_conv_options, 

448 extra, dofit, problem, runtime, X_train, X_test, y_train, 

449 y_test, Xort_test, init_types, conv_options, 

450 method_name, n_features, dims, opsets, 

451 output_index, predict_kwargs, prefix_import, 

452 exc, execute=False, location_pyspy=None, patterns=None): 

453 """ 

454 Creates a benchmark file based in the information received 

455 through the argument. It uses one of the templates 

456 like @see cl TemplateBenchmarkClassifier or 

457 @see cl TemplateBenchmarkRegressor. 

458 """ 

459 if patterns is None: 

460 raise ValueError("Patterns list is empty.") # pragma: no cover 

461 

462 def format_conv_options(d_options, class_name): 

463 if d_options is None: 

464 return None 

465 res = {} 

466 for k, v in d_options.items(): 

467 if isinstance(k, type): 

468 if "." + class_name + "'" in str(k): 

469 res[class_name] = v 

470 continue 

471 raise ValueError( # pragma: no cover 

472 f"Class '{class_name}', unable to format options {d_options}") 

473 res[k] = v 

474 return res 

475 

476 def _nick_name_options(model, opts): 

477 # Shorten common onnx options, see _CommonAsvSklBenchmark._to_onnx. 

478 if opts is None: 

479 return opts # pragma: no cover 

480 short_opts = shorten_onnx_options(model, opts) 

481 if short_opts is not None: 

482 return short_opts 

483 res = {} 

484 for k, v in opts.items(): 

485 if hasattr(k, '__name__'): 

486 res["####" + k.__name__ + "####"] = v 

487 else: 

488 res[k] = v # pragma: no cover 

489 return res 

490 

491 def _make_simple_name(name): 

492 simple_name = name.replace("bench_", "").replace("_bench", "") 

493 simple_name = simple_name.replace("bench.", "").replace(".bench", "") 

494 simple_name = simple_name.replace(".", "-") 

495 repl = {'_': '', 'solverliblinear': 'liblinear'} 

496 for k, v in repl.items(): 

497 simple_name = simple_name.replace(k, v) 

498 return simple_name 

499 

500 def _optdict2string(opt): 

501 if isinstance(opt, str): 

502 return opt 

503 if isinstance(opt, list): 

504 raise TypeError( 

505 f"Unable to process type {type(opt)!r}.") 

506 reps = {True: 1, False: 0, 'zipmap': 'zm', 

507 'optim': 'opt'} 

508 info = [] 

509 for k, v in sorted(opt.items()): 

510 if isinstance(v, dict): 

511 v = _optdict2string(v) 

512 if k.startswith('####'): 

513 k = '' 

514 i = f'{reps.get(k, k)}{reps.get(v, v)}' 

515 info.append(i) 

516 return "-".join(info) 

517 

518 runtimes_abb = { 

519 'scikit-learn': 'skl', 

520 'onnxruntime1': 'ort', 

521 'onnxruntime2': 'ort2', 

522 'python': 'pyrt', 

523 'python_compiled': 'pyrtc', 

524 } 

525 runtime = [runtimes_abb[k] for k in runtime] 

526 

527 # Looping over configuration. 

528 names = [] 

529 for optimisation in optimisations: 

530 merged_options = [_merge_options(nconv_options, conv_options) 

531 for nconv_options in new_conv_options] 

532 

533 nck_opts = [_nick_name_options(model, opts) 

534 for opts in merged_options] 

535 try: 

536 name = _asv_class_name( 

537 model, scenario, optimisation, extra, 

538 dofit, conv_options, problem, 

539 shorten=True) 

540 except ValueError as e: # pragma: no cover 

541 if exc: 

542 raise e 

543 warnings.warn(str(e)) 

544 continue 

545 filename = name.replace(".", "_") + ".py" 

546 try: 

547 class_content = _select_pattern_problem(problem, patterns) 

548 except ValueError as e: 

549 if exc: 

550 raise e # pragma: no cover 

551 warnings.warn(str(e)) 

552 continue 

553 full_class_name = _asv_class_name( 

554 model, scenario, optimisation, extra, 

555 dofit, conv_options, problem, 

556 shorten=False) 

557 class_name = name.replace( 

558 "bench.", "").replace(".", "_") + "_bench" 

559 

560 # n_features, N, runtimes 

561 rep = { 

562 "['skl', 'pyrtc', 'ort'], # values for runtime": str(runtime), 

563 "[1, 10, 100, 1000, 10000], # values for N": str(dims), 

564 "[4, 20], # values for nf": str(n_features), 

565 "[__max_supported_opset__], # values for opset": str(opsets), 

566 "['float', 'double'], # values for dtype": 

567 "['float']" if '-64' not in problem else "['double']", 

568 "[None], # values for optim": f"{nck_opts!r}", 

569 } 

570 for k, v in rep.items(): 

571 if k not in class_content: 

572 raise ValueError("Unable to find '{}'\n{}.".format( # pragma: no cover 

573 k, class_content)) 

574 class_content = class_content.replace(k, v + ',') 

575 class_content = class_content.split( 

576 "def _create_model(self):")[0].strip("\n ") 

577 if "####" in class_content: 

578 class_content = class_content.replace( 

579 "'####", "").replace("####'", "") 

580 if "####" in class_content: 

581 raise RuntimeError( # pragma: no cover 

582 "Substring '####' should not be part of the script for '{}'\n{}".format( 

583 model.__name__, class_content)) 

584 

585 # Model setup 

586 class_content, atts = add_model_import_init( 

587 class_content, model, optimisation, 

588 extra, merged_options) 

589 class_content = class_content.replace( 

590 "class TemplateBenchmark", 

591 f"class {class_name}") 

592 

593 # dtype, dofit 

594 atts.append(f"chk_method_name = {method_name!r}") 

595 atts.append(f"par_scenario = {scenario!r}") 

596 atts.append(f"par_problem = {problem!r}") 

597 atts.append(f"par_optimisation = {optimisation!r}") 

598 if not dofit: 

599 atts.append("par_dofit = False") 

600 if merged_options is not None and len(merged_options) > 0: 

601 atts.append("par_convopts = %r" % format_conv_options( 

602 conv_options, model.__name__)) 

603 atts.append(f"par_full_test_name = {full_class_name!r}") 

604 

605 simple_name = _make_simple_name(name) 

606 atts.append(f"benchmark_name = {simple_name!r}") 

607 atts.append(f"pretty_name = {simple_name!r}") 

608 

609 if atts: 

610 class_content = class_content.replace( 

611 "# additional parameters", 

612 "\n ".join(atts)) 

613 if prefix_import != '.': 

614 class_content = class_content.replace( 

615 " from .", f"from .{prefix_import}") 

616 

617 # Check compilation 

618 try: 

619 compile(class_content, filename, 'exec') 

620 except SyntaxError as e: # pragma: no cover 

621 raise SyntaxError("Unable to compile model '{}'\n{}".format( 

622 model.__name__, class_content)) from e 

623 

624 # Verifies missing imports. 

625 to_import, _ = verify_code(class_content, exc=False) 

626 try: 

627 miss = find_missing_sklearn_imports(to_import) 

628 except ValueError as e: # pragma: no cover 

629 raise ValueError( 

630 f"Unable to check import in script\n{class_content}") from e 

631 class_content = class_content.replace( 

632 "# __IMPORTS__", "\n".join(miss)) 

633 verify_code(class_content, exc=True) 

634 class_content = class_content.replace( 

635 "par_extra = {", "par_extra = {\n") 

636 class_content = remove_extra_spaces_and_pep8( 

637 class_content, aggressive=True) 

638 

639 # Check compilation again 

640 try: 

641 obj = compile(class_content, filename, 'exec') 

642 except SyntaxError as e: # pragma: no cover 

643 raise SyntaxError("Unable to compile model '{}'\n{}".format( 

644 model.__name__, 

645 _display_code_lines(class_content))) from e 

646 

647 # executes to check import 

648 if execute: 

649 try: 

650 exec(obj, globals(), locals()) # pylint: disable=W0122 

651 except Exception as e: # pragma: no cover 

652 raise RuntimeError( 

653 "Unable to process class '{}' ('{}') a script due to '{}'\n{}".format( 

654 model.__name__, filename, str(e), 

655 _display_code_lines(class_content))) from e 

656 

657 # Saves 

658 fullname = os.path.join(location, filename) 

659 names.append(fullname) 

660 with open(fullname, "w", encoding='utf-8') as f: 

661 f.write(class_content) 

662 

663 if location_pyspy is not None: 

664 # adding configuration for pyspy 

665 class_name = re.compile( 

666 'class ([A-Za-z_0-9]+)[(]').findall(class_content)[0] 

667 fullname_pyspy = os.path.splitext( 

668 os.path.join(location_pyspy, filename))[0] 

669 pyfold = os.path.splitext(os.path.split(fullname)[-1])[0] 

670 

671 dtypes = ['float', 'double'] if '-64' in problem else ['float'] 

672 for dim in dims: 

673 for nf in n_features: 

674 for opset in opsets: 

675 for dtype in dtypes: 

676 for opt in nck_opts: 

677 tmpl = pyspy_template.replace( 

678 '__PATH__', location) 

679 tmpl = tmpl.replace( 

680 '__CLASSNAME__', class_name) 

681 tmpl = tmpl.replace('__PYFOLD__', pyfold) 

682 opt = "" if opt == {} else opt 

683 

684 first = True 

685 for rt in runtime: 

686 if first: 

687 tmpl += textwrap.dedent(""" 

688 

689 def profile0_{rt}(iter, cl, N, nf, opset, dtype, optim): 

690 return setup_profile0(iter, cl, '{rt}', N, nf, opset, dtype, optim) 

691 iter = profile0_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt}) 

692 print(datetime.now(), "iter", iter) 

693 

694 """).format(rt=rt, dim=dim, nf=nf, opset=opset, 

695 dtype=dtype, opt=f"{opt!r}") 

696 first = False 

697 

698 tmpl += textwrap.dedent(""" 

699 

700 def profile_{rt}(iter, cl, N, nf, opset, dtype, optim): 

701 return setup_profile(iter, cl, '{rt}', N, nf, opset, dtype, optim) 

702 profile_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt}) 

703 print(datetime.now(), "iter", iter) 

704 

705 """).format(rt=rt, dim=dim, nf=nf, opset=opset, 

706 dtype=dtype, opt=f"{opt!r}") 

707 

708 thename = "{n}_{dim}_{nf}_{opset}_{dtype}_{opt}.py".format( 

709 n=fullname_pyspy, dim=dim, nf=nf, 

710 opset=opset, dtype=dtype, opt=_optdict2string(opt)) 

711 with open(thename, 'w', encoding='utf-8') as f: 

712 f.write(tmpl) 

713 names.append(thename) 

714 

715 ext = '.bat' if sys.platform.startswith( 

716 'win') else '.sh' 

717 script = os.path.splitext(thename)[0] + ext 

718 short = os.path.splitext( 

719 os.path.split(thename)[-1])[0] 

720 with open(script, 'w', encoding='utf-8') as f: 

721 f.write('py-spy record --native --function --rate=10 -o {n}_fct.svg -- {py} {n}.py\n'.format( 

722 py=sys.executable, n=short)) 

723 f.write('py-spy record --native --rate=10 -o {n}_line.svg -- {py} {n}.py\n'.format( 

724 py=sys.executable, n=short)) 

725 

726 return names