Coverage for mlprodict/grammar/grammar_sklearn/grammar/gactions_num.py: 100%
36 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Action definition.
4"""
5from .gtypes import MLNumTypeFloat32, MLNumTypeFloat64, MLNumTypeBool
6from .gactions import MLActionBinary, MLActionFunctionCall
9class MLActionAdd(MLActionBinary):
10 """
11 Addition
12 """
14 def __init__(self, act1, act2):
15 """
16 @param act1 first element
17 @param act2 second element
18 """
19 MLActionBinary.__init__(self, act1, act2, "+")
20 if type(act1.output) != type(act2.output):
21 raise TypeError( # pragma: no cover
22 f"Not the same input type {type(act1.output)} != {type(act2.output)}")
24 def execute(self, **kwargs):
25 MLActionBinary.execute(self, **kwargs)
26 res = self.ChildrenResults
27 return self.output.validate(res[0] + res[1])
30class MLActionSign(MLActionFunctionCall):
31 """
32 Sign of an expression: 1=positive, 0=negative.
33 """
35 def __init__(self, act1):
36 """
37 @param act1 first element
38 """
39 MLActionFunctionCall.__init__(self, "sign", act1.output, act1)
40 if not isinstance(act1.output, (MLNumTypeFloat32, MLNumTypeFloat64)):
41 raise TypeError( # pragma: no cover
42 f"The input action must produce float32 or float64 not '{type(act1.output)}'")
44 def execute(self, **kwargs):
45 MLActionFunctionCall.execute(self, **kwargs)
46 res = self.ChildrenResults
47 return self.output.validate(self.output.softcast(1 if res[0] >= 0 else 0))
50class MLActionTestInf(MLActionBinary):
51 """
52 Operator ``<``.
53 """
55 def __init__(self, act1, act2):
56 """
57 @param act1 first element
58 @param act2 second element
59 """
60 MLActionBinary.__init__(self, act1, act2, "<=")
61 if type(act1.output) != type(act2.output):
62 raise TypeError( # pragma: no cover
63 f"Not the same input type {type(act1.output)} != {type(act2.output)}")
64 self.output = MLNumTypeBool()
66 def execute(self, **kwargs):
67 MLActionBinary.execute(self, **kwargs)
68 res = self.ChildrenResults
69 return self.output.validate(self.output.softcast(res[0] <= res[1]))
72class MLActionTestEqual(MLActionBinary):
73 """
74 Operator ``==``.
75 """
77 def __init__(self, act1, act2):
78 """
79 @param act1 first element
80 @param act2 second element
81 """
82 MLActionBinary.__init__(self, act1, act2, "==")
83 if type(act1.output) != type(act2.output):
84 raise TypeError( # pragma: no cover
85 f"Not the same input type {type(act1.output)} != {type(act2.output)}")
86 self.output = MLNumTypeBool()
88 def execute(self, **kwargs):
89 MLActionBinary.execute(self, **kwargs)
90 res = self.ChildrenResults
91 return self.output.validate(self.output.softcast(res[0] == res[1]))