Coverage for mlprodict/asv_benchmark/create_asv.py: 98%
294 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file Functions to creates a benchmark based on :epkg:`asv`
3for many regressors and classifiers.
4"""
5import os
6import sys
7import json
8import textwrap
9import warnings
10import re
11from pyquickhelper.pycode.code_helper import remove_extra_spaces_and_pep8
12try:
13 from ._create_asv_helper import (
14 default_asv_conf,
15 flask_helper,
16 pyspy_template,
17 _handle_init_files,
18 _asv_class_name,
19 _read_patterns,
20 _select_pattern_problem,
21 _display_code_lines,
22 add_model_import_init,
23 find_missing_sklearn_imports)
24except ImportError: # pragma: no cover
25 from mlprodict.asv_benchmark._create_asv_helper import (
26 default_asv_conf,
27 flask_helper,
28 pyspy_template,
29 _handle_init_files,
30 _asv_class_name,
31 _read_patterns,
32 _select_pattern_problem,
33 _display_code_lines,
34 add_model_import_init,
35 find_missing_sklearn_imports)
37try:
38 from .. import __max_supported_opset__
39 from ..tools.asv_options_helper import (
40 shorten_onnx_options)
41 from ..onnxrt.validate.validate_helper import sklearn_operators
42 from ..onnxrt.validate.validate import (
43 _retrieve_problems_extra, _get_problem_data, _merge_options)
44except (ValueError, ImportError): # pragma: no cover
45 from mlprodict import __max_supported_opset__
46 from mlprodict.onnxrt.validate.validate_helper import sklearn_operators
47 from mlprodict.onnxrt.validate.validate import (
48 _retrieve_problems_extra, _get_problem_data, _merge_options)
49 from mlprodict.tools.asv_options_helper import shorten_onnx_options
50try:
51 from ..testing.verify_code import verify_code
52except (ValueError, ImportError): # pragma: no cover
53 from mlprodict.testing.verify_code import verify_code
55# exec function does not import models but potentially
56# requires all specific models used to define scenarios
57try:
58 from ..onnxrt.validate.validate_scenarios import * # pylint: disable=W0614,W0401
59except (ValueError, ImportError): # pragma: no cover
60 # Skips this step if used in a benchmark.
61 pass
64def create_asv_benchmark(
65 location, opset_min=-1, opset_max=None,
66 runtime=('scikit-learn', 'python_compiled'), models=None,
67 skip_models=None, extended_list=True,
68 dims=(1, 10, 100, 10000),
69 n_features=(4, 20), dtype=None,
70 verbose=0, fLOG=print, clean=True,
71 conf_params=None, filter_exp=None,
72 filter_scenario=None, flat=False,
73 exc=False, build=None, execute=False,
74 add_pyspy=False, env=None,
75 matrix=None):
76 """
77 Creates an :epkg:`asv` benchmark in a folder
78 but does not run it.
80 :param n_features: number of features to try
81 :param dims: number of observations to try
82 :param verbose: integer from 0 (None) to 2 (full verbose)
83 :param opset_min: tries every conversion from this minimum opset,
84 -1 to get the current opset defined by module :epkg:`onnx`
85 :param opset_max: tries every conversion up to maximum opset,
86 -1 to get the current opset defined by module :epkg:`onnx`
87 :param runtime: runtime to check, *scikit-learn*, *python*,
88 *python_compiled* compiles the graph structure
89 and is more efficient when the number of observations is
90 small, *onnxruntime1* to check :epkg:`onnxruntime`,
91 *onnxruntime2* to check every ONNX node independently
92 with onnxruntime, many runtime can be checked at the same time
93 if the value is a comma separated list
94 :param models: list of models to test or empty
95 string to test them all
96 :param skip_models: models to skip
97 :param extended_list: extends the list of :epkg:`scikit-learn` converters
98 with converters implemented in this module
99 :param n_features: change the default number of features for
100 a specific problem, it can also be a comma separated list
101 :param dtype: '32' or '64' or None for both,
102 limits the test to one specific number types
103 :param fLOG: logging function
104 :param clean: clean the folder first, otherwise overwrites the content
105 :param conf_params: to overwrite some of the configuration parameters
106 :param filter_exp: function which tells if the experiment must be run,
107 None to run all, takes *model, problem* as an input
108 :param filter_scenario: second function which tells if the experiment must be run,
109 None to run all, takes *model, problem, scenario, extra*
110 as an input
111 :param flat: one folder for all files or subfolders
112 :param exc: if False, raises warnings instead of exceptions
113 whenever possible
114 :param build: where to put the outputs
115 :param execute: execute each script to make sure
116 imports are correct
117 :param add_pyspy: add an extra folder with code to profile
118 each configuration
119 :param env: None to use the default configuration or ``same`` to use
120 the current one
121 :param matrix: specifies versions for a module,
122 example: ``{'onnxruntime': ['1.1.1', '1.1.2']}``,
123 if a package name starts with `'~'`, the package is removed
124 :return: created files
126 The default configuration is the following:
128 .. runpython::
129 :showcode:
130 :warningout: DeprecationWarning
132 import pprint
133 from mlprodict.asv_benchmark.create_asv import default_asv_conf
135 pprint.pprint(default_asv_conf)
137 The benchmark does not seem to work well with setting
138 ``-environment existing:same``. The publishing fails.
139 """
140 if opset_min == -1:
141 opset_min = __max_supported_opset__
142 if opset_max == -1:
143 opset_max = __max_supported_opset__ # pragma: no cover
144 if verbose > 0 and fLOG is not None: # pragma: no cover
145 fLOG(f"[create_asv_benchmark] opset in [{opset_min}, {opset_max}].")
147 # creates the folder if it does not exist.
148 if not os.path.exists(location):
149 if verbose > 0 and fLOG is not None: # pragma: no cover
150 fLOG(f"[create_asv_benchmark] create folder '{location}'.")
151 os.makedirs(location) # pragma: no cover
153 location_test = os.path.join(location, 'benches')
154 if not os.path.exists(location_test):
155 if verbose > 0 and fLOG is not None:
156 fLOG(f"[create_asv_benchmark] create folder '{location_test}'.")
157 os.mkdir(location_test)
159 # Cleans the content of the folder
160 created = []
161 if clean:
162 for name in os.listdir(location_test):
163 full_name = os.path.join(location_test, name) # pragma: no cover
164 if os.path.isfile(full_name): # pragma: no cover
165 os.remove(full_name)
167 # configuration
168 conf = default_asv_conf.copy()
169 if conf_params is not None:
170 for k, v in conf_params.items():
171 conf[k] = v
172 if build is not None:
173 for fi in ['env_dir', 'results_dir', 'html_dir']: # pragma: no cover
174 conf[fi] = os.path.join(build, conf[fi])
175 if env == 'same':
176 if matrix is not None:
177 raise ValueError( # pragma: no cover
178 "Parameter matrix must be None if env is 'same'.")
179 conf['pythons'] = ['same']
180 conf['matrix'] = {}
181 elif matrix is not None:
182 drop_keys = set(p for p in matrix if p.startswith('~'))
183 matrix = {k: v for k, v in matrix.items() if k not in drop_keys}
184 conf['matrix'] = {k: v for k,
185 v in conf['matrix'].items() if k not in drop_keys}
186 conf['matrix'].update(matrix)
187 elif env is not None:
188 raise ValueError( # pragma: no cover
189 f"Unable to handle env='{env}'.")
190 dest = os.path.join(location, "asv.conf.json")
191 created.append(dest)
192 with open(dest, "w", encoding='utf-8') as f:
193 json.dump(conf, f, indent=4)
194 if verbose > 0 and fLOG is not None:
195 fLOG("[create_asv_benchmark] create 'asv.conf.json'.")
197 # __init__.py
198 dest = os.path.join(location, "__init__.py")
199 with open(dest, "w", encoding='utf-8') as f:
200 pass
201 created.append(dest)
202 if verbose > 0 and fLOG is not None:
203 fLOG("[create_asv_benchmark] create '__init__.py'.")
204 dest = os.path.join(location_test, '__init__.py')
205 with open(dest, "w", encoding='utf-8') as f:
206 pass
207 created.append(dest)
208 if verbose > 0 and fLOG is not None:
209 fLOG("[create_asv_benchmark] create 'benches/__init__.py'.")
211 # flask_server
212 tool_dir = os.path.join(location, 'tools')
213 if not os.path.exists(tool_dir):
214 os.mkdir(tool_dir)
215 fl = os.path.join(tool_dir, 'flask_serve.py')
216 with open(fl, "w", encoding='utf-8') as f:
217 f.write(flask_helper)
218 if verbose > 0 and fLOG is not None:
219 fLOG("[create_asv_benchmark] create 'flask_serve.py'.")
221 # command line
222 if sys.platform.startswith("win"):
223 run_bash = os.path.join(tool_dir, 'run_asv.bat') # pragma: no cover
224 else:
225 run_bash = os.path.join(tool_dir, 'run_asv.sh')
226 with open(run_bash, 'w') as f:
227 f.write(textwrap.dedent("""
228 echo --BENCHRUN--
229 python -m asv run --show-stderr --config ./asv.conf.json
230 echo --PUBLISH--
231 python -m asv publish --config ./asv.conf.json -o ./html
232 echo --CSV--
233 python -m mlprodict asv2csv -f ./results -o ./data_bench.csv
234 """))
236 # pyspy
237 if add_pyspy:
238 dest_pyspy = os.path.join(location, 'pyspy')
239 if not os.path.exists(dest_pyspy):
240 os.mkdir(dest_pyspy)
241 else:
242 dest_pyspy = None
244 if verbose > 0 and fLOG is not None:
245 fLOG("[create_asv_benchmark] create all tests.")
247 created.extend(list(_enumerate_asv_benchmark_all_models(
248 location_test, opset_min=opset_min, opset_max=opset_max,
249 runtime=runtime, models=models,
250 skip_models=skip_models, extended_list=extended_list,
251 n_features=n_features, dtype=dtype,
252 verbose=verbose, filter_exp=filter_exp,
253 filter_scenario=filter_scenario,
254 dims=dims, exc=exc, flat=flat,
255 fLOG=fLOG, execute=execute,
256 dest_pyspy=dest_pyspy)))
258 if verbose > 0 and fLOG is not None:
259 fLOG("[create_asv_benchmark] done.")
260 return created
263def _enumerate_asv_benchmark_all_models( # pylint: disable=R0914
264 location, opset_min=10, opset_max=None,
265 runtime=('scikit-learn', 'python'), models=None,
266 skip_models=None, extended_list=True,
267 n_features=None, dtype=None,
268 verbose=0, filter_exp=None,
269 dims=None, filter_scenario=None,
270 exc=True, flat=False, execute=False,
271 dest_pyspy=None, fLOG=print):
272 """
273 Loops over all possible models and fills a folder
274 with benchmarks following :epkg:`asv` concepts.
276 :param n_features: number of features to try
277 :param dims: number of observations to try
278 :param verbose: integer from 0 (None) to 2 (full verbose)
279 :param opset_min: tries every conversion from this minimum opset
280 :param opset_max: tries every conversion up to maximum opset
281 :param runtime: runtime to check, *scikit-learn*, *python*,
282 *onnxruntime1* to check :epkg:`onnxruntime`,
283 *onnxruntime2* to check every ONNX node independently
284 with onnxruntime, many runtime can be checked at the same time
285 if the value is a comma separated list
286 :param models: list of models to test or empty
287 string to test them all
288 :param skip_models: models to skip
289 :param extended_list: extends the list of :epkg:`scikit-learn` converters
290 with converters implemented in this module
291 :param n_features: change the default number of features for
292 a specific problem, it can also be a comma separated list
293 :param dtype: '32' or '64' or None for both,
294 limits the test to one specific number types
295 :param fLOG: logging function
296 :param filter_exp: function which tells if the experiment must be run,
297 None to run all, takes *model, problem* as an input
298 :param filter_scenario: second function which tells if the experiment must be run,
299 None to run all, takes *model, problem, scenario, extra*
300 as an input
301 :param exc: if False, raises warnings instead of exceptions
302 whenever possible
303 :param flat: one folder for all files or subfolders
304 :param execute: execute each script to make sure
305 imports are correct
306 :param dest_pyspy: add a file to profile the prediction
307 function with :epkg:`pyspy`
308 """
310 ops = [_ for _ in sklearn_operators(extended=extended_list)]
311 patterns = _read_patterns()
313 if models is not None:
314 if not all(map(lambda m: isinstance(m, str), models)):
315 raise ValueError(
316 "models must be a set of strings.") # pragma: no cover
317 ops_ = [_ for _ in ops if _['name'] in models]
318 if len(ops) == 0:
319 raise ValueError("Parameter models is wrong: {}\n{}".format( # pragma: no cover
320 models, ops[0]))
321 ops = ops_
322 if skip_models is not None:
323 ops = [m for m in ops if m['name'] not in skip_models]
325 if verbose > 0:
327 def iterate():
328 for i, row in enumerate(ops): # pragma: no cover
329 fLOG(f"{i + 1}/{len(ops)} - {row}")
330 yield row
332 if verbose >= 11:
333 verbose -= 10 # pragma: no cover
334 loop = iterate() # pragma: no cover
335 else:
336 try:
337 from tqdm import trange
339 def iterate_tqdm():
340 with trange(len(ops)) as t:
341 for i in t:
342 row = ops[i]
343 disp = row['name'] + " " * (28 - len(row['name']))
344 t.set_description(f"{disp}")
345 yield row
347 loop = iterate_tqdm()
349 except ImportError: # pragma: no cover
350 loop = iterate()
351 else:
352 loop = ops
354 if opset_max is None:
355 opset_max = __max_supported_opset__
356 opsets = list(range(opset_min, opset_max + 1))
357 all_created = set()
359 # loop on all models
360 for row in loop:
362 model = row['cl']
364 problems, extras = _retrieve_problems_extra(
365 model, verbose, fLOG, extended_list)
366 if extras is None or problems is None:
367 # Not tested yet.
368 continue # pragma: no cover
370 # flat or not flat
371 created, location_model, prefix_import, dest_pyspy_model = _handle_init_files(
372 model, flat, location, verbose, dest_pyspy, fLOG)
373 for init in created:
374 yield init
376 # loops on problems
377 for prob in problems:
378 if filter_exp is not None and not filter_exp(model, prob):
379 continue
381 (X_train, X_test, y_train,
382 y_test, Xort_test,
383 init_types, conv_options, method_name,
384 output_index, dofit, predict_kwargs) = _get_problem_data(prob, None)
386 for scenario_extra in extras:
387 subset_problems = None
388 optimisations = None
389 new_conv_options = None
391 if len(scenario_extra) > 2:
392 options = scenario_extra[2]
393 if isinstance(options, dict):
394 subset_problems = options.get('subset_problems', None)
395 optimisations = options.get('optim', None)
396 new_conv_options = options.get('conv_options', None)
397 else:
398 subset_problems = options
400 if subset_problems and isinstance(subset_problems, (list, set)):
401 if prob not in subset_problems:
402 # Skips unrelated problem for a specific configuration.
403 continue
404 elif subset_problems is not None:
405 raise RuntimeError( # pragma: no cover
406 "subset_problems must be a set or a list not {}.".format(
407 subset_problems))
409 scenario, extra = scenario_extra[:2]
410 if optimisations is None:
411 optimisations = [None]
412 if new_conv_options is None:
413 new_conv_options = [{}]
415 if (filter_scenario is not None and
416 not filter_scenario(model, prob, scenario,
417 extra, new_conv_options)):
418 continue # pragma: no cover
420 if verbose >= 3 and fLOG is not None:
421 fLOG("[create_asv_benchmark] model={} scenario={} optim={} extra={} dofit={} (problem={} method_name='{}')".format(
422 model.__name__, scenario, optimisations, extra, dofit, prob, method_name))
423 created = _create_asv_benchmark_file(
424 location_model, opsets=opsets,
425 model=model, scenario=scenario, optimisations=optimisations,
426 extra=extra, dofit=dofit, problem=prob,
427 runtime=runtime, new_conv_options=new_conv_options,
428 X_train=X_train, X_test=X_test, y_train=y_train,
429 y_test=y_test, Xort_test=Xort_test,
430 init_types=init_types, conv_options=conv_options,
431 method_name=method_name, dims=dims, n_features=n_features,
432 output_index=output_index, predict_kwargs=predict_kwargs,
433 exc=exc, prefix_import=prefix_import,
434 execute=execute, location_pyspy=dest_pyspy_model,
435 patterns=patterns)
436 for cr in created:
437 if cr in all_created:
438 raise RuntimeError( # pragma: no cover
439 f"File '{cr}' was already created.")
440 all_created.add(cr)
441 if verbose > 1 and fLOG is not None:
442 fLOG(f"[create_asv_benchmark] add '{cr}'.")
443 yield cr
446def _create_asv_benchmark_file( # pylint: disable=R0914
447 location, model, scenario, optimisations, new_conv_options,
448 extra, dofit, problem, runtime, X_train, X_test, y_train,
449 y_test, Xort_test, init_types, conv_options,
450 method_name, n_features, dims, opsets,
451 output_index, predict_kwargs, prefix_import,
452 exc, execute=False, location_pyspy=None, patterns=None):
453 """
454 Creates a benchmark file based in the information received
455 through the argument. It uses one of the templates
456 like @see cl TemplateBenchmarkClassifier or
457 @see cl TemplateBenchmarkRegressor.
458 """
459 if patterns is None:
460 raise ValueError("Patterns list is empty.") # pragma: no cover
462 def format_conv_options(d_options, class_name):
463 if d_options is None:
464 return None
465 res = {}
466 for k, v in d_options.items():
467 if isinstance(k, type):
468 if "." + class_name + "'" in str(k):
469 res[class_name] = v
470 continue
471 raise ValueError( # pragma: no cover
472 f"Class '{class_name}', unable to format options {d_options}")
473 res[k] = v
474 return res
476 def _nick_name_options(model, opts):
477 # Shorten common onnx options, see _CommonAsvSklBenchmark._to_onnx.
478 if opts is None:
479 return opts # pragma: no cover
480 short_opts = shorten_onnx_options(model, opts)
481 if short_opts is not None:
482 return short_opts
483 res = {}
484 for k, v in opts.items():
485 if hasattr(k, '__name__'):
486 res["####" + k.__name__ + "####"] = v
487 else:
488 res[k] = v # pragma: no cover
489 return res
491 def _make_simple_name(name):
492 simple_name = name.replace("bench_", "").replace("_bench", "")
493 simple_name = simple_name.replace("bench.", "").replace(".bench", "")
494 simple_name = simple_name.replace(".", "-")
495 repl = {'_': '', 'solverliblinear': 'liblinear'}
496 for k, v in repl.items():
497 simple_name = simple_name.replace(k, v)
498 return simple_name
500 def _optdict2string(opt):
501 if isinstance(opt, str):
502 return opt
503 if isinstance(opt, list):
504 raise TypeError(
505 f"Unable to process type {type(opt)!r}.")
506 reps = {True: 1, False: 0, 'zipmap': 'zm',
507 'optim': 'opt'}
508 info = []
509 for k, v in sorted(opt.items()):
510 if isinstance(v, dict):
511 v = _optdict2string(v)
512 if k.startswith('####'):
513 k = ''
514 i = f'{reps.get(k, k)}{reps.get(v, v)}'
515 info.append(i)
516 return "-".join(info)
518 runtimes_abb = {
519 'scikit-learn': 'skl',
520 'onnxruntime1': 'ort',
521 'onnxruntime2': 'ort2',
522 'python': 'pyrt',
523 'python_compiled': 'pyrtc',
524 }
525 runtime = [runtimes_abb[k] for k in runtime]
527 # Looping over configuration.
528 names = []
529 for optimisation in optimisations:
530 merged_options = [_merge_options(nconv_options, conv_options)
531 for nconv_options in new_conv_options]
533 nck_opts = [_nick_name_options(model, opts)
534 for opts in merged_options]
535 try:
536 name = _asv_class_name(
537 model, scenario, optimisation, extra,
538 dofit, conv_options, problem,
539 shorten=True)
540 except ValueError as e: # pragma: no cover
541 if exc:
542 raise e
543 warnings.warn(str(e))
544 continue
545 filename = name.replace(".", "_") + ".py"
546 try:
547 class_content = _select_pattern_problem(problem, patterns)
548 except ValueError as e:
549 if exc:
550 raise e # pragma: no cover
551 warnings.warn(str(e))
552 continue
553 full_class_name = _asv_class_name(
554 model, scenario, optimisation, extra,
555 dofit, conv_options, problem,
556 shorten=False)
557 class_name = name.replace(
558 "bench.", "").replace(".", "_") + "_bench"
560 # n_features, N, runtimes
561 rep = {
562 "['skl', 'pyrtc', 'ort'], # values for runtime": str(runtime),
563 "[1, 10, 100, 1000, 10000], # values for N": str(dims),
564 "[4, 20], # values for nf": str(n_features),
565 "[__max_supported_opset__], # values for opset": str(opsets),
566 "['float', 'double'], # values for dtype":
567 "['float']" if '-64' not in problem else "['double']",
568 "[None], # values for optim": f"{nck_opts!r}",
569 }
570 for k, v in rep.items():
571 if k not in class_content:
572 raise ValueError("Unable to find '{}'\n{}.".format( # pragma: no cover
573 k, class_content))
574 class_content = class_content.replace(k, v + ',')
575 class_content = class_content.split(
576 "def _create_model(self):")[0].strip("\n ")
577 if "####" in class_content:
578 class_content = class_content.replace(
579 "'####", "").replace("####'", "")
580 if "####" in class_content:
581 raise RuntimeError( # pragma: no cover
582 "Substring '####' should not be part of the script for '{}'\n{}".format(
583 model.__name__, class_content))
585 # Model setup
586 class_content, atts = add_model_import_init(
587 class_content, model, optimisation,
588 extra, merged_options)
589 class_content = class_content.replace(
590 "class TemplateBenchmark",
591 f"class {class_name}")
593 # dtype, dofit
594 atts.append(f"chk_method_name = {method_name!r}")
595 atts.append(f"par_scenario = {scenario!r}")
596 atts.append(f"par_problem = {problem!r}")
597 atts.append(f"par_optimisation = {optimisation!r}")
598 if not dofit:
599 atts.append("par_dofit = False")
600 if merged_options is not None and len(merged_options) > 0:
601 atts.append("par_convopts = %r" % format_conv_options(
602 conv_options, model.__name__))
603 atts.append(f"par_full_test_name = {full_class_name!r}")
605 simple_name = _make_simple_name(name)
606 atts.append(f"benchmark_name = {simple_name!r}")
607 atts.append(f"pretty_name = {simple_name!r}")
609 if atts:
610 class_content = class_content.replace(
611 "# additional parameters",
612 "\n ".join(atts))
613 if prefix_import != '.':
614 class_content = class_content.replace(
615 " from .", f"from .{prefix_import}")
617 # Check compilation
618 try:
619 compile(class_content, filename, 'exec')
620 except SyntaxError as e: # pragma: no cover
621 raise SyntaxError("Unable to compile model '{}'\n{}".format(
622 model.__name__, class_content)) from e
624 # Verifies missing imports.
625 to_import, _ = verify_code(class_content, exc=False)
626 try:
627 miss = find_missing_sklearn_imports(to_import)
628 except ValueError as e: # pragma: no cover
629 raise ValueError(
630 f"Unable to check import in script\n{class_content}") from e
631 class_content = class_content.replace(
632 "# __IMPORTS__", "\n".join(miss))
633 verify_code(class_content, exc=True)
634 class_content = class_content.replace(
635 "par_extra = {", "par_extra = {\n")
636 class_content = remove_extra_spaces_and_pep8(
637 class_content, aggressive=True)
639 # Check compilation again
640 try:
641 obj = compile(class_content, filename, 'exec')
642 except SyntaxError as e: # pragma: no cover
643 raise SyntaxError("Unable to compile model '{}'\n{}".format(
644 model.__name__,
645 _display_code_lines(class_content))) from e
647 # executes to check import
648 if execute:
649 try:
650 exec(obj, globals(), locals()) # pylint: disable=W0122
651 except Exception as e: # pragma: no cover
652 raise RuntimeError(
653 "Unable to process class '{}' ('{}') a script due to '{}'\n{}".format(
654 model.__name__, filename, str(e),
655 _display_code_lines(class_content))) from e
657 # Saves
658 fullname = os.path.join(location, filename)
659 names.append(fullname)
660 with open(fullname, "w", encoding='utf-8') as f:
661 f.write(class_content)
663 if location_pyspy is not None:
664 # adding configuration for pyspy
665 class_name = re.compile(
666 'class ([A-Za-z_0-9]+)[(]').findall(class_content)[0]
667 fullname_pyspy = os.path.splitext(
668 os.path.join(location_pyspy, filename))[0]
669 pyfold = os.path.splitext(os.path.split(fullname)[-1])[0]
671 dtypes = ['float', 'double'] if '-64' in problem else ['float']
672 for dim in dims:
673 for nf in n_features:
674 for opset in opsets:
675 for dtype in dtypes:
676 for opt in nck_opts:
677 tmpl = pyspy_template.replace(
678 '__PATH__', location)
679 tmpl = tmpl.replace(
680 '__CLASSNAME__', class_name)
681 tmpl = tmpl.replace('__PYFOLD__', pyfold)
682 opt = "" if opt == {} else opt
684 first = True
685 for rt in runtime:
686 if first:
687 tmpl += textwrap.dedent("""
689 def profile0_{rt}(iter, cl, N, nf, opset, dtype, optim):
690 return setup_profile0(iter, cl, '{rt}', N, nf, opset, dtype, optim)
691 iter = profile0_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt})
692 print(datetime.now(), "iter", iter)
694 """).format(rt=rt, dim=dim, nf=nf, opset=opset,
695 dtype=dtype, opt=f"{opt!r}")
696 first = False
698 tmpl += textwrap.dedent("""
700 def profile_{rt}(iter, cl, N, nf, opset, dtype, optim):
701 return setup_profile(iter, cl, '{rt}', N, nf, opset, dtype, optim)
702 profile_{rt}(iter, cl, {dim}, {nf}, {opset}, '{dtype}', {opt})
703 print(datetime.now(), "iter", iter)
705 """).format(rt=rt, dim=dim, nf=nf, opset=opset,
706 dtype=dtype, opt=f"{opt!r}")
708 thename = "{n}_{dim}_{nf}_{opset}_{dtype}_{opt}.py".format(
709 n=fullname_pyspy, dim=dim, nf=nf,
710 opset=opset, dtype=dtype, opt=_optdict2string(opt))
711 with open(thename, 'w', encoding='utf-8') as f:
712 f.write(tmpl)
713 names.append(thename)
715 ext = '.bat' if sys.platform.startswith(
716 'win') else '.sh'
717 script = os.path.splitext(thename)[0] + ext
718 short = os.path.splitext(
719 os.path.split(thename)[-1])[0]
720 with open(script, 'w', encoding='utf-8') as f:
721 f.write('py-spy record --native --function --rate=10 -o {n}_fct.svg -- {py} {n}.py\n'.format(
722 py=sys.executable, n=short))
723 f.write('py-spy record --native --rate=10 -o {n}_line.svg -- {py} {n}.py\n'.format(
724 py=sys.executable, n=short))
726 return names