Coverage for mlprodict/onnxrt/ops_cpu/op_cdist.py: 96%
27 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from scipy.spatial.distance import cdist
8from ._op import OpRunBinaryNum
9from ._new_ops import OperatorSchema
12class CDist(OpRunBinaryNum):
14 atts = {'metric': 'sqeuclidean', 'p': 2.}
16 def __init__(self, onnx_node, desc=None, **options):
17 OpRunBinaryNum.__init__(self, onnx_node, desc=desc,
18 expected_attributes=CDist.atts,
19 **options)
21 def _run(self, a, b, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
22 metric = self.metric.decode('ascii')
23 if metric == 'minkowski':
24 res = cdist(a, b, metric=metric, p=self.p)
25 else:
26 res = cdist(a, b, metric=metric)
27 # scipy may change the output type
28 res = res.astype(a.dtype)
29 return (res, )
31 def _find_custom_operator_schema(self, op_name):
32 if op_name == "CDist":
33 return CDistSchema()
34 raise RuntimeError( # pragma: no cover
35 f"Unable to find a schema for operator '{op_name}'.")
37 def to_python(self, inputs):
38 metric = self.metric.decode('ascii')
39 if metric == 'minkowski':
40 return ('from scipy.spatial.distance import cdist',
41 "return cdist({}, {}, metric='{}', p={})".format(
42 inputs[0], inputs[1], metric, self.p))
43 return ('from scipy.spatial.distance import cdist',
44 f"return cdist({inputs[0]}, {inputs[1]}, metric='{metric}')")
47class CDistSchema(OperatorSchema):
48 """
49 Defines a schema for operators added in this package
50 such as @see cl TreeEnsembleClassifierDouble.
51 """
53 def __init__(self):
54 OperatorSchema.__init__(self, 'CDist')
55 self.attributes = CDist.atts