Coverage for mlprodict/onnxrt/ops_cpu/op_batch_normalization.py: 92%
40 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from onnx.defs import onnx_opset_version
9from ._op import OpRun
12def _batchnorm_test_mode(x, s, bias, mean, var, epsilon=1e-5):
13 dims_x = len(x.shape)
14 dim_ones = (1,) * (dims_x - 2)
15 s = s.reshape(-1, *dim_ones)
16 bias = bias.reshape(-1, *dim_ones)
17 mean = mean.reshape(-1, *dim_ones)
18 var = var.reshape(-1, *dim_ones)
19 y = s * (x - mean) / numpy.sqrt(var + epsilon) + bias
20 return y.astype(x.dtype)
23def _batchnorm_training_mode(x, s, bias, mean, var, momentum=0.9,
24 epsilon=1e-5):
25 axis = tuple(numpy.delete(numpy.arange(len(x.shape)), 1))
26 saved_mean = x.mean(axis=axis)
27 saved_var = x.var(axis=axis)
28 output_mean = mean * momentum + saved_mean * (1 - momentum)
29 output_var = var * momentum + saved_var * (1 - momentum)
30 y = _batchnorm_test_mode(x, s, bias, saved_mean, saved_var,
31 epsilon=epsilon)
32 return (y.astype(x.dtype), saved_mean.astype(x.dtype),
33 saved_var.astype(x.dtype), output_mean.astype(x.dtype),
34 output_var.astype(x.dtype))
37class BatchNormalization_9(OpRun):
39 atts = {'epsilon': 1e-5, 'momentum': 0.9}
41 def __init__(self, onnx_node, desc=None, **options):
42 OpRun.__init__(self, onnx_node, desc=desc,
43 expected_attributes=BatchNormalization.atts,
44 **options)
46 def _run(self, x, scale, bias, mean, var, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
47 res = _batchnorm_test_mode(
48 x, scale, bias, mean, var, epsilon=self.epsilon)
49 return (res, )
52class BatchNormalization_14(OpRun):
54 atts = {'epsilon': 1e-5, 'momentum': 0.9, 'training_mode': 0}
56 def __init__(self, onnx_node, desc=None, **options):
57 OpRun.__init__(self, onnx_node, desc=desc,
58 expected_attributes=BatchNormalization.atts,
59 **options)
61 def _run(self, x, scale, bias, mean, var, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
62 if self.training_mode == 0:
63 res = _batchnorm_test_mode(
64 x, scale, bias, mean, var, epsilon=self.epsilon)
65 return (res, )
66 res, __, _, output_mean, output_var = (
67 _batchnorm_training_mode(x, scale, bias, mean, var,
68 self.momentum, self.epsilon))
69 return res, output_mean, output_var
72if onnx_opset_version() >= 14:
73 BatchNormalization = BatchNormalization_14
74else: # pragma: no cover
75 BatchNormalization = BatchNormalization_9