Coverage for mlprodict/onnxrt/ops_cpu/op_softmax_cross_entropy_loss.py: 96%

56 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Runtime operator. 

4""" 

5import numpy 

6from ._op import OpRun 

7 

8 

9def softmaxcrossentropy(x, target, weight=None, reduction='mean', 

10 ignore_index=None, get_log_prob=None): 

11 """ 

12 Modified version of `softmaxcrossentropy.py 

13 <https://github.com/onnx/onnx/blob/main/onnx/backend/ 

14 test/case/node/softmaxcrossentropy.py>`_ to handle other type 

15 than float32. 

16 """ 

17 input_shape = x.shape 

18 if len(input_shape) == 1: 

19 raise RuntimeError(f"Unsupported shape {input_shape!r}.") 

20 

21 target_shape = target.shape 

22 N = input_shape[0] 

23 C = input_shape[1] 

24 

25 # compute log_softmax 

26 max_x = numpy.max(x, axis=1, keepdims=True) 

27 exp_x = numpy.exp(x - max_x) 

28 p = exp_x / numpy.sum(exp_x, axis=1, keepdims=True) 

29 inp = numpy.log(p) 

30 log_prob = None 

31 if get_log_prob is True: 

32 log_prob = numpy.copy(inp) 

33 

34 # initialize the positional weights when required 

35 gather_weight = None 

36 if weight is not None: 

37 gather_weight = numpy.take( 

38 weight, numpy.array(target, dtype=numpy.int32), mode='clip') 

39 if ignore_index is not None: 

40 gather_weight = numpy.where( 

41 target == ignore_index, 0, gather_weight).astype(dtype=x.dtype) 

42 elif ignore_index is not None: 

43 gather_weight = numpy.where( 

44 target == ignore_index, 0, 1).astype(dtype=x.dtype) 

45 

46 # if input is 4-d and above, make it 3-d 

47 if len(input_shape) != 3: 

48 inp = inp.reshape((N, C, -1)) 

49 target = target.reshape((N, -1)) 

50 

51 # Get a dimension from the reshaped input. 

52 # If the original input shape is [N, C, H, W], 

53 # the D here should be H * W because we reshape 

54 # [N, C, H, W] to [N, C, H * W]. 

55 D = inp.shape[2] 

56 neg_gather_element_input = numpy.zeros((N, D), dtype=x.dtype) 

57 for i in range(N): 

58 for d in range(D): 

59 if target[i, d] != ignore_index: 

60 neg_gather_element_input[i, d] = -inp[i, target[i, d], d] 

61 

62 loss = neg_gather_element_input 

63 

64 # if the input was 4-d or above reshape to the right shape 

65 if len(input_shape) != 3: 

66 loss = loss.reshape(target_shape) 

67 

68 # apply the weights when required 

69 if gather_weight is not None: 

70 loss = gather_weight * loss 

71 if reduction == b'mean': 

72 loss = loss.sum() / gather_weight.sum() 

73 if get_log_prob is True: 

74 return loss, log_prob 

75 return (loss, ) 

76 

77 if reduction == b'mean': 

78 loss = numpy.mean(loss) 

79 elif reduction == b'sum': 

80 loss = numpy.sum(loss) 

81 

82 if get_log_prob is True: 

83 return loss, log_prob 

84 return (loss, ) 

85 

86 

87class SoftmaxCrossEntropyLoss(OpRun): 

88 """ 

89 Python runtime for function *SoftmaxCrossEntropyLoss*. 

90 """ 

91 

92 atts = {'reduction': b'mean', 'ignore_index': -1} 

93 

94 def __init__(self, onnx_node, desc=None, **options): 

95 OpRun.__init__(self, onnx_node, desc=desc, 

96 expected_attributes=SoftmaxCrossEntropyLoss.atts, 

97 **options) 

98 

99 def _run(self, x, target, weight=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

100 n_outputs = len(self.onnx_node.output) 

101 return softmaxcrossentropy( 

102 x, target, weight=weight, reduction=self.reduction, # pylint: disable=E1101 

103 ignore_index=self.ignore_index, # pylint: disable=E1101 

104 get_log_prob=n_outputs == 2)