Coverage for mlprodict/npy/xop_helper.py: 100%

13 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# pylint: disable=E0602 

2""" 

3@file 

4@brief Xop helpers. 

5 

6.. versionadded:: 0.9 

7""" 

8from .xop_variable import Variable 

9 

10 

11def _infer_node_output(node, inputs): 

12 """ 

13 Infers node outputs for a specific type. 

14 

15 :param node: :epkg:`NodeProto` 

16 :param outputs: known inputs 

17 :return: dtype 

18 """ 

19 if not isinstance(inputs, dict): 

20 raise TypeError( # pragma: no cover 

21 f"inputs should be OrderedDict not {type(inputs)!r}.") 

22 

23 if node.op_type == 'Concat': 

24 type_set = set() 

25 for v in inputs.values(): 

26 if not isinstance(v, Variable): 

27 raise TypeError( # pragma: no cover 

28 f"Unexpected type {type(v)!r} for {v!r}.") 

29 type_set.add(v.dtype) 

30 if len(type_set) != 1: 

31 raise RuntimeError( # pragma: no cover 

32 f"Unable to guess output type from {type_set!r} (inputs={inputs!r}).") 

33 dtype = type_set.pop() 

34 if dtype is None: 

35 raise RuntimeError( # pragma: no cover 

36 f"Guessed output type is None from inputs={inputs!r}.") 

37 return dtype, [None, None] 

38 

39 raise NotImplementedError( # pragma: no cover 

40 f"Unable to infer type for node type {node.op_type!r} and inputs={inputs!r}.")