Coverage for mlprodict/onnxrt/ops_cpu/op_negative_log_likelihood_loss.py: 98%

44 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Runtime operator. 

4""" 

5import numpy 

6from ._op import OpRun 

7 

8 

9def _compute_negative_log_likelihood_loss(x, target, weight=None, 

10 reduction=b'mean', ignore_index=None): 

11 """ 

12 Modified version of `softmaxcrossentropy.py 

13 <https://github.com/onnx/onnx/blob/main/onnx/backend/ 

14 test/case/node/negativeloglikelihoodloss.py>`_ to handle other type 

15 than float32. 

16 """ 

17 input_shape = x.shape 

18 if len(input_shape) == 1: 

19 raise RuntimeError(f"Unsupported shape {input_shape!r}.") 

20 

21 target_shape = target.shape 

22 N = input_shape[0] 

23 C = input_shape[1] 

24 

25 # initialize the positional weights when required 

26 gather_weight = None 

27 if weight is not None: 

28 # setting mode='clip' to deal with ignore_index > C or < 0 cases. 

29 # when the target value is > C or < 0, it doesn't matter which value we are 

30 # taking in gather_weight, since it will be set to 0 in the following if-block 

31 # use numpy.int32 to make it compatible with x86 machines 

32 gather_weight = numpy.take(weight, numpy.array( 

33 target, dtype=numpy.int32), mode='clip') 

34 # set `ignore_index`'s loss weight to 0. 

35 # The loss tensor will be multiplied by this weight tensor, 

36 # so `ingore_index`'s loss value will be eliminated. 

37 if ignore_index is not None: 

38 gather_weight = numpy.where( 

39 target == ignore_index, 0, gather_weight).astype(dtype=x.dtype) 

40 elif ignore_index != -1: 

41 gather_weight = numpy.where( 

42 target == ignore_index, 0, 1).astype(dtype=x.dtype) 

43 

44 # if input is 4-d and above, make it 3-d 

45 if len(input_shape) != 3: 

46 x = x.reshape((N, C, -1)) 

47 target = target.reshape((N, -1)) 

48 

49 # Get a dimension from the reshaped input. 

50 # If the original input shape is [N, C, H, W], 

51 # the D here should be H * W because we reshape 

52 # [N, C, H, W] to [N, C, H * W]. 

53 D = x.shape[2] 

54 neg_gather_element_input = numpy.zeros((N, D), dtype=x.dtype) 

55 for i in range(N): 

56 for d in range(D): 

57 if target[i][d] != ignore_index: 

58 neg_gather_element_input[i][d] = -x[i][target[i][d]][d] 

59 

60 loss = neg_gather_element_input 

61 

62 # if the input was 4-d or above reshape to the right shape 

63 if len(input_shape) != 3: 

64 loss = loss.reshape(target_shape) 

65 

66 # apply the weights when required 

67 if gather_weight is not None: 

68 loss = gather_weight * loss 

69 if reduction == b'mean': 

70 loss = loss.sum() / gather_weight.sum() 

71 return (loss, ) 

72 

73 if reduction == b'mean': 

74 loss = numpy.mean(loss) 

75 elif reduction == b'sum': 

76 loss = numpy.sum(loss) 

77 return (loss, ) 

78 

79 

80class NegativeLogLikelihoodLoss(OpRun): 

81 """ 

82 Python runtime for function *NegativeLogLikelihoodLoss*. 

83 """ 

84 

85 atts = {'reduction': b'mean', 'ignore_index': -1} 

86 

87 def __init__(self, onnx_node, desc=None, **options): 

88 OpRun.__init__(self, onnx_node, desc=desc, 

89 expected_attributes=NegativeLogLikelihoodLoss.atts, 

90 **options) 

91 

92 def _run(self, x, target, weight=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

93 return _compute_negative_log_likelihood_loss( 

94 x, target, weight=weight, reduction=self.reduction, # pylint: disable=E1101 

95 ignore_index=self.ignore_index) # pylint: disable=E1101