Coverage for mlprodict/onnxrt/ops_cpu/op_negative_log_likelihood_loss.py: 98%
44 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Runtime operator.
4"""
5import numpy
6from ._op import OpRun
9def _compute_negative_log_likelihood_loss(x, target, weight=None,
10 reduction=b'mean', ignore_index=None):
11 """
12 Modified version of `softmaxcrossentropy.py
13 <https://github.com/onnx/onnx/blob/main/onnx/backend/
14 test/case/node/negativeloglikelihoodloss.py>`_ to handle other type
15 than float32.
16 """
17 input_shape = x.shape
18 if len(input_shape) == 1:
19 raise RuntimeError(f"Unsupported shape {input_shape!r}.")
21 target_shape = target.shape
22 N = input_shape[0]
23 C = input_shape[1]
25 # initialize the positional weights when required
26 gather_weight = None
27 if weight is not None:
28 # setting mode='clip' to deal with ignore_index > C or < 0 cases.
29 # when the target value is > C or < 0, it doesn't matter which value we are
30 # taking in gather_weight, since it will be set to 0 in the following if-block
31 # use numpy.int32 to make it compatible with x86 machines
32 gather_weight = numpy.take(weight, numpy.array(
33 target, dtype=numpy.int32), mode='clip')
34 # set `ignore_index`'s loss weight to 0.
35 # The loss tensor will be multiplied by this weight tensor,
36 # so `ingore_index`'s loss value will be eliminated.
37 if ignore_index is not None:
38 gather_weight = numpy.where(
39 target == ignore_index, 0, gather_weight).astype(dtype=x.dtype)
40 elif ignore_index != -1:
41 gather_weight = numpy.where(
42 target == ignore_index, 0, 1).astype(dtype=x.dtype)
44 # if input is 4-d and above, make it 3-d
45 if len(input_shape) != 3:
46 x = x.reshape((N, C, -1))
47 target = target.reshape((N, -1))
49 # Get a dimension from the reshaped input.
50 # If the original input shape is [N, C, H, W],
51 # the D here should be H * W because we reshape
52 # [N, C, H, W] to [N, C, H * W].
53 D = x.shape[2]
54 neg_gather_element_input = numpy.zeros((N, D), dtype=x.dtype)
55 for i in range(N):
56 for d in range(D):
57 if target[i][d] != ignore_index:
58 neg_gather_element_input[i][d] = -x[i][target[i][d]][d]
60 loss = neg_gather_element_input
62 # if the input was 4-d or above reshape to the right shape
63 if len(input_shape) != 3:
64 loss = loss.reshape(target_shape)
66 # apply the weights when required
67 if gather_weight is not None:
68 loss = gather_weight * loss
69 if reduction == b'mean':
70 loss = loss.sum() / gather_weight.sum()
71 return (loss, )
73 if reduction == b'mean':
74 loss = numpy.mean(loss)
75 elif reduction == b'sum':
76 loss = numpy.sum(loss)
77 return (loss, )
80class NegativeLogLikelihoodLoss(OpRun):
81 """
82 Python runtime for function *NegativeLogLikelihoodLoss*.
83 """
85 atts = {'reduction': b'mean', 'ignore_index': -1}
87 def __init__(self, onnx_node, desc=None, **options):
88 OpRun.__init__(self, onnx_node, desc=desc,
89 expected_attributes=NegativeLogLikelihoodLoss.atts,
90 **options)
92 def _run(self, x, target, weight=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
93 return _compute_negative_log_likelihood_loss(
94 x, target, weight=weight, reduction=self.reduction, # pylint: disable=E1101
95 ignore_index=self.ignore_index) # pylint: disable=E1101