Coverage for mlprodict/onnxrt/ops_cpu/op_dropout.py: 100%

37 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7 

8import numpy 

9from numpy.random import RandomState 

10from onnx.defs import onnx_opset_version 

11from ._op import OpRun 

12 

13 

14def _dropout(X, drop_probability=0.5, seed=0, 

15 training_mode=False, return_mask=False): 

16 if drop_probability == 0 or not training_mode: 

17 if return_mask: 

18 return X, numpy.ones(X.shape, dtype=bool) 

19 return (X, ) 

20 

21 rnd = RandomState(seed) 

22 mask = rnd.uniform(0, 1.0, X.shape) >= drop_probability 

23 scale = (1. / (1. - drop_probability)) 

24 return ( 

25 (mask * X * scale, mask.astype(bool)) 

26 if return_mask else (mask * X * scale, )) 

27 

28 

29class DropoutBase(OpRun): 

30 

31 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): 

32 OpRun.__init__(self, onnx_node, desc=desc, 

33 expected_attributes=expected_attributes, 

34 **options) 

35 self.nb_outputs = len(onnx_node.output) 

36 

37 def _private_run(self, X, seed=None, ratio=0.5, training_mode=False): # pylint: disable=W0221 

38 return _dropout(X, ratio, seed=seed, return_mask=self.nb_outputs == 2, 

39 training_mode=training_mode) 

40 

41 

42class Dropout_7(DropoutBase): 

43 

44 atts = {'ratio': 0.5} 

45 

46 def __init__(self, onnx_node, desc=None, **options): 

47 DropoutBase.__init__(self, onnx_node, desc=desc, 

48 expected_attributes=Dropout_7.atts, 

49 **options) 

50 

51 def _run(self, X, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

52 return self._private_run(X, self.ratio) 

53 

54 

55class Dropout_12(DropoutBase): 

56 

57 atts = {'seed': 0} 

58 

59 def __init__(self, onnx_node, desc=None, **options): 

60 DropoutBase.__init__(self, onnx_node, desc=desc, 

61 expected_attributes=Dropout_12.atts, 

62 **options) 

63 

64 def _run(self, *inputs, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

65 X = inputs[0] 

66 ratio = 0.5 if len(inputs) <= 1 else inputs[1] 

67 training_mode = False if len(inputs) <= 2 else inputs[2] 

68 return self._private_run(X, seed=self.seed, ratio=ratio, 

69 training_mode=training_mode) 

70 

71 

72if onnx_opset_version() >= 12: 

73 Dropout = Dropout_12 

74else: 

75 Dropout = Dropout_7 # pragma: no cover