Coverage for src/sparkouille/fctmr/fast_parallel_fctmr.py: 91%

56 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-07-01 14:24 +0200

1# -*- coding: utf-8 -*- 

2""" 

3@file 

4@brief Simple parallelization of *mapper* and *reducer* based on :epkg:`numba`. 

5:epkg:`Python` does not easily allow to parallelize functions 

6as the :epkg:`GIL` blocks most of the tentatives by imposing 

7a single tunnel for all allocations, creation of :epkg:`python` 

8objects. The language implements it but in practice it is not. 

9This file is just a tentative to use :epkg:`numba` to parallelize 

10a mapper but the number of round trip between :epkg:`python` 

11and compiled :epkg:`C` makes it difficult to write something 

12generic. 

13""" 

14import numpy 

15from numba import jit, njit, prange 

16 

17 

18def create_array_numba(nb, sig): 

19 """ 

20 Creates an array of size nb knowing its signature. 

21 

22 @param nb integer 

23 @param signature signature, ex: ``'f8'`` 

24 @return container 

25 """ 

26 if sig == 'f8': 

27 return numpy.empty(nb, dtype=numpy.float64) 

28 else: 

29 raise NotImplementedError( 

30 "Cannot create a container for type '{0}'.".format(sig)) 

31 

32 

33def fast_parallel_mapper(fct, gen, chunk_size=100000, parallel=True, 

34 nogil=False, nopython=True, 

35 sigin=None, sigout=None): 

36 """ 

37 Parallelizes a mapper based on :epkg:`numba` and more specifically 

38 `Automatic parallelization with @jit <https://numba.pydata.org/ 

39 numba-doc/dev/user/parallel.html>`_. 

40 This page indicates what :epkg:`numba` optimizes when 

41 it parallizes a map. 

42 

43 @param fct function 

44 @param gen generator 

45 @param chunk_size see :ref:`l-parallel-mapper-chunk-size` 

46 @param parallel see `parallel <http://numba.pydata.org/numba-doc/latest/ 

47 user/jit.html?highlight=nopython#parallel>`_ 

48 @param nopython see `nopython <http://numba.pydata.org/numba-doc/ 

49 latest/user/jit.html?highlight=nopython#nopython>`_ 

50 @param nogil see `nogil <http://numba.pydata.org/numba-doc/ 

51 latest/user/jit.html?highlight=nopython#nogil>`_ 

52 @param sigin signature of input type 

53 @param sigout signature of output type 

54 

55 @return generator 

56 

57 The parallelization can only happen if 

58 the array is known. So the function splits the 

59 array in chunck of size *chunk_size*. 

60 This tentative is not very efficient due to 

61 the genericity of the mapper. :epkg:`python` 

62 is not a good language to do that. 

63 See unit test 

64 `test_parallel_fctmr.py <https://github.com/sdpython/sparkouille/blob/ 

65 master/_unittests/ut_fctmr/test_parallel_fctmr.py>`_. 

66 """ 

67 if sigin is not None and sigout is not None: 

68 sig1 = '{0}({1})'.format(sigout, sigin) 

69 sig2 = 'void(i8, {0}[:], {1}[:])'.format(sigin, sigout) 

70 

71 fct_jit = jit(sig1, nogil=nogil, parallel=parallel, 

72 nopython=nopython, cache=True)(fct) 

73 

74 def loop(nb, inputs, outputs): 

75 "local function" 

76 for i in prange(nb): 

77 outputs[i] = fct_jit(inputs[i]) 

78 

79 loop_jit = njit(sig2, nogil=nogil, parallel=parallel, 

80 nopython=nopython, cache=True)(loop) 

81 inputs = create_array_numba(chunk_size, sigin) 

82 outputs = create_array_numba(chunk_size, sigout) 

83 

84 done = 0 

85 for obs in gen: 

86 if done < len(inputs): 

87 inputs[done] = obs 

88 done += 1 

89 else: 

90 loop_jit(done, inputs, outputs) 

91 for out in outputs: 

92 yield out 

93 done = 0 

94 if 0 < done < len(inputs): 

95 loop_jit(done, inputs, outputs) 

96 for out in outputs: 

97 yield out 

98 

99 else: 

100 

101 def loop(nb, inputs, outputs): 

102 "local function" 

103 for i in prange(nb): 

104 outputs[i] = fct_jit(inputs[i]) 

105 

106 loop_jit = None 

107 fct_jit = None 

108 

109 inputs = None 

110 outputs = None 

111 

112 done = 0 

113 for obs in gen: 

114 if inputs is None: 

115 inputs = [obs] * chunk_size 

116 outputs = [fct(obs)] * chunk_size 

117 loop_jit = njit(nogil=nogil, parallel=parallel, 

118 nopython=nopython, cache=True)(loop) 

119 fct_jit = jit(nogil=nogil, parallel=parallel, 

120 nopython=nopython, cache=True)(fct) 

121 

122 if done < len(inputs): 

123 inputs[done] = obs 

124 done += 1 

125 else: 

126 loop_jit(done, inputs, outputs) 

127 for out in outputs: 

128 yield out 

129 done = 0 

130 if 0 < done < len(inputs): 

131 loop_jit(done, inputs, outputs) 

132 for out in outputs: 

133 yield out