Coverage for mlprodict/cli/validate.py: 95%

252 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Command line about validation of prediction runtime. 

4""" 

5import os 

6from io import StringIO 

7from logging import getLogger 

8import warnings 

9import json 

10from multiprocessing import Pool 

11from pandas import DataFrame, read_csv, concat 

12from sklearn.exceptions import ConvergenceWarning 

13 

14 

15def benchmark_doc(runtime, black_list=None, white_list=None, 

16 out_raw='bench_raw.xlsx', out_summary="bench_summary.xlsx", 

17 dump_dir='dump', fLOG=print, verbose=0): 

18 """ 

19 Runs the benchmark published into the documentation 

20 (see :ref:`l-onnx-bench-onnxruntime1` and 

21 :ref:`l-onnx-bench-python_compiled`). 

22 

23 :param runtime: runtime (python, python_compiled, 

24 onnxruntime1, onnxruntime2) 

25 :param black_list: models to skip, None for none 

26 (comma separated list) 

27 :param white_list: models to benchmark, None for all 

28 (comma separated list) 

29 :param out_raw: all results are saved in that file 

30 :param out_summary: all results are summarized in that file 

31 :param dump_dir: folder where to dump intermediate results 

32 :param fLOG: logging function 

33 :param verbose: verbosity 

34 :return: list of created files 

35 """ 

36 def _save(df, name): 

37 ext = os.path.splitext(name)[-1] 

38 if ext == '.xlsx': 

39 df.to_excel(name, index=False) 

40 elif ext == '.csv': 

41 df.to_csv(name, index=False) 

42 else: 

43 raise ValueError( # pragma: no cover 

44 f"Unexpected extension in {name!r}.") 

45 if verbose > 1: 

46 fLOG( # pragma: no cover 

47 f"[mlprodict] wrote '{name}'") 

48 

49 from pyquickhelper.loghelper import run_cmd 

50 from pyquickhelper.loghelper.run_cmd import get_interpreter_path 

51 from tqdm import tqdm 

52 from ..onnxrt.validate.validate_helper import sklearn_operators 

53 from ..onnx_conv import ( 

54 register_converters, register_rewritten_operators, register_new_operators) 

55 register_converters() 

56 try: 

57 register_rewritten_operators() 

58 register_new_operators() 

59 except KeyError: # pragma: no cover 

60 warnings.warn("converter for HistGradientBoosting* not not exist. " 

61 "Upgrade sklearn-onnx") 

62 

63 if black_list is None: 

64 black_list = [] 

65 else: 

66 black_list = black_list.split(',') 

67 if white_list is None: 

68 white_list = [] 

69 else: 

70 white_list = white_list.split(',') 

71 

72 filenames = [] 

73 skls = sklearn_operators(extended=True) 

74 skls = [_['name'] for _ in skls] 

75 if white_list: 

76 skls = [_ for _ in skls if _ in white_list] 

77 skls.sort() 

78 if verbose > 0: 

79 pbar = tqdm(skls) 

80 else: 

81 pbar = skls 

82 for op in pbar: 

83 if black_list is not None and op in black_list: 

84 continue 

85 if verbose > 0: 

86 pbar.set_description( # pragma: no cover 

87 f"[{op + ' ' * (25 - len(op))}]") 

88 

89 loop_out_raw = os.path.join( 

90 dump_dir, f"bench_raw_{runtime}_{op}.csv") 

91 loop_out_sum = os.path.join( 

92 dump_dir, f"bench_sum_{runtime}_{op}.csv") 

93 cmd = ('{0} -m mlprodict validate_runtime --verbose=0 --out_raw={1} --out_summary={2} ' 

94 '--benchmark=1 --dump_folder={3} --runtime={4} --models={5}'.format( 

95 get_interpreter_path(), loop_out_raw, loop_out_sum, dump_dir, runtime, op)) 

96 if verbose > 1: 

97 fLOG(f"[mlprodict] cmd '{cmd}'.") # pragma: no cover 

98 out, err = run_cmd(cmd, wait=True, fLOG=None) 

99 if not os.path.exists(loop_out_sum): # pragma: no cover 

100 if verbose > 2: 

101 fLOG(f"[mlprodict] unable to find '{loop_out_sum}'.") 

102 if verbose > 1: 

103 fLOG(f"[mlprodict] cmd '{cmd}'") 

104 fLOG(f"[mlprodict] unable to find '{loop_out_sum}'") 

105 msg = "Unable to find '{}'\n--CMD--\n{}\n--OUT--\n{}\n--ERR--\n{}".format( 

106 loop_out_sum, cmd, out, err) 

107 if verbose > 1: 

108 fLOG(msg) 

109 rows = [{'name': op, 'scenario': 'CRASH', 

110 'ERROR-msg': msg.replace("\n", " -- ")}] 

111 df = DataFrame(rows) 

112 df.to_csv(loop_out_sum, index=False) 

113 filenames.append((loop_out_raw, loop_out_sum)) 

114 

115 # concatenate summaries 

116 dfs_raw = [read_csv(name[0]) 

117 for name in filenames if os.path.exists(name[0])] 

118 dfs_sum = [read_csv(name[1]) 

119 for name in filenames if os.path.exists(name[1])] 

120 df_raw = concat(dfs_raw, sort=False) 

121 piv = concat(dfs_sum, sort=False) 

122 

123 opset_cols = [(int(oc.replace("opset", "")), oc) 

124 for oc in piv.columns if 'opset' in oc] 

125 opset_cols.sort(reverse=True) 

126 opset_cols = [oc[1] for oc in opset_cols] 

127 new_cols = opset_cols[:1] 

128 bench_cols = ["RT/SKL-N=1", "N=10", "N=100", 

129 "N=1000", "N=10000"] 

130 new_cols.extend(["ERROR-msg", "name", "problem", "scenario", 'optim']) 

131 new_cols.extend(bench_cols) 

132 new_cols.extend(opset_cols[1:]) 

133 for c in bench_cols: 

134 new_cols.append(c + '-min') 

135 new_cols.append(c + '-max') 

136 for c in piv.columns: 

137 if c.startswith("skl_") or c.startswith("onx_"): 

138 new_cols.append(c) 

139 new_cols = [_ for _ in new_cols if _ in piv.columns] 

140 piv = piv[new_cols] 

141 

142 _save(piv, out_summary) 

143 _save(df_raw, out_raw) 

144 return filenames 

145 

146 

147def validate_runtime(verbose=1, opset_min=-1, opset_max="", 

148 check_runtime=True, runtime='python', debug=False, 

149 models=None, out_raw="model_onnx_raw.xlsx", 

150 out_summary="model_onnx_summary.xlsx", 

151 dump_folder=None, dump_all=False, benchmark=False, 

152 catch_warnings=True, assume_finite=True, 

153 versions=False, skip_models=None, 

154 extended_list=True, separate_process=False, 

155 time_kwargs=None, n_features=None, fLOG=print, 

156 out_graph=None, force_return=False, 

157 dtype=None, skip_long_test=False, 

158 number=1, repeat=1, time_kwargs_fact='lin', 

159 time_limit=4, n_jobs=0): 

160 """ 

161 Walks through most of :epkg:`scikit-learn` operators 

162 or model or predictor or transformer, tries to convert 

163 them into :epkg:`ONNX` and computes the predictions 

164 with a specific runtime. 

165 

166 :param verbose: integer from 0 (None) to 2 (full verbose) 

167 :param opset_min: tries every conversion from this minimum opset, 

168 -1 to get the current opset 

169 :param opset_max: tries every conversion up to maximum opset, 

170 -1 to get the current opset 

171 :param check_runtime: to check the runtime 

172 and not only the conversion 

173 :param runtime: runtime to check, python, 

174 onnxruntime1 to check :epkg:`onnxruntime`, 

175 onnxruntime2 to check every *ONNX* node independently 

176 with onnxruntime, many runtime can be checked at the same time 

177 if the value is a comma separated list 

178 :param models: comma separated list of models to test or empty 

179 string to test them all 

180 :param skip_models: models to skip 

181 :param debug: stops whenever an exception is raised, 

182 only if *separate_process* is False 

183 :param out_raw: output raw results into this file (excel format) 

184 :param out_summary: output an aggregated view into this file (excel format) 

185 :param dump_folder: folder where to dump information (pickle) 

186 in case of mismatch 

187 :param dump_all: dumps all models, not only the failing ones 

188 :param benchmark: run benchmark 

189 :param catch_warnings: catch warnings 

190 :param assume_finite: See `config_context 

191 <https://scikit-learn.org/stable/modules/generated/sklearn.config_context.html>`_, 

192 If True, validation for finiteness will be skipped, saving time, but leading 

193 to potential crashes. If False, validation for finiteness will be performed, 

194 avoiding error. 

195 :param versions: add columns with versions of used packages, 

196 :epkg:`numpy`, :epkg:`scikit-learn`, :epkg:`onnx`, :epkg:`onnxruntime`, 

197 :epkg:`sklearn-onnx` 

198 :param extended_list: extends the list of :epkg:`scikit-learn` converters 

199 with converters implemented in this module 

200 :param separate_process: run every model in a separate process, 

201 this option must be used to run all model in one row 

202 even if one of them is crashing 

203 :param time_kwargs: a dictionary which defines the number of rows and 

204 the parameter *number* and *repeat* when benchmarking a model, 

205 the value must follow :epkg:`json` format 

206 :param n_features: change the default number of features for 

207 a specific problem, it can also be a comma separated list 

208 :param force_return: forces the function to return the results, 

209 used when the results are produces through a separate process 

210 :param out_graph: image name, to output a graph which summarizes 

211 a benchmark in case it was run 

212 :param dtype: '32' or '64' or None for both, 

213 limits the test to one specific number types 

214 :param skip_long_test: skips tests for high values of N if 

215 they seem too long 

216 :param number: to multiply number values in *time_kwargs* 

217 :param repeat: to multiply repeat values in *time_kwargs* 

218 :param time_kwargs_fact: to multiply number and repeat in 

219 *time_kwargs* depending on the model 

220 (see :func:`_multiply_time_kwargs <mlprodict.onnxrt.validate.validate_helper._multiply_time_kwargs>`) 

221 :param time_limit: to stop benchmarking after this limit of time 

222 :param n_jobs: force the number of jobs to have this value, 

223 by default, it is equal to the number of CPU 

224 :param fLOG: logging function 

225 

226 .. cmdref:: 

227 :title: Validates a runtime against scikit-learn 

228 :cmd: -m mlprodict validate_runtime --help 

229 :lid: l-cmd-validate_runtime 

230 

231 The command walks through all scikit-learn operators, 

232 tries to convert them, checks the predictions, 

233 and produces a report. 

234 

235 Example:: 

236 

237 python -m mlprodict validate_runtime --models LogisticRegression,LinearRegression 

238 

239 Following example benchmarks models 

240 :epkg:`sklearn:ensemble:RandomForestRegressor`, 

241 :epkg:`sklearn:tree:DecisionTreeRegressor`, it compares 

242 :epkg:`onnxruntime` against :epkg:`scikit-learn` for opset 10. 

243 

244 :: 

245 

246 python -m mlprodict validate_runtime -v 1 -o 10 -op 10 -c 1 -r onnxruntime1 

247 -m RandomForestRegressor,DecisionTreeRegressor -out bench_onnxruntime.xlsx -b 1 

248 

249 Parameter ``--time_kwargs`` may be used to reduce or increase 

250 bencharmak precisions. The following value tells the function 

251 to run a benchmarks with datasets of 1 or 10 number, to repeat 

252 a given number of time *number* predictions in one row. 

253 The total time is divided by :math:`number \\times repeat``. 

254 Parameter ``--time_kwargs_fact`` may be used to increase these 

255 number for some specific models. ``'lin'`` multiplies 

256 by 10 number when the model is linear. 

257 

258 :: 

259 

260 -t "{\\"1\\":{\\"number\\":10,\\"repeat\\":10},\\"10\\":{\\"number\\":5,\\"repeat\\":5}}" 

261 

262 The following example dumps every model in the list: 

263 

264 :: 

265 

266 python -m mlprodict validate_runtime --out_raw raw.csv --out_summary sum.csv 

267 --models LinearRegression,LogisticRegression,DecisionTreeRegressor,DecisionTreeClassifier 

268 -r python,onnxruntime1 -o 10 -op 10 -v 1 -b 1 -dum 1 

269 -du model_dump -n 20,100,500 --out_graph benchmark.png --dtype 32 

270 

271 The command line generates a graph produced by function 

272 :func:`plot_validate_benchmark 

273 <mlprodict.onnxrt.validate.validate_graph.plot_validate_benchmark>`. 

274 """ 

275 if separate_process: 

276 return _validate_runtime_separate_process( 

277 verbose=verbose, opset_min=opset_min, opset_max=opset_max, 

278 check_runtime=check_runtime, runtime=runtime, debug=debug, 

279 models=models, out_raw=out_raw, 

280 out_summary=out_summary, dump_all=dump_all, 

281 dump_folder=dump_folder, benchmark=benchmark, 

282 catch_warnings=catch_warnings, assume_finite=assume_finite, 

283 versions=versions, skip_models=skip_models, 

284 extended_list=extended_list, time_kwargs=time_kwargs, 

285 n_features=n_features, fLOG=fLOG, force_return=True, 

286 out_graph=None, dtype=dtype, skip_long_test=skip_long_test, 

287 time_kwargs_fact=time_kwargs_fact, time_limit=time_limit, 

288 n_jobs=n_jobs) 

289 

290 from ..onnxrt.validate import enumerate_validated_operator_opsets # pylint: disable=E0402 

291 

292 if not isinstance(models, list): 

293 models = (None if models in (None, "") 

294 else models.strip().split(',')) 

295 if not isinstance(skip_models, list): 

296 skip_models = ({} if skip_models in (None, "") 

297 else skip_models.strip().split(',')) 

298 if verbose <= 1: 

299 logger = getLogger('skl2onnx') 

300 logger.disabled = True 

301 if not dump_folder: 

302 dump_folder = None 

303 if dump_folder and not os.path.exists(dump_folder): 

304 os.mkdir(dump_folder) # pragma: no cover 

305 if dump_folder and not os.path.exists(dump_folder): 

306 raise FileNotFoundError( # pragma: no cover 

307 f"Cannot find dump_folder '{dump_folder}'.") 

308 

309 # handling parameters 

310 if opset_max == "": 

311 opset_max = None # pragma: no cover 

312 if isinstance(opset_min, str): 

313 opset_min = int(opset_min) # pragma: no cover 

314 if isinstance(opset_max, str): 

315 opset_max = int(opset_max) 

316 if isinstance(verbose, str): 

317 verbose = int(verbose) # pragma: no cover 

318 if isinstance(extended_list, str): 

319 extended_list = extended_list in ( 

320 '1', 'True', 'true') # pragma: no cover 

321 if time_kwargs in (None, ''): 

322 time_kwargs = None 

323 if isinstance(time_kwargs, str): 

324 time_kwargs = json.loads(time_kwargs) 

325 # json only allows string as keys 

326 time_kwargs = {int(k): v for k, v in time_kwargs.items()} 

327 if isinstance(n_jobs, str): 

328 n_jobs = int(n_jobs) 

329 if n_jobs == 0: 

330 n_jobs = None 

331 if time_kwargs is not None and not isinstance(time_kwargs, dict): 

332 raise ValueError( # pragma: no cover 

333 f"time_kwargs must be a dictionary not {type(time_kwargs)}\n{time_kwargs}") 

334 if not isinstance(n_features, list): 

335 if n_features in (None, ""): 

336 n_features = None 

337 elif ',' in n_features: 

338 n_features = list(map(int, n_features.split(','))) 

339 else: 

340 n_features = int(n_features) 

341 if not isinstance(runtime, list) and ',' in runtime: 

342 runtime = runtime.split(',') 

343 

344 def fct_filter_exp(m, s): 

345 cl = m.__name__ 

346 if cl in skip_models: 

347 return False 

348 pair = f"{cl}[{s}]" 

349 if pair in skip_models: 

350 return False 

351 return True 

352 

353 if dtype in ('', None): 

354 fct_filter = fct_filter_exp 

355 elif dtype == '32': 

356 def fct_filter_exp2(m, p): 

357 return fct_filter_exp(m, p) and '64' not in p 

358 fct_filter = fct_filter_exp2 

359 elif dtype == '64': # pragma: no cover 

360 def fct_filter_exp3(m, p): 

361 return fct_filter_exp(m, p) and '64' in p 

362 fct_filter = fct_filter_exp3 

363 else: 

364 raise ValueError( # pragma: no cover 

365 f"dtype must be empty, 32, 64 not '{dtype}'.") 

366 

367 # time_kwargs 

368 

369 if benchmark: 

370 if time_kwargs is None: 

371 from ..onnxrt.validate.validate_helper import default_time_kwargs # pylint: disable=E0402 

372 time_kwargs = default_time_kwargs() 

373 for _, v in time_kwargs.items(): 

374 v['number'] *= number 

375 v['repeat'] *= repeat 

376 if verbose > 0: 

377 fLOG(f"time_kwargs={time_kwargs!r}") 

378 

379 # body 

380 

381 def build_rows(models_): 

382 rows = list(enumerate_validated_operator_opsets( 

383 verbose, models=models_, fLOG=fLOG, runtime=runtime, debug=debug, 

384 dump_folder=dump_folder, opset_min=opset_min, opset_max=opset_max, 

385 benchmark=benchmark, assume_finite=assume_finite, versions=versions, 

386 extended_list=extended_list, time_kwargs=time_kwargs, dump_all=dump_all, 

387 n_features=n_features, filter_exp=fct_filter, 

388 skip_long_test=skip_long_test, time_limit=time_limit, 

389 time_kwargs_fact=time_kwargs_fact, n_jobs=n_jobs)) 

390 return rows 

391 

392 def catch_build_rows(models_): 

393 if catch_warnings: 

394 with warnings.catch_warnings(): 

395 warnings.simplefilter("ignore", 

396 (UserWarning, ConvergenceWarning, 

397 RuntimeWarning, FutureWarning)) 

398 rows = build_rows(models_) 

399 else: 

400 rows = build_rows(models_) # pragma: no cover 

401 return rows 

402 

403 rows = catch_build_rows(models) 

404 res = _finalize(rows, out_raw, out_summary, 

405 verbose, models, out_graph, fLOG) 

406 return res if (force_return or verbose >= 2) else None 

407 

408 

409def _finalize(rows, out_raw, out_summary, verbose, models, out_graph, fLOG): 

410 from ..onnxrt.validate import summary_report # pylint: disable=E0402 

411 from ..tools.cleaning import clean_error_msg # pylint: disable=E0402 

412 

413 # Drops data which cannot be serialized. 

414 for row in rows: 

415 keys = [] 

416 for k in row: 

417 if 'lambda' in k: 

418 keys.append(k) 

419 for k in keys: 

420 del row[k] 

421 

422 df = DataFrame(rows) 

423 

424 if out_raw: 

425 if verbose > 0: 

426 fLOG(f"Saving raw_data into '{out_raw}'.") 

427 if os.path.splitext(out_raw)[-1] == ".xlsx": 

428 df.to_excel(out_raw, index=False) 

429 else: 

430 clean_error_msg(df).to_csv(out_raw, index=False) 

431 

432 if df.shape[0] == 0: 

433 raise RuntimeError("No result produced by the benchmark.") 

434 piv = summary_report(df) 

435 if 'optim' not in piv: 

436 raise RuntimeError( # pragma: no cover 

437 f"Unable to produce a summary. Missing column in \n{piv.columns}") 

438 

439 if out_summary: 

440 if verbose > 0: 

441 fLOG(f"Saving summary into '{out_summary}'.") 

442 if os.path.splitext(out_summary)[-1] == ".xlsx": 

443 piv.to_excel(out_summary, index=False) 

444 else: 

445 clean_error_msg(piv).to_csv(out_summary, index=False) 

446 

447 if verbose > 1 and models is not None: 

448 fLOG(piv.T) 

449 if out_graph is not None: 

450 if verbose > 0: 

451 fLOG(f"Saving graph into '{out_graph}'.") 

452 from ..plotting.plotting import plot_validate_benchmark 

453 fig = plot_validate_benchmark(piv)[0] 

454 fig.savefig(out_graph) 

455 

456 return rows 

457 

458 

459def _validate_runtime_dict(kwargs): 

460 return validate_runtime(**kwargs) 

461 

462 

463def _validate_runtime_separate_process(**kwargs): 

464 models = kwargs['models'] 

465 if models in (None, ""): 

466 from ..onnxrt.validate.validate_helper import sklearn_operators # pragma: no cover 

467 models = [_['name'] 

468 for _ in sklearn_operators(extended=True)] # pragma: no cover 

469 elif not isinstance(models, list): 

470 models = models.strip().split(',') 

471 

472 skip_models = kwargs['skip_models'] 

473 skip_models = {} if skip_models in ( 

474 None, "") else skip_models.strip().split(',') 

475 

476 verbose = kwargs['verbose'] 

477 fLOG = kwargs['fLOG'] 

478 all_rows = [] 

479 skls = [m for m in models if m not in skip_models] 

480 skls.sort() 

481 

482 if verbose > 0: 

483 from tqdm import tqdm 

484 pbar = tqdm(skls) 

485 else: 

486 pbar = skls # pragma: no cover 

487 

488 for op in pbar: 

489 if not isinstance(pbar, list): 

490 pbar.set_description(f"[{op + ' ' * (25 - len(op))}]") 

491 

492 if kwargs['out_raw']: 

493 out_raw = os.path.splitext(kwargs['out_raw']) 

494 out_raw = "".join([out_raw[0], "_", op, out_raw[1]]) 

495 else: 

496 out_raw = None # pragma: no cover 

497 

498 if kwargs['out_summary']: 

499 out_summary = os.path.splitext(kwargs['out_summary']) 

500 out_summary = "".join([out_summary[0], "_", op, out_summary[1]]) 

501 else: 

502 out_summary = None # pragma: no cover 

503 

504 new_kwargs = kwargs.copy() 

505 if 'fLOG' in new_kwargs: 

506 del new_kwargs['fLOG'] 

507 new_kwargs['out_raw'] = out_raw 

508 new_kwargs['out_summary'] = out_summary 

509 new_kwargs['models'] = op 

510 new_kwargs['verbose'] = 0 # tqdm fails 

511 new_kwargs['out_graph'] = None 

512 

513 with Pool(1) as p: 

514 try: 

515 result = p.apply_async(_validate_runtime_dict, [new_kwargs]) 

516 lrows = result.get(timeout=150) # timeout fixed to 150s 

517 all_rows.extend(lrows) 

518 except Exception as e: # pylint: disable=W0703 

519 all_rows.append({ # pragma: no cover 

520 'name': op, 'scenario': 'CRASH', 

521 'ERROR-msg': str(e).replace("\n", " -- ") 

522 }) 

523 

524 return _finalize(all_rows, kwargs['out_raw'], kwargs['out_summary'], 

525 verbose, models, kwargs.get('out_graph', None), fLOG) 

526 

527 

528def latency(model, law='normal', size=1, number=10, repeat=10, max_time=0, 

529 runtime="onnxruntime", device='cpu', fmt=None, 

530 profiling=None, profile_output='profiling.csv'): 

531 """ 

532 Measures the latency of a model (python API). 

533 

534 :param model: ONNX graph 

535 :param law: random law used to generate fake inputs 

536 :param size: batch size, it replaces the first dimension 

537 of every input if it is left unknown 

538 :param number: number of calls to measure 

539 :param repeat: number of times to repeat the experiment 

540 :param max_time: if it is > 0, it runs as many time during 

541 that period of time 

542 :param runtime: available runtime 

543 :param device: device, `cpu`, `cuda:0` or a list of providers 

544 `CPUExecutionProvider, CUDAExecutionProvider 

545 :param fmt: None or `csv`, it then 

546 returns a string formatted like a csv file 

547 :param profiling: if True, profile the execution of every 

548 node, if can be sorted by name or type, 

549 the value for this parameter should e in `(None, 'name', 'type')` 

550 :param profile_output: output name for the profiling 

551 if profiling is specified 

552 

553 .. cmdref:: 

554 :title: Measures model latency 

555 :cmd: -m mlprodict latency --help 

556 :lid: l-cmd-latency 

557 

558 The command generates random inputs and call many times the 

559 model on these inputs. It returns the processing time for one 

560 iteration. 

561 

562 Example:: 

563 

564 python -m mlprodict latency --model "model.onnx" 

565 """ 

566 from ..onnxrt.validate.validate_latency import latency as _latency # pylint: disable=E0402 

567 

568 if not os.path.exists(model): 

569 raise FileNotFoundError( # pragma: no cover 

570 f"Unable to find model {model!r}.") 

571 if profiling not in (None, '', 'name', 'type'): 

572 raise ValueError( # pragma: no cover 

573 f"Unexpected value for profiling: {profiling!r}.") 

574 size = int(size) 

575 number = int(number) 

576 repeat = int(repeat) 

577 if max_time in (None, 0, ""): 

578 max_time = None 

579 else: 

580 max_time = float(max_time) 

581 if max_time <= 0: 

582 max_time = None 

583 

584 if law != "normal": 

585 raise ValueError( # pragma: no cover 

586 f"Only law='normal' is supported, not {law!r}.") 

587 

588 if profiling in ('name', 'type') and profile_output in (None, ''): 

589 raise ValueError( # pragma: no cover 

590 f'profiling is enabled but profile_output is wrong ({profile_output!r}).') 

591 

592 res = _latency( 

593 model, law=law, size=size, number=number, repeat=repeat, 

594 max_time=max_time, runtime=runtime, device=device, 

595 profiling=profiling) 

596 

597 if profiling not in (None, ''): 

598 res, gr = res 

599 ext = os.path.splitext(profile_output)[-1] 

600 gr = gr.reset_index(drop=False) 

601 if ext == '.csv': 

602 gr.to_csv(profile_output, index=False) 

603 elif ext == '.xlsx': # pragma: no cover 

604 gr.to_excel(profile_output, index=False) 

605 else: 

606 raise ValueError( # pragma: no cover 

607 f"Unexpected extension for profile_output={profile_output!r}.") 

608 

609 if fmt == 'csv': 

610 st = StringIO() 

611 df = DataFrame([res]) 

612 df.to_csv(st, index=False) 

613 return st.getvalue() 

614 if fmt in (None, ''): 

615 return res 

616 raise ValueError( # pragma: no cover 

617 f"Unexpected value for fmt: {fmt!r}.")