Coverage for mlprodict/onnxrt/ops_cpu/op_softmax_cross_entropy_loss.py: 96%
56 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Runtime operator.
4"""
5import numpy
6from ._op import OpRun
9def softmaxcrossentropy(x, target, weight=None, reduction='mean',
10 ignore_index=None, get_log_prob=None):
11 """
12 Modified version of `softmaxcrossentropy.py
13 <https://github.com/onnx/onnx/blob/main/onnx/backend/
14 test/case/node/softmaxcrossentropy.py>`_ to handle other type
15 than float32.
16 """
17 input_shape = x.shape
18 if len(input_shape) == 1:
19 raise RuntimeError(f"Unsupported shape {input_shape!r}.")
21 target_shape = target.shape
22 N = input_shape[0]
23 C = input_shape[1]
25 # compute log_softmax
26 max_x = numpy.max(x, axis=1, keepdims=True)
27 exp_x = numpy.exp(x - max_x)
28 p = exp_x / numpy.sum(exp_x, axis=1, keepdims=True)
29 inp = numpy.log(p)
30 log_prob = None
31 if get_log_prob is True:
32 log_prob = numpy.copy(inp)
34 # initialize the positional weights when required
35 gather_weight = None
36 if weight is not None:
37 gather_weight = numpy.take(
38 weight, numpy.array(target, dtype=numpy.int32), mode='clip')
39 if ignore_index is not None:
40 gather_weight = numpy.where(
41 target == ignore_index, 0, gather_weight).astype(dtype=x.dtype)
42 elif ignore_index is not None:
43 gather_weight = numpy.where(
44 target == ignore_index, 0, 1).astype(dtype=x.dtype)
46 # if input is 4-d and above, make it 3-d
47 if len(input_shape) != 3:
48 inp = inp.reshape((N, C, -1))
49 target = target.reshape((N, -1))
51 # Get a dimension from the reshaped input.
52 # If the original input shape is [N, C, H, W],
53 # the D here should be H * W because we reshape
54 # [N, C, H, W] to [N, C, H * W].
55 D = inp.shape[2]
56 neg_gather_element_input = numpy.zeros((N, D), dtype=x.dtype)
57 for i in range(N):
58 for d in range(D):
59 if target[i, d] != ignore_index:
60 neg_gather_element_input[i, d] = -inp[i, target[i, d], d]
62 loss = neg_gather_element_input
64 # if the input was 4-d or above reshape to the right shape
65 if len(input_shape) != 3:
66 loss = loss.reshape(target_shape)
68 # apply the weights when required
69 if gather_weight is not None:
70 loss = gather_weight * loss
71 if reduction == b'mean':
72 loss = loss.sum() / gather_weight.sum()
73 if get_log_prob is True:
74 return loss, log_prob
75 return (loss, )
77 if reduction == b'mean':
78 loss = numpy.mean(loss)
79 elif reduction == b'sum':
80 loss = numpy.sum(loss)
82 if get_log_prob is True:
83 return loss, log_prob
84 return (loss, )
87class SoftmaxCrossEntropyLoss(OpRun):
88 """
89 Python runtime for function *SoftmaxCrossEntropyLoss*.
90 """
92 atts = {'reduction': b'mean', 'ignore_index': -1}
94 def __init__(self, onnx_node, desc=None, **options):
95 OpRun.__init__(self, onnx_node, desc=desc,
96 expected_attributes=SoftmaxCrossEntropyLoss.atts,
97 **options)
99 def _run(self, x, target, weight=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
100 n_outputs = len(self.onnx_node.output)
101 return softmaxcrossentropy(
102 x, target, weight=weight, reduction=self.reduction, # pylint: disable=E1101
103 ignore_index=self.ignore_index, # pylint: disable=E1101
104 get_log_prob=n_outputs == 2)