Coverage for onnxcustom/training/sgd_learning_penalty.py: 100%

60 statements  

« prev     ^ index     » next       coverage.py v7.0.5, created at 2023-01-17 01:42 +0100

1# pylint: disable=W0105 

2""" 

3@file 

4@brief Helper for :epkg:`onnxruntime-training`. 

5""" 

6from onnxruntime import SessionOptions, InferenceSession, RunOptions 

7from ..utils.onnx_function import function_onnx_graph 

8from ..utils.onnxruntime_helper import device_to_providers 

9from ._base_onnx_function import BaseLearningOnnx 

10 

11 

12class BaseLearningPenalty(BaseLearningOnnx): 

13 """ 

14 Class handling the penalty on the coefficients for class 

15 @see cl OrtGradientForwardBackwardOptimizer. 

16 """ 

17 

18 def __init__(self): 

19 BaseLearningOnnx.__init__(self) 

20 self.ro_ = RunOptions() 

21 

22 def _call_iobinding(self, sess, bind): 

23 sess.run_with_iobinding(bind, self.ro_) 

24 

25 @staticmethod 

26 def select(class_name, **kwargs): 

27 """ 

28 Returns an instance of a given initialized with 

29 *kwargs*. 

30 :param class_name: an instance of @see cl BaseLearningPenalty 

31 or a string among the following class names (see below) 

32 :return: instance of @see cl BaseLearningPenalty 

33 

34 Possible values for *class_name*: 

35 * None or `'penalty'`: see @see cl L1L2PenaltyLearning 

36 """ 

37 if isinstance(class_name, BaseLearningPenalty): 

38 return class_name 

39 cls = {NoLearningPenalty: [None, ''], 

40 ElasticLearningPenalty: ['elastic', 'l1l2']} 

41 for cl, aliases in cls.items(): 

42 if class_name == cl.__class__.__name__ or class_name in aliases: 

43 return cl(**kwargs) 

44 raise ValueError( # pragma: no cover 

45 "Unexpected class name %r. It should be one of %r." % ( 

46 class_name, list(map(lambda c: c.__name__, cls)))) 

47 

48 def penalty_loss(self, device, loss, *weights): 

49 """ 

50 Returns the received loss. Updates the loss inplace. 

51 

52 :param device: device where the training takes place 

53 :param loss: loss without penalty 

54 :param weights: any weights to be penalized 

55 :return: loss 

56 """ 

57 raise NotImplementedError( 

58 "penalty_loss must be overwritten.") 

59 

60 def update_weights(self, device, statei): 

61 """ 

62 Returns the received loss. Updates the weight inplace. 

63 

64 :param device: device where the training takes place 

65 :param statei: loss without penalty 

66 :return: weight 

67 """ 

68 raise NotImplementedError( 

69 "update_weights must be overwritten.") 

70 

71 

72class NoLearningPenalty(BaseLearningPenalty): 

73 """ 

74 No regularization. 

75 """ 

76 

77 def __init__(self): 

78 BaseLearningPenalty.__init__(self) 

79 

80 def build_onnx_function(self, opset, device, n_tensors): 

81 # Nothing to do. 

82 pass 

83 

84 def penalty_loss(self, device, loss, *weights): 

85 """ 

86 Returns the received loss. Updates the loss inplace. 

87 

88 :param device: device where the training takes place 

89 :param loss: loss without penalty 

90 :param weights: any weights to be penalized 

91 :return: loss 

92 """ 

93 return loss 

94 

95 def update_weights(self, n_bind, device, statei): 

96 """ 

97 Returns the received loss. Updates the weight inplace. 

98 

99 :param device: device where the training takes place 

100 :param statei: loss without penalty 

101 :return: weight 

102 """ 

103 return statei 

104 

105 

106class ElasticLearningPenalty(BaseLearningPenalty): 

107 """ 

108 Implements a L1 or L2 regularization on weights. 

109 """ 

110 

111 def __init__(self, l1=0.5, l2=0.5): 

112 BaseLearningPenalty.__init__(self) 

113 self.l1 = l1 

114 self.l2 = l2 

115 

116 def build_onnx_function(self, opset, device, n_tensors): 

117 so = SessionOptions() 

118 so.log_severity_level = 4 

119 

120 # loss_grad 

121 self.penalty_onnx_ = function_onnx_graph( 

122 "n_penalty_elastic_error", target_opset=opset, n_tensors=n_tensors, 

123 loss_shape=None, l1_weight=self.l1, l2_weight=self.l2) 

124 self.penalty_sess_ = InferenceSession( 

125 self.penalty_onnx_.SerializeToString(), so, 

126 providers=device_to_providers(device)) 

127 self.penalty_sess_bind_ = ( 

128 self.penalty_sess_.io_binding()._iobinding) 

129 self.names_ = [i.name for i in self.penalty_onnx_.graph.input] 

130 

131 # weight updates 

132 self.penalty_grad_onnx_ = function_onnx_graph( 

133 "update_penalty_elastic_error", target_opset=opset, 

134 l1=self.l1, l2=self.l2) 

135 self.penalty_grad_sess_ = InferenceSession( 

136 self.penalty_grad_onnx_.SerializeToString(), so, 

137 providers=device_to_providers(device)) 

138 self.penalty_grad_sess_binds_ = [ 

139 self.penalty_grad_sess_.io_binding()._iobinding 

140 for n in range(n_tensors)] 

141 

142 def penalty_loss(self, device, *inputs): 

143 """ 

144 Computes the penalty associated to every 

145 weights and adds them up to the loss. 

146 

147 :param device: device where the training takes place 

148 :param inputs: loss without penalty and weights 

149 :return: loss + penatlies 

150 """ 

151 if (not hasattr(self, "penalty_onnx_") or 

152 not hasattr(self, "penalty_sess_bind_")): 

153 raise RuntimeError( # pragma: no cover 

154 "Attributes 'penalty_sess_bind_' or 'penalty_onnx_' is " 

155 "missing. Method 'build_onnx_function' has not been called.") 

156 if len(self.names_) != len(inputs): 

157 raise RuntimeError( # pragma: no cover 

158 f"Mismatched number of inputs: {len(self.names_)} != {len(inputs)}.") 

159 

160 for name, inp in zip(self.names_, inputs): 

161 self._bind_input_ortvalue( 

162 name, self.penalty_sess_bind_, inp, device, cache=True) 

163 self._bind_output_ortvalue( 

164 'Y', self.penalty_sess_bind_, inputs[0], cache=True) 

165 self._call_iobinding(self.penalty_sess_._sess, self.penalty_sess_bind_) 

166 return self.penalty_sess_bind_.get_outputs()[0] 

167 

168 def update_weights(self, n_bind, device, statei): 

169 if (not hasattr(self, "penalty_grad_onnx_") or 

170 not hasattr(self, "penalty_grad_sess_binds_")): 

171 raise RuntimeError( # pragma: no cover 

172 "Attributes 'penalty_grad_sess_binds_' or " 

173 "'penalty_grad_onnx_' is missing. Method " 

174 "'build_onnx_function' has not been called.") 

175 bind = self.penalty_grad_sess_binds_[n_bind] 

176 self._bind_input_ortvalue("X", bind, statei, device, cache=True) 

177 self._bind_output_ortvalue('Y', bind, statei, cache=True) 

178 self._call_iobinding(self.penalty_grad_sess_._sess, bind) 

179 return bind.get_outputs()[0] # X