Coverage for src/sparkouille/fctmr/fast_parallel_fctmr.py: 91%
56 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 14:24 +0200
« prev ^ index » next coverage.py v7.2.7, created at 2023-07-01 14:24 +0200
1# -*- coding: utf-8 -*-
2"""
3@file
4@brief Simple parallelization of *mapper* and *reducer* based on :epkg:`numba`.
5:epkg:`Python` does not easily allow to parallelize functions
6as the :epkg:`GIL` blocks most of the tentatives by imposing
7a single tunnel for all allocations, creation of :epkg:`python`
8objects. The language implements it but in practice it is not.
9This file is just a tentative to use :epkg:`numba` to parallelize
10a mapper but the number of round trip between :epkg:`python`
11and compiled :epkg:`C` makes it difficult to write something
12generic.
13"""
14import numpy
15from numba import jit, njit, prange
18def create_array_numba(nb, sig):
19 """
20 Creates an array of size nb knowing its signature.
22 @param nb integer
23 @param signature signature, ex: ``'f8'``
24 @return container
25 """
26 if sig == 'f8':
27 return numpy.empty(nb, dtype=numpy.float64)
28 else:
29 raise NotImplementedError(
30 "Cannot create a container for type '{0}'.".format(sig))
33def fast_parallel_mapper(fct, gen, chunk_size=100000, parallel=True,
34 nogil=False, nopython=True,
35 sigin=None, sigout=None):
36 """
37 Parallelizes a mapper based on :epkg:`numba` and more specifically
38 `Automatic parallelization with @jit <https://numba.pydata.org/
39 numba-doc/dev/user/parallel.html>`_.
40 This page indicates what :epkg:`numba` optimizes when
41 it parallizes a map.
43 @param fct function
44 @param gen generator
45 @param chunk_size see :ref:`l-parallel-mapper-chunk-size`
46 @param parallel see `parallel <http://numba.pydata.org/numba-doc/latest/
47 user/jit.html?highlight=nopython#parallel>`_
48 @param nopython see `nopython <http://numba.pydata.org/numba-doc/
49 latest/user/jit.html?highlight=nopython#nopython>`_
50 @param nogil see `nogil <http://numba.pydata.org/numba-doc/
51 latest/user/jit.html?highlight=nopython#nogil>`_
52 @param sigin signature of input type
53 @param sigout signature of output type
55 @return generator
57 The parallelization can only happen if
58 the array is known. So the function splits the
59 array in chunck of size *chunk_size*.
60 This tentative is not very efficient due to
61 the genericity of the mapper. :epkg:`python`
62 is not a good language to do that.
63 See unit test
64 `test_parallel_fctmr.py <https://github.com/sdpython/sparkouille/blob/
65 master/_unittests/ut_fctmr/test_parallel_fctmr.py>`_.
66 """
67 if sigin is not None and sigout is not None:
68 sig1 = '{0}({1})'.format(sigout, sigin)
69 sig2 = 'void(i8, {0}[:], {1}[:])'.format(sigin, sigout)
71 fct_jit = jit(sig1, nogil=nogil, parallel=parallel,
72 nopython=nopython, cache=True)(fct)
74 def loop(nb, inputs, outputs):
75 "local function"
76 for i in prange(nb):
77 outputs[i] = fct_jit(inputs[i])
79 loop_jit = njit(sig2, nogil=nogil, parallel=parallel,
80 nopython=nopython, cache=True)(loop)
81 inputs = create_array_numba(chunk_size, sigin)
82 outputs = create_array_numba(chunk_size, sigout)
84 done = 0
85 for obs in gen:
86 if done < len(inputs):
87 inputs[done] = obs
88 done += 1
89 else:
90 loop_jit(done, inputs, outputs)
91 for out in outputs:
92 yield out
93 done = 0
94 if 0 < done < len(inputs):
95 loop_jit(done, inputs, outputs)
96 for out in outputs:
97 yield out
99 else:
101 def loop(nb, inputs, outputs):
102 "local function"
103 for i in prange(nb):
104 outputs[i] = fct_jit(inputs[i])
106 loop_jit = None
107 fct_jit = None
109 inputs = None
110 outputs = None
112 done = 0
113 for obs in gen:
114 if inputs is None:
115 inputs = [obs] * chunk_size
116 outputs = [fct(obs)] * chunk_size
117 loop_jit = njit(nogil=nogil, parallel=parallel,
118 nopython=nopython, cache=True)(loop)
119 fct_jit = jit(nogil=nogil, parallel=parallel,
120 nopython=nopython, cache=True)(fct)
122 if done < len(inputs):
123 inputs[done] = obs
124 done += 1
125 else:
126 loop_jit(done, inputs, outputs)
127 for out in outputs:
128 yield out
129 done = 0
130 if 0 < done < len(inputs):
131 loop_jit(done, inputs, outputs)
132 for out in outputs:
133 yield out