{"cells": [{"cell_type": "markdown", "id": "c760c855", "metadata": {}, "source": ["# ONNX FFTs\n", "\n", "Implementation of a couple of variations of FFT (see [FFT](https://www.tensorflow.org/xla/operation_semantics#fft) in ONNX."]}, {"cell_type": "code", "execution_count": 1, "id": "ddecaddb", "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "code", "execution_count": 2, "id": "6f17f4f5", "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "code", "execution_count": 3, "id": "75f2064c", "metadata": {}, "outputs": [], "source": ["%load_ext mlprodict"]}, {"cell_type": "markdown", "id": "05a249a3", "metadata": {}, "source": ["## Signature\n", "\n", "We try to use function [FFT](https://www.tensorflow.org/xla/operation_semantics#fft) or [torch.fft.fftn](https://pytorch.org/docs/stable/generated/torch.fft.fftn.html#torch.fft.fftn)."]}, {"cell_type": "code", "execution_count": 4, "id": "6e2fc017", "metadata": {}, "outputs": [{"data": {"text/plain": ["1302"]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["import numpy\n", "from numpy.testing import assert_almost_equal\n", "\n", "def numpy_fftn(x, fft_type, fft_length, axes):\n", " \"\"\"\n", " Implements FFT\n", "\n", " :param x: input\n", " :param fft_type: string (see below)\n", " :param fft_length: length on each axis of axes\n", " :param axes: axes\n", " :return: result\n", " \n", " * `'FFT`': complex-to-complex FFT. Shape is unchanged.\n", " * `'IFFT`': Inverse complex-to-complex FFT. Shape is unchanged.\n", " * `'RFFT`': Forward real-to-complex FFT.\n", " Shape of the innermost axis is reduced to fft_length[-1] // 2 + 1 if fft_length[-1]\n", " is a non-zero value, omitting the reversed conjugate part of \n", " the transformed signal beyond the Nyquist frequency.\n", " * `'IRFFT`': Inverse real-to-complex FFT (ie takes complex, returns real).\n", " Shape of the innermost axis is expanded to fft_length[-1] if fft_length[-1] \n", " is a non-zero value, inferring the part of the transformed signal beyond the Nyquist\n", " frequency from the reverse conjugate of the 1 to fft_length[-1] // 2 + 1 entries.\n", " \"\"\"\n", " if fft_type == 'FFT':\n", " return numpy.fft.fftn(x, fft_length, axes=axes)\n", " raise NotImplementedError(\"Not implemented for fft_type=%r.\" % fft_type)\n", " \n", "\n", "def test_fct(fct1, fct2, fft_type='FFT', decimal=5):\n", " cases = list(range(4, 20))\n", " dims = [[c] for c in cases] + [[4,4,4,4], [4,5,6,7]]\n", " lengths_axes = [([c], [0]) for c in cases] + [\n", " ([2, 2, 2, 2], None), ([2, 6, 7, 2], None), ([2, 3, 4, 5], None),\n", " ([2], [3]), ([3], [2])]\n", " n_test = 0\n", " for ndim in range(1, 5):\n", " for dim in dims:\n", " for length, axes in lengths_axes:\n", " if axes is None:\n", " axes = range(ndim)\n", " di = dim[:ndim]\n", " axes = [min(len(di) - 1, a) for a in axes]\n", " le = length[:ndim]\n", " if len(length) > len(di):\n", " continue\n", " mat = numpy.random.randn(*di).astype(numpy.float32)\n", " try:\n", " v1 = fct1(mat, fft_type, le, axes=axes)\n", " except Exception as e:\n", " raise AssertionError(\n", " \"Unable to run %r mat.shape=%r ndim=%r di=%r fft_type=%r le=%r \"\n", " \"axes=%r exc=%r\" %(\n", " fct1, mat.shape, ndim, di, fft_type, le, axes, e))\n", " v2 = fct2(mat, fft_type, le, axes=axes)\n", " try:\n", " assert_almost_equal(v1, v2, decimal=decimal)\n", " except AssertionError as e:\n", " raise AssertionError(\n", " \"Failure mat.shape=%r, fft_type=%r, fft_length=%r\" % (\n", " mat.shape, fft_type, le)) from e\n", " n_test += 1\n", " return n_test\n", "\n", "\n", "test_fct(numpy_fftn, numpy_fftn)"]}, {"cell_type": "code", "execution_count": 5, "id": "84993aa6", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["1.81 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["%timeit -n 1 -r 1 test_fct(numpy_fftn, numpy_fftn)"]}, {"cell_type": "code", "execution_count": 6, "id": "9de9a43f", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["2.07 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["import torch\n", "\n", "def torch_fftn(x, fft_type, fft_length, axes):\n", " xt = torch.tensor(x)\n", " if fft_type == 'FFT':\n", " return torch.fft.fftn(xt, fft_length, axes).cpu().detach().numpy()\n", " \n", "%timeit -n 1 -r 1 test_fct(numpy_fftn, torch_fftn)"]}, {"cell_type": "markdown", "id": "e55d6dbf", "metadata": {}, "source": ["## Numpy implementation"]}, {"cell_type": "code", "execution_count": 7, "id": "aa74068d", "metadata": {}, "outputs": [], "source": ["import numpy\n", "\n", "\n", "def _dft_cst(N, fft_length, dtype):\n", " def _arange(dim, dtype, resh):\n", " return numpy.arange(dim).astype(dtype).reshape(resh)\n", "\n", " def _prod(n, k):\n", " return (-2j * numpy.pi * k / fft_length) * n\n", "\n", " def _exp(m):\n", " return numpy.exp(m)\n", " \n", " n = _arange(N, dtype, (-1, 1))\n", " k = _arange(fft_length, dtype, (1, -1))\n", " M = _exp(_prod(n, k))\n", " return M\n", "\n", "\n", "def custom_fft(x, fft_type, length, axis, dft_fct=None):\n", " # https://github.com/numpy/numpy/blob/4adc87dff15a247e417d50f10cc4def8e1c17a03/numpy/fft/_pocketfft.py#L56\n", " if dft_fct is None:\n", " dft_fct = _dft_cst\n", " if fft_type == 'FFT':\n", " if x.shape[axis] > length:\n", " # fft_length > shape on the same axis\n", " # the matrix is shortened\n", " slices = [slice(None)] * len(x.shape)\n", " slices[axis] = slice(0, length)\n", " new_x = x[tuple(slices)]\n", " elif x.shape[axis] == length:\n", " new_x = x\n", " else:\n", " # other, the matrix is completed with zeros\n", " shape = list(x.shape)\n", " shape[axis] = length\n", " slices = [slice(None)] * len(x.shape)\n", " slices[axis] = slice(0, length)\n", " zeros = numpy.zeros(tuple(shape), dtype=x.dtype)\n", " index = [slice(0, i) for i in x.shape]\n", " zeros[tuple(index)] = x\n", " new_x = zeros\n", "\n", " cst = dft_fct(new_x.shape[axis], length, x.dtype)\n", " perm = numpy.arange(len(x.shape)).tolist() \n", " if perm[axis] == perm[-1]:\n", " res = numpy.matmul(new_x, cst).transpose(perm)\n", " else:\n", " perm[axis], perm[-1] = perm[-1], perm[axis] \n", " rest = new_x.transpose(perm)\n", " res = numpy.matmul(rest, cst).transpose(perm)\n", " perm[axis], perm[0] = perm[0], perm[axis]\n", " return res\n", " raise ValueError(\"Unexpected value for fft_type=%r.\" % fft_type)\n", "\n", "\n", "def custom_fftn(x, fft_type, fft_length, axes, dft_fct=None):\n", " if len(axes) != len(fft_length):\n", " raise ValueError(\"Length mismatch axes=%r, fft_length=%r.\" % (\n", " axes, fft_length))\n", " if fft_type == 'FFT':\n", " res = x\n", " for i in range(len(fft_length) - 1, -1, -1):\n", " length = fft_length[i]\n", " axis = axes[i]\n", " res = custom_fft(res, fft_type, length, axis, dft_fct=dft_fct)\n", " return res\n", " raise ValueError(\"Unexpected value for fft_type=%r.\" % fft_type)\n", "\n", " \n", "shape = (4, )\n", "fft_length = [5,]\n", "axes = [0]\n", "rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * 1j\n", "custom_fftn(rnd, 'FFT', fft_length, axes), numpy_fftn(rnd, 'FFT', fft_length, axes)\n", "assert_almost_equal(custom_fftn(rnd, 'FFT', fft_length, axes),\n", " numpy_fftn(rnd, 'FFT', fft_length, axes), decimal=5)\n", "\n", "shape = (4, 3)\n", "fft_length = [3, 2]\n", "axes = [0, 1]\n", "rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * 1j\n", "custom_fftn(rnd, 'FFT', fft_length, axes), numpy_fftn(rnd, 'FFT', fft_length, axes)\n", "assert_almost_equal(custom_fftn(rnd, 'FFT', fft_length, axes),\n", " numpy_fftn(rnd, 'FFT', fft_length, axes), decimal=5)"]}, {"cell_type": "code", "execution_count": 8, "id": "5c454666", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["2.35 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["%timeit -n 1 -r 1 test_fct(numpy_fftn, custom_fftn, decimal=4)"]}, {"cell_type": "markdown", "id": "f27bd70d", "metadata": {}, "source": ["## Benchmark"]}, {"cell_type": "code", "execution_count": 9, "id": "507d8348", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:06<00:00, 3.91it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namecustom_fftnnumpy_fftntorch_fftn
length
80.0005850.0009110.003643
160.0016690.0013730.004087
240.0026820.0032730.005745
320.0042880.0032750.004657
400.0048180.0038310.005198
\n", "
"], "text/plain": ["name custom_fftn numpy_fftn torch_fftn\n", "length \n", "8 0.000585 0.000911 0.003643\n", "16 0.001669 0.001373 0.004087\n", "24 0.002682 0.003273 0.005745\n", "32 0.004288 0.003275 0.004657\n", "40 0.004818 0.003831 0.005198"]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}], "source": ["from cpyquickhelper.numbers.speed_measure import measure_time\n", "from tqdm import tqdm\n", "from pandas import DataFrame\n", "\n", "def benchmark(fcts, power2=False):\n", " axes = [1]\n", " if power2:\n", " shape = [512, 1024]\n", " lengths = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]\n", " else:\n", " shape = [512, 150]\n", " lengths = list(range(8, 200, 8))\n", " rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * 1j\n", "\n", " data = []\n", " for length in tqdm(lengths):\n", " fft_length = [length]\n", " for name, fct in fcts.items():\n", " obs = measure_time(lambda: fct(rnd, 'FFT', fft_length, axes),\n", " repeat=5, number=5)\n", " obs['name'] = name\n", " obs['length'] = length\n", " data.append(obs)\n", "\n", " df = DataFrame(data)\n", " return df\n", "\n", "\n", "df = benchmark({'numpy_fftn': numpy_fftn, 'custom_fftn': custom_fftn, 'torch_fftn': torch_fftn})\n", "piv = df.pivot(\"length\", \"name\", \"average\")\n", "piv[:5]"]}, {"cell_type": "code", "execution_count": 10, "id": "6f201494", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["piv.plot(logy=True, logx=True, title=\"FFT benchmark\", figsize=(12, 4));"]}, {"cell_type": "code", "execution_count": 11, "id": "9c29380c", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 10/10 [00:13<00:00, 1.33s/it]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namecustom_fftnnumpy_fftntorch_fftn
length
20.0004340.0011670.023980
40.0011170.0016710.022530
80.0014280.0020770.022102
160.0046540.0028740.019792
320.0031720.0026890.017474
640.0069660.0046120.018116
1280.0309040.0116080.023369
2560.1238210.0258530.023532
5120.4768020.0433520.033228
10241.5279170.1098680.052858
\n", "
"], "text/plain": ["name custom_fftn numpy_fftn torch_fftn\n", "length \n", "2 0.000434 0.001167 0.023980\n", "4 0.001117 0.001671 0.022530\n", "8 0.001428 0.002077 0.022102\n", "16 0.004654 0.002874 0.019792\n", "32 0.003172 0.002689 0.017474\n", "64 0.006966 0.004612 0.018116\n", "128 0.030904 0.011608 0.023369\n", "256 0.123821 0.025853 0.023532\n", "512 0.476802 0.043352 0.033228\n", "1024 1.527917 0.109868 0.052858"]}, "execution_count": 12, "metadata": {}, "output_type": "execute_result"}], "source": ["df = benchmark({'numpy_fftn': numpy_fftn, 'custom_fftn': custom_fftn, 'torch_fftn': torch_fftn},\n", " power2=True)\n", "piv = df.pivot(\"length\", \"name\", \"average\")\n", "piv"]}, {"cell_type": "code", "execution_count": 12, "id": "40847bf5", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["piv.plot(logy=True, logx=True, title=\"FFT benchmark (power2)\", figsize=(12, 4));"]}, {"cell_type": "markdown", "id": "616099d8", "metadata": {}, "source": ["## Profiling"]}, {"cell_type": "code", "execution_count": 13, "id": "531658bb", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["f -- 1 1 -- 0.01752 0.54515 -- :8:f (f)\n", " custom_fftn -- 100 100 -- 0.00234 0.52763 -- :57:custom_fftn (custom_fftn)\n", " custom_fft -- 100 100 -- 0.19936 0.52516 -- :20:custom_fft (custom_fft)\n", " _dft_cst -- 100 100 -- 0.31917 0.32366 -- :4:_dft_cst (_dft_cst)\n", " _arange -- 200 200 -- 0.00088 0.00449 -- :5:_arange (_arange)\n", " -- 200 200 -- 0.00128 0.00128 -- ~:0: ()\n", " -- 200 200 -- 0.00064 0.00064 -- ~:0: ()\n", " -- 200 200 -- 0.00169 0.00169 -- ~:0: () +++\n", " -- 100 100 -- 0.00011 0.00011 -- ~:0: () +++\n", " -- 100 100 -- 0.00024 0.00024 -- ~:0: ()\n", " -- 100 100 -- 0.00076 0.00076 -- ~:0: ()\n", " -- 100 100 -- 0.00102 0.00102 -- ~:0: () +++\n", " -- 300 300 -- 0.00013 0.00013 -- ~:0: () +++\n", " -- 400 400 -- 0.00024 0.00024 -- ~:0: ()\n", " -- 300 300 -- 0.00271 0.00271 -- ~:0: ()\n"]}], "source": ["from pyquickhelper.pycode.profiling import profile2graph, profile\n", "\n", "shape = [512, 128]\n", "fft_length = [128]\n", "axes = [1]\n", "rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * 1j\n", "\n", "def f():\n", " for i in range(100):\n", " custom_fftn(rnd, 'FFT', fft_length, axes)\n", "\n", "stat, text = profile(f)\n", "gr = profile2graph(stat)\n", "print(gr[0].to_text(fct_width=40))"]}, {"cell_type": "markdown", "id": "7690454d", "metadata": {}, "source": ["We can see that function `_dft_cst` is the bottle neck and more precisely the exponential. We need to use the symmetries of the matrix it builds."]}, {"cell_type": "markdown", "id": "4e250ff9", "metadata": {}, "source": ["## Faster _dft_cst\n", "\n", "The function builds the matrix $M_{nk} = \\left( \\exp\\left(\\frac{-2i\\pi nk}{K}\\right) \\right)_{nk}$ where $1 \\leqslant n \\leqslant N$ and $1 \\leqslant k \\leqslant K$. So it computes powers of the unity roots.\n", "\n", "$$\n", "\\exp\\left(\\frac{-2i\\pi nk}{K}\\right) = \\exp\\left(\\frac{-2i\\pi k}{K}\\right)^n = \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{nk}\n", "$$\n", "\n", "We use that expression to reduce the number of exponentiels to compute."]}, {"cell_type": "code", "execution_count": 14, "id": "8b60fd16", "metadata": {}, "outputs": [{"data": {"text/plain": ["((3, 4), dtype('complex64'))"]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["import numpy\n", "from numpy.testing import assert_almost_equal\n", "\n", "def _dft_cst(N, fft_length, dtype=numpy.float32):\n", " def _arange(dim, dtype, resh):\n", " return numpy.arange(dim).astype(dtype).reshape(resh)\n", "\n", " n = _arange(N, dtype, (-1, 1))\n", " k = _arange(fft_length, dtype, (1, -1))\n", " M = (-2j * numpy.pi * k / fft_length) * n\n", " numpy.exp(M, out=M)\n", " return M\n", "\n", "\n", "M = _dft_cst(3, 4, numpy.float32)\n", "M.shape, M.dtype"]}, {"cell_type": "code", "execution_count": 15, "id": "0600a293", "metadata": {}, "outputs": [{"data": {"text/plain": ["((4, 3), dtype('complex128'))"]}, "execution_count": 16, "metadata": {}, "output_type": "execute_result"}], "source": ["M = _dft_cst(4, 3, numpy.float64)\n", "M.shape, M.dtype"]}, {"cell_type": "code", "execution_count": 16, "id": "38d760e2", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[ 1. +0.00000000e+00j, 1. +0.00000000e+00j, 1. +0.00000000e+00j],\n", " [ 1. +0.00000000e+00j, -0.5-8.66025404e-01j, -0.5+8.66025404e-01j],\n", " [ 1. +0.00000000e+00j, -0.5+8.66025404e-01j, -0.5-8.66025404e-01j],\n", " [ 1. +0.00000000e+00j, 1. +2.44929360e-16j, 1. +4.89858720e-16j]])"]}, "execution_count": 17, "metadata": {}, "output_type": "execute_result"}], "source": ["M"]}, {"cell_type": "code", "execution_count": 17, "id": "30466d1b", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[ 1. +0.00000000e+00j, 1. +0.00000000e+00j, 1. +0.00000000e+00j],\n", " [ 1. +0.00000000e+00j, -0.5-8.66025404e-01j, -0.5+8.66025404e-01j],\n", " [ 1. +0.00000000e+00j, -0.5+8.66025404e-01j, -0.5-8.66025404e-01j],\n", " [ 1. +0.00000000e+00j, 1. +6.10622664e-16j, 1. +1.22124533e-15j]])"]}, "execution_count": 18, "metadata": {}, "output_type": "execute_result"}], "source": ["def _dft_cst_power(N, fft_length, dtype=numpy.float32):\n", " if dtype == numpy.float32:\n", " ctype = numpy.complex64\n", " else:\n", " ctype = numpy.complex128\n", " M = numpy.empty((N, fft_length), dtype=ctype)\n", " M[0, :] = 1\n", " M[1, 0] = 1\n", " root = numpy.exp(numpy.pi / fft_length * (-2j))\n", " current = root\n", " M[1, 1] = root\n", " for i in range(2, M.shape[1]):\n", " current *= root\n", " M[1, i] = current\n", " for i in range(2, M.shape[0]):\n", " numpy.multiply(M[i-1, :], M[1, :], out=M[i, :])\n", " return M\n", "\n", "M_pow = _dft_cst_power(4, 3, numpy.float64)\n", "M_pow"]}, {"cell_type": "code", "execution_count": 18, "id": "965651d7", "metadata": {}, "outputs": [], "source": ["assert_almost_equal(M, M_pow)"]}, {"cell_type": "code", "execution_count": 19, "id": "a7194b92", "metadata": {}, "outputs": [], "source": ["dims = (10, 15)\n", "assert_almost_equal(_dft_cst(*dims, dtype=numpy.float32), \n", " _dft_cst_power(*dims, dtype=numpy.float32),\n", " decimal=5)"]}, {"cell_type": "markdown", "id": "36b1a3d7", "metadata": {}, "source": ["## Benchmark again"]}, {"cell_type": "code", "execution_count": 20, "id": "40a61cb1", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["1.46 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["def custom_fftn_power(*args, **kwargs):\n", " return custom_fftn(*args, dft_fct=_dft_cst_power, **kwargs)\n", "\n", "\n", "%timeit -r 1 -n 1 test_fct(numpy_fftn, custom_fftn_power, decimal=4)"]}, {"cell_type": "code", "execution_count": 21, "id": "3a2707eb", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:07<00:00, 3.19it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namecustom_fftncustom_fftn_powernumpy_fftntorch_fftn
length
80.0009910.0008370.0011770.007033
160.0027580.0025910.0020690.006228
240.0030870.0028160.0024990.005564
320.0037670.0030680.0033060.005985
400.0047100.0039750.0040440.005733
\n", "
"], "text/plain": ["name custom_fftn custom_fftn_power numpy_fftn torch_fftn\n", "length \n", "8 0.000991 0.000837 0.001177 0.007033\n", "16 0.002758 0.002591 0.002069 0.006228\n", "24 0.003087 0.002816 0.002499 0.005564\n", "32 0.003767 0.003068 0.003306 0.005985\n", "40 0.004710 0.003975 0.004044 0.005733"]}, "execution_count": 22, "metadata": {}, "output_type": "execute_result"}], "source": ["df = benchmark({\n", " 'numpy_fftn': numpy_fftn, 'torch_fftn': torch_fftn, 'custom_fftn': custom_fftn, \n", " 'custom_fftn_power': custom_fftn_power})\n", "piv = df.pivot(\"length\", \"name\", \"average\")\n", "piv[:5]"]}, {"cell_type": "code", "execution_count": 22, "id": "eb2c6d54", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["piv.plot(logy=True, logx=True, title=\"FFT benchmark 2\", figsize=(12, 4));"]}, {"cell_type": "code", "execution_count": 23, "id": "ed6c4d1a", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["f -- 1 1 -- 0.02624 0.57688 -- :8:f (f)\n", " custom_fftn_power -- 100 100 -- 0.00094 0.55064 -- :1:custom_fftn_power (custom_fftn_power)\n", " custom_fftn -- 100 100 -- 0.00609 0.54970 -- :57:custom_fftn (custom_fftn)\n", " custom_fft -- 100 100 -- 0.46378 0.54342 -- :20:custom_fft (custom_fft)\n", " _dft_cst_power -- 100 100 -- 0.07599 0.07726 -- :1:_dft_cst_power (_dft_cst_power)\n", " -- 100 100 -- 0.00126 0.00126 -- ~:0: ()\n", " -- 100 100 -- 0.00008 0.00008 -- ~:0: () +++\n", " -- 100 100 -- 0.00025 0.00025 -- ~:0: ()\n", " -- 100 100 -- 0.00096 0.00096 -- ~:0: ()\n", " -- 100 100 -- 0.00109 0.00109 -- ~:0: ()\n", " -- 300 300 -- 0.00020 0.00020 -- ~:0: () +++\n", " -- 400 400 -- 0.00027 0.00027 -- ~:0: ()\n"]}], "source": ["from pyquickhelper.pycode.profiling import profile2graph, profile\n", "\n", "shape = [512, 128]\n", "fft_length = [128]\n", "axes = [1]\n", "rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * 1j\n", "\n", "def f():\n", " for i in range(100):\n", " custom_fftn_power(rnd, 'FFT', fft_length, axes)\n", "\n", "stat, text = profile(f)\n", "gr = profile2graph(stat)\n", "print(gr[0].to_text(fct_width=40))"]}, {"cell_type": "markdown", "id": "9ef4af25", "metadata": {}, "source": ["## Cooley\u2013Tukey FFT algorithm\n", "\n", "See [Cooley\u2013Tukey FFT algorithm](https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm).\n", "\n", "The FFT matrix is defined by the matrix computation $F_{ak} = X_{an} M_{nk}$, then one coefficient is ($1 \\leqslant n, k \\leqslant K$):\n", "\n", "$$\n", "F_{ak} = \\sum_n X_{an} M_{nk} = \\sum_n X_{an} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{nk}\n", "$$\n", "\n", "Let's assume K is even, then $\\exp\\left(\\frac{-2i\\pi k}{K}\\right) = -\\exp\\left(\\frac{-2i\\pi \\left(k + \\frac{K}{2}\\right)}{K}\\right)$."]}, {"cell_type": "code", "execution_count": 24, "id": "803fcc3a", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import matplotlib.pyplot as plt\n", "fig, ax = plt.subplots(1, 1, figsize=(6, 3))\n", "a = numpy.arange(0, 12) * (-2 * numpy.pi / 12)\n", "X = numpy.vstack([numpy.cos(a), numpy.sin(a)]).T\n", "ax.plot(X[:, 0], X[:, 1], 'o');\n", "for i in range(0, 12):\n", " ax.text(X[i, 0], X[i, 1], \"exp(-2pi %d/12)\" % i)\n", "ax.set_title('unit roots');"]}, {"cell_type": "markdown", "id": "f257dc6e", "metadata": {}, "source": ["Then:\n", "\n", "$$\n", "\\begin{array}{rcl}\n", "F_{a,k + \\frac{K}{2}} &=& \\sum_{n=1}^{N} X_{an} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{n\\left(k + \\frac{K}{2}\\right)} \\\\\n", "&=&\\sum_{n=1}^{N} X_{an} (-1)^n \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{nk} \\\\\n", "&=&\\sum_{m=1}^{\\frac{N}{2}} X_{a,2m} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{2mk} - \\sum_{m=1}^{\\frac{N}{2}} X_{a,2m-1} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{(2m-1)k} \\\\\n", "&=&\\sum_{m=1}^{\\frac{N}{2}} X_{a,2m} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{2mk} - \\sum_{m=1}^{\\frac{N}{2}} X_{a,2m-1} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{2mk} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{-k}\n", "\\end{array}\n", "$$\n", "\n", "Then:\n", "\n", "$$\n", "\\begin{array}{rcl}\n", "F_{a,k} + F_{a,k+\\frac{K}{2}} &=& 2\\sum_{m=1}^{\\frac{N}{2}} X_{a,2m} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{2mk}\n", "= 2\\sum_{m=1}^{\\frac{N}{2}} X_{a,2m} \\exp\\left(\\frac{-2i\\pi}{\\frac{K}{2}}\\right)^{mk}\n", "\\end{array}\n", "$$\n", "\n", "Finally:\n", "\n", "$$\n", "\\begin{array}{rcl}\n", "F_{a,k} &=& \\sum_{m=1}^{\\frac{N}{2}} X_{a,2m} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{2mk} + \\sum_{m=1}^{\\frac{N}{2}} X_{a,2m-1} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{2mk} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{-k} \\\\\n", "F_{a,k + \\frac{K}{2}} &=&\\sum_{m=1}^{\\frac{N}{2}} X_{a,2m} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{2mk} - \\sum_{m=1}^{\\frac{N}{2}} X_{a,2m-1} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{2mk} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{-k}\n", "\\end{array}\n", "$$"]}, {"cell_type": "markdown", "id": "356b585d", "metadata": {}, "source": ["Now, what happen when *K* is odd, fallback to the original computation.\n", "\n", "$$\n", "F_{ak} = \\sum_n X_{an} M_{nk} = \\sum_n X_{an} \\exp\\left(\\frac{-2i\\pi}{K}\\right)^{nk}\n", "$$"]}, {"cell_type": "code", "execution_count": 25, "id": "971be7bd", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["1.5 s \u00b1 0 ns per loop (mean \u00b1 std. dev. of 1 run, 1 loop each)\n"]}], "source": ["import functools\n", "\n", "\n", "def cooley_fft_2p(x, fft_length):\n", " cst = _dft_cst_power(x.shape[-1], fft_length, x.dtype)\n", " return numpy.matmul(x, cst)\n", "\n", "\n", "@functools.cache\n", "def _build_fact(p2_2, fft_length, dtype):\n", " first = numpy.exp(-2j * numpy.pi / fft_length)\n", " fact = numpy.ones(p2_2, dtype=dtype)\n", " for k in range(1, p2_2):\n", " fact[k] = fact[k-1] * first\n", " return fact.reshape((1, -1))\n", "\n", "\n", "def build_fact(p2_2, fft_length, dtype):\n", " return _build_fact(p2_2, fft_length, dtype)\n", "\n", "\n", "def cooley_fft_recursive(x, fft_length):\n", " if len(x.shape) != 2:\n", " raise RuntimeError(\n", " \"Unexpected x.shape=%r.\" % (x.shape, ))\n", " dtype = numpy.complex128 if x.dtype == numpy.float64 else numpy.complex64\n", " if fft_length == 1:\n", " return x[:, :1].astype(dtype)\n", "\n", " if fft_length % 2 == 0:\n", " def split(x):\n", " even = x[:, ::2]\n", " odd = x[:, 1::2]\n", " return even, odd\n", "\n", " def tmp1(even, odd, fft_length):\n", " p2_2 = fft_length // 2\n", " fft_even = cooley_fft_recursive(even, p2_2)\n", " fft_odd = cooley_fft_recursive(odd, p2_2)\n", " return fft_even, fft_odd, p2_2\n", "\n", " def tmp2(x, fft_even, fft_odd, p2_2):\n", " fact = build_fact(p2_2, fft_length, fft_even.dtype)\n", "\n", " fact_odd = fft_odd * fact\n", " return numpy.hstack([fft_even + fact_odd, fft_even - fact_odd])\n", "\n", " # inplace\n", " # result = numpy.empty((x.shape[0], fft_length), dtype=fft_even.dtype)\n", " # numpy.multiply(fft_odd, fact, out=result[:, :p2_2])\n", " # numpy.subtract(fft_even, result[:, :p2_2], out=result[:, p2_2:])\n", " # numpy.add(fft_even, result[:, :p2_2], out=result[:, :p2_2])\n", " # return result\n", " \n", " even, odd = split(x)\n", " fft_even, fft_odd, p2_2 = tmp1(even, odd, fft_length)\n", " result = tmp2(x, fft_even, fft_odd, p2_2)\n", " else:\n", " result = cooley_fft_2p(x, fft_length)\n", " \n", " return result\n", "\n", "\n", "\n", "def cooley_fft(x, fft_length):\n", " return cooley_fft_recursive(x, fft_length)\n", "\n", "\n", "def custom_fft_cooley(x, fft_type, length, axis):\n", " # https://github.com/numpy/numpy/blob/4adc87dff15a247e417d50f10cc4def8e1c17a03/numpy/fft/_pocketfft.py#L56\n", " if fft_type == 'FFT':\n", " if x.shape[axis] > length:\n", " # fft_length > shape on the same axis\n", " # the matrix is shortened\n", " slices = [slice(None)] * len(x.shape)\n", " slices[axis] = slice(0, length)\n", " new_x = x[tuple(slices)]\n", " elif x.shape[axis] == length:\n", " new_x = x\n", " else:\n", " # other, the matrix is completed with zeros\n", " shape = list(x.shape)\n", " shape[axis] = length\n", " slices = [slice(None)] * len(x.shape)\n", " slices[axis] = slice(0, length)\n", " zeros = numpy.zeros(tuple(shape), dtype=x.dtype)\n", " index = [slice(0, i) for i in x.shape]\n", " zeros[tuple(index)] = x\n", " new_x = zeros\n", "\n", " if axis == len(new_x.shape) - 1:\n", " if len(new_x.shape) != 2:\n", " xt = new_x.reshape((-1, new_x.shape[-1]))\n", " else:\n", " xt = new_x\n", " res = cooley_fft(xt, length)\n", " if len(new_x.shape) != 2:\n", " res = res.reshape(new_x.shape[:-1] + (-1, ))\n", " else:\n", " perm = numpy.arange(len(x.shape)).tolist() \n", " perm[axis], perm[-1] = perm[-1], perm[axis] \n", " rest = new_x.transpose(perm)\n", " shape = rest.shape[:-1]\n", " rest = rest.reshape((-1, rest.shape[-1]))\n", " res = cooley_fft(rest, length)\n", " res = res.reshape(shape + (-1, )).transpose(perm)\n", " perm[axis], perm[0] = perm[0], perm[axis]\n", " return res\n", " raise ValueError(\"Unexpected value for fft_type=%r.\" % fft_type)\n", "\n", "\n", "def custom_fftn_cooley(x, fft_type, fft_length, axes):\n", " if len(axes) != len(fft_length):\n", " raise ValueError(\"Length mismatch axes=%r, fft_length=%r.\" % (\n", " axes, fft_length))\n", " if fft_type == 'FFT':\n", " res = x\n", " for i in range(len(fft_length) - 1, -1, -1):\n", " length = fft_length[i]\n", " axis = axes[i]\n", " res = custom_fft_cooley(res, fft_type, length, axis)\n", " return res\n", " raise ValueError(\"Unexpected value for fft_type=%r.\" % fft_type)\n", " \n", "\n", "shape = (4, )\n", "fft_length = [3,]\n", "axes = [0]\n", "rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * 1j\n", "assert_almost_equal(custom_fftn_cooley(rnd, 'FFT', fft_length, axes),\n", " numpy_fftn(rnd, 'FFT', fft_length, axes),\n", " decimal=5)\n", "%timeit -n 1 -r 1 test_fct(numpy_fftn, custom_fftn_cooley)"]}, {"cell_type": "code", "execution_count": 26, "id": "441d6c41", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:10<00:00, 2.35it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namecustom_fftn_cooleycustom_fftn_powernumpy_fftntorch_fftn
length
80.0028730.0006850.0014820.005463
160.0071970.0021210.0019220.005063
240.0094430.0029030.0027390.005169
320.0127830.0025560.0020030.004076
400.0141420.0039160.0039370.005118
\n", "
"], "text/plain": ["name custom_fftn_cooley custom_fftn_power numpy_fftn torch_fftn\n", "length \n", "8 0.002873 0.000685 0.001482 0.005463\n", "16 0.007197 0.002121 0.001922 0.005063\n", "24 0.009443 0.002903 0.002739 0.005169\n", "32 0.012783 0.002556 0.002003 0.004076\n", "40 0.014142 0.003916 0.003937 0.005118"]}, "execution_count": 27, "metadata": {}, "output_type": "execute_result"}], "source": ["df = benchmark({\n", " 'numpy_fftn': numpy_fftn, 'torch_fftn': torch_fftn,\n", " 'custom_fftn_power': custom_fftn_power, 'custom_fftn_cooley': custom_fftn_cooley})\n", "piv = df.pivot(\"length\", \"name\", \"average\")\n", "piv[:5]"]}, {"cell_type": "code", "execution_count": 27, "id": "6b579149", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["piv.plot(logy=True, logx=True, title=\"FFT benchmark 3\", figsize=(12, 4));"]}, {"cell_type": "code", "execution_count": 28, "id": "eb4f30d6", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 10/10 [00:11<00:00, 1.15s/it]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
namecustom_fftn_cooleycustom_fftn_powernumpy_fftntorch_fftn
length
20.0005750.0004710.0007220.019371
40.0011530.0003280.0011300.018366
80.0036780.0006240.0017790.019295
160.0068430.0022550.0021920.020169
320.0155740.0030450.0027360.017193
\n", "
"], "text/plain": ["name custom_fftn_cooley custom_fftn_power numpy_fftn torch_fftn\n", "length \n", "2 0.000575 0.000471 0.000722 0.019371\n", "4 0.001153 0.000328 0.001130 0.018366\n", "8 0.003678 0.000624 0.001779 0.019295\n", "16 0.006843 0.002255 0.002192 0.020169\n", "32 0.015574 0.003045 0.002736 0.017193"]}, "execution_count": 29, "metadata": {}, "output_type": "execute_result"}], "source": ["df = benchmark({\n", " 'numpy_fftn': numpy_fftn, 'torch_fftn': torch_fftn,\n", " 'custom_fftn_power': custom_fftn_power, 'custom_fftn_cooley': custom_fftn_cooley},\n", " power2=True)\n", "piv = df.pivot(\"length\", \"name\", \"average\")\n", "piv[:5]"]}, {"cell_type": "code", "execution_count": 29, "id": "769d9e41", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["piv.plot(logy=True, logx=True, title=\"FFT benchmark 3 (power2)\", figsize=(12, 4));"]}, {"cell_type": "code", "execution_count": 30, "id": "3f6ef212", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["cooley_fft_recursive -- 100 51100 -- 0.24497 2.68339 -- :22:cooley_fft_recursive (cooley_fft_recursive)\n", " split -- 25500 25500 -- 0.06264 0.06264 -- :31:split (split)\n", " tmp1 -- 100 25500 -- 0.09438 2.54540 -- :36:tmp1 (tmp1)\n", " cooley_fft_recursive -- 51000 200 -- 0.24336 2.54421 -- :22:cooley_fft_recursive (cooley_fft_recursive) +++\n", " tmp2 -- 25500 25500 -- 0.95948 2.04473 -- :42:tmp2 (tmp2)\n", " hstack -- 25500 25500 -- 0.04799 1.05776 -- <__array_function__ internals>:177:hstack (hstack)\n", " _vhstack_dispatcher -- 25500 25500 -- 0.02712 0.07002 -- C:/Python395_x64/lib/site-packages/numpy/core/shape_base.py:218:_vhstack_dispatcher (_vhstack_dispatcher)\n", " _arrays_for...dispatcher -- 25500 25500 -- 0.02361 0.04290 -- C:/Python395_x64/lib/site-packages/numpy/core/shape_base.py:207:_arrays_for_stack_dispatcher (_arrays_for_stack_dispatcher)\n", " -- 25500 25500 -- 0.01929 0.01929 -- ~:0: ()\n", " -- 25500 25500 -- 0.03753 0.93975 -- ~:0: () +++\n", " build_fact -- 25500 25500 -- 0.02749 0.02749 -- :18:build_fact (build_fact)\n", " -- 51100 51100 -- 0.01521 0.01521 -- ~:0: () +++\n", " -- 25600 25600 -- 0.22146 0.22146 -- ~:0: ()\n", "f -- 1 1 -- 0.01449 2.70167 -- :8:f (f)\n", " custom_fftn_cooley -- 100 100 -- 0.00139 2.68718 -- :112:custom_fftn_cooley (custom_fftn_cooley)\n", " custom_fft_cooley -- 100 100 -- 0.00135 2.68568 -- :69:custom_fft_cooley (custom_fft_cooley)\n", " cooley_fft -- 100 100 -- 0.00082 2.68421 -- :65:cooley_fft (cooley_fft)\n", " cooley_fft_recursive -- 100 100 -- 0.00160 2.68339 -- :22:cooley_fft_recursive (cooley_fft_recursive) +++\n", " -- 300 300 -- 0.00012 0.00012 -- ~:0: () +++\n", " -- 300 300 -- 0.00011 0.00011 -- ~:0: () +++\n", " -- 77200 77200 -- 0.02367 0.02367 -- ~:0: ()\n", " -- 25500 76500 -- 0.58675 0.93975 -- ~:0: ()\n", " atleast_1d -- 25500 25500 -- 0.09562 0.13747 -- C:/Python395_x64/lib/site-packages/numpy/core/shape_base.py:23:atleast_1d (atleast_1d)\n", " -- 51000 51000 -- 0.01708 0.01708 -- ~:0: ()\n", " -- 25500 25500 -- 0.00822 0.00822 -- ~:0: () +++\n", " -- 51000 51000 -- 0.01655 0.01655 -- ~:0: ()\n", " hstack -- 25500 25500 -- 0.09871 0.90222 -- C:/Python395_x64/lib/site-packages/numpy/core/shape_base.py:285:hstack (hstack)\n", " concatenate -- 25500 25500 -- 0.04882 0.57709 -- <__array_function__ internals>:177:concatenate (concatenate)\n", " concatenate -- 25500 25500 -- 0.01049 0.01049 -- C:/Python395_x64/lib/site-packages/numpy/core/multiarray.py:148:concatenate (concatenate)\n", " -- 25500 25500 -- 0.51778 0.51778 -- ~:0: () +++\n", " atleast_1d -- 25500 25500 -- 0.04022 0.21751 -- <__array_function__ internals>:177:atleast_1d (atleast_1d)\n", " _atleast_1d_dispatcher -- 25500 25500 -- 0.00838 0.00838 -- C:/Python395_x64/lib/site-packages/numpy/core/shape_base.py:19:_atleast_1d_dispatcher (_atleast_1d_dispatcher)\n", " -- 25500 25500 -- 0.03144 0.16891 -- ~:0: () +++\n", " -- 25500 25500 -- 0.00892 0.00892 -- ~:0: ()\n"]}], "source": ["from pyquickhelper.pycode.profiling import profile2graph, profile\n", "\n", "shape = [512, 256]\n", "fft_length = [256]\n", "axes = [1]\n", "rnd = numpy.random.randn(*shape) + numpy.random.randn(*shape) * 1j\n", "\n", "def f():\n", " for i in range(100):\n", " custom_fftn_cooley(rnd, 'FFT', fft_length, axes)\n", "\n", "stat, text = profile(f)\n", "gr = profile2graph(stat)\n", "print(gr[0].to_text(fct_width=40))"]}, {"cell_type": "code", "execution_count": 31, "id": "79164246", "metadata": {}, "outputs": [], "source": []}, {"cell_type": "code", "execution_count": 32, "id": "b9a3147b", "metadata": {}, "outputs": [], "source": []}, {"cell_type": "code", "execution_count": 33, "id": "6c1375ae", "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5"}}, "nbformat": 4, "nbformat_minor": 5}