{"cells": [{"cell_type": "markdown", "id": "51bc89fc", "metadata": {}, "source": ["# ONNX and FFT\n", "\n", "ONNX does not fully support complex yet. It does not have any FFT operators either. What if we need them anyway?"]}, {"cell_type": "code", "execution_count": 1, "id": "7b2add97", "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "code", "execution_count": 2, "id": "acfdc3b0", "metadata": {}, "outputs": [], "source": ["%load_ext mlprodict"]}, {"cell_type": "code", "execution_count": 3, "id": "abb5fa88", "metadata": {}, "outputs": [{"data": {"text/plain": ["'1.21.5'"]}, "execution_count": 4, "metadata": {}, "output_type": "execute_result"}], "source": ["import numpy\n", "numpy.__version__"]}, {"cell_type": "markdown", "id": "2e4f68e4", "metadata": {}, "source": ["## Python implementation of RFFT\n", "\n", "We try to replicate [numpy.rfft](https://numpy.org/doc/stable/reference/generated/numpy.fft.rfft.html)."]}, {"cell_type": "code", "execution_count": 4, "id": "cb1cc910", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-0.33227623+0.j , -1.53729601-0.93413037j,\n", " 4.47973719+2.89019374j, 1.36392938-2.59133368j],\n", " [ 0.07591467+0.j , 0.51947711+0.624144j ,\n", " -2.48242622-1.56579382j, -0.98728199+2.81434946j],\n", " [-0.55875075+0.j , -0.83228203+2.25251549j,\n", " 0.48281369+2.69338405j, -0.86559293+0.08437194j],\n", " [ 0.26185111+0.j , -1.18143684+1.73623491j,\n", " 0.96002386+0.39340971j, 3.53861562-1.32858241j],\n", " [ 1.06276855+0.j , 3.07258661-2.71505518j,\n", " -0.82579331-1.91852778j, 4.10811113-0.46836687j]])"]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["import numpy\n", "\n", "\n", "def almost_equal(a, b, error=1e-5):\n", " \"\"\"\n", " The function compares two matrices, one may be complex. In that case,\n", " this matrix is changed into a new matrix with a new first dimension,\n", " [0,::] means real part, [1,::] means imaginary part.\n", " \"\"\"\n", " if a.dtype in (numpy.complex64, numpy.complex128):\n", " dtype = numpy.float64 if a.dtype == numpy.complex128 else numpy.float32\n", " new_a = numpy.empty((2,) + a.shape).astype(dtype)\n", " new_a[0] = numpy.real(a)\n", " new_a[1] = numpy.imag(a)\n", " return almost_equal(new_a, b, error)\n", " if b.dtype in (numpy.complex64, numpy.complex128):\n", " return almost_equal(b, a, error)\n", " if a.shape != b.shape:\n", " raise AssertionError(\"Shape mismatch %r != %r.\" % (a.shape, b.shape))\n", " diff = numpy.abs(a.ravel() - b.ravel()).max()\n", " if diff > error:\n", " raise AssertionError(\"Mismatch max diff=%r > %r.\" % (diff, error))\n", "\n", "\n", "def dft_real_cst(N, fft_length):\n", " n = numpy.arange(N)\n", " k = n.reshape((N, 1)).astype(numpy.float64)\n", " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", " both = numpy.empty((2,) + M.shape)\n", " both[0, :, :] = numpy.real(M)\n", " both[1, :, :] = numpy.imag(M)\n", " return both\n", "\n", "\n", "def dft_real(x, fft_length=None, transpose=True):\n", " if len(x.shape) == 1:\n", " x = x.reshape((1, -1))\n", " N = 1\n", " else:\n", " N = x.shape[0] \n", " C = x.shape[-1] if transpose else x.shape[-2]\n", " if fft_length is None:\n", " fft_length = x.shape[-1]\n", " size = fft_length // 2 + 1\n", "\n", " cst = dft_real_cst(C, fft_length)\n", " if transpose:\n", " x = numpy.transpose(x, (1, 0))\n", " a = cst[:, :, :fft_length]\n", " b = x[:fft_length]\n", " res = numpy.matmul(a, b)\n", " res = res[:, :size, :]\n", " return numpy.transpose(res, (0, 2, 1))\n", " else:\n", " a = cst[:, :, :fft_length]\n", " b = x[:fft_length]\n", " return numpy.matmul(a, b)\n", "\n", "\n", "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", "fft_np = numpy.fft.rfft(rnd)\n", "fft_cus = dft_real(rnd)\n", "fft_np"]}, {"cell_type": "markdown", "id": "0c052ea1", "metadata": {}, "source": ["Function `almost_equal` verifies both functions return the same results."]}, {"cell_type": "code", "execution_count": 5, "id": "3ca040cb", "metadata": {}, "outputs": [], "source": ["almost_equal(fft_np, fft_cus)"]}, {"cell_type": "markdown", "id": "7fe77440", "metadata": {}, "source": ["Let's do the same with `fft_length < shape[1]`."]}, {"cell_type": "code", "execution_count": 6, "id": "3a747a4a", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-0.86976612+0.j , 2.20926839+0.35688821j],\n", " [ 0.33280143+0.j , -1.41451804+0.2065253j ],\n", " [-2.30690554+0.j , 0.51297992+0.62331197j],\n", " [-0.72842433+0.j , 1.84198139+1.07546916j],\n", " [ 4.17533261+0.j , 0.86360028+0.36508775j]])"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["fft_np3 = numpy.fft.rfft(rnd, n=3)\n", "fft_cus3 = dft_real(rnd, fft_length=3)\n", "fft_np3"]}, {"cell_type": "code", "execution_count": 7, "id": "0db6247b", "metadata": {}, "outputs": [], "source": ["almost_equal(fft_np3, fft_cus3)"]}, {"cell_type": "markdown", "id": "31a6ac9c", "metadata": {}, "source": ["## RFFT in ONNX\n", "\n", "Let's assume first the number of column of the input matrix is fixed. The result of function `dft_real_cst` can be considered as constant."]}, {"cell_type": "code", "execution_count": 8, "id": "efb67b9b", "metadata": {"scrolled": false}, "outputs": [{"data": {"text/plain": ["array([[[-0.33227617, -1.5372959 , 4.4797373 , 1.3639294 ],\n", " [ 0.07591468, 0.51947707, -2.4824262 , -0.98728204],\n", " [-0.5587506 , -0.8322822 , 0.48281363, -0.86559296],\n", " [ 0.26185107, -1.1814368 , 0.96002394, 3.5386157 ],\n", " [ 1.0627685 , 3.0725865 , -0.8257934 , 4.108111 ]],\n", "\n", " [[ 0. , -0.93413043, 2.890194 , -2.5913336 ],\n", " [ 0. , 0.624144 , -1.5657941 , 2.8143494 ],\n", " [ 0. , 2.2525156 , 2.6933842 , 0.08437189],\n", " [ 0. , 1.7362347 , 0.39340976, -1.3285824 ],\n", " [ 0. , -2.7150555 , -1.9185277 , -0.4683669 ]]],\n", " dtype=float32)"]}, "execution_count": 9, "metadata": {}, "output_type": "execute_result"}], "source": ["from typing import Any\n", "import mlprodict.npy.numpy_onnx_impl as npnx\n", "from mlprodict.npy import onnxnumpy_np\n", "from mlprodict.npy.onnx_numpy_annotation import NDArrayType\n", "# from mlprodict.onnxrt import OnnxInference\n", "\n", "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", "def onnx_rfft(x, fft_length=None):\n", " if fft_length is None:\n", " raise RuntimeError(\"fft_length must be specified.\")\n", " \n", " size = fft_length // 2 + 1\n", " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", " xt = npnx.transpose(x, (1, 0))\n", " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", " return npnx.transpose(res, (0, 2, 1))\n", "\n", "fft_onx = onnx_rfft(rnd, fft_length=rnd.shape[1])\n", "fft_onx"]}, {"cell_type": "code", "execution_count": 9, "id": "c4b6b1a5", "metadata": {}, "outputs": [], "source": ["almost_equal(fft_cus, fft_onx)"]}, {"cell_type": "markdown", "id": "a8c35327", "metadata": {}, "source": ["The corresponding ONNX graph is the following:"]}, {"cell_type": "code", "execution_count": 10, "id": "4d1a85b0", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 11, "metadata": {}, "output_type": "execute_result"}], "source": ["%onnxview onnx_rfft.to_onnx()"]}, {"cell_type": "code", "execution_count": 11, "id": "6cf18aca", "metadata": {}, "outputs": [], "source": ["fft_onx3 = onnx_rfft(rnd, fft_length=3)\n", "almost_equal(fft_cus3, fft_onx3)"]}, {"cell_type": "markdown", "id": "6b466fd4", "metadata": {}, "source": ["## FFT 2D\n", "\n", "Below the code for complex features."]}, {"cell_type": "code", "execution_count": 12, "id": "e0020084", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-4.14039719 +0.j , -1.06715605 +1.16770652j,\n", " -0.27080808 +1.93562775j, 5.28785846 +2.27915445j],\n", " [-2.57576449 +3.09907081j, -8.90391777 -5.56953367j,\n", " -1.6455202 +2.03337471j, 4.21121677 -1.85803104j],\n", " [ 1.84529583 -0.54705419j, 3.61232172 -4.11661604j,\n", " 1.00659205 +3.72264071j, -0.36878039 -8.21956881j],\n", " [ 1.84529583 +0.54705419j, -1.173484 +5.12345283j,\n", " -1.7897386 -10.15322422j, -0.17258219 +2.37388952j],\n", " [-2.57576449 -3.09907081j, 0.58355627 +1.62293628j,\n", " 0.71779814 +4.64582025j, -6.32441255 -4.21906685j]])"]}, "execution_count": 13, "metadata": {}, "output_type": "execute_result"}], "source": ["def _DFT_cst(N, fft_length, trunc=True):\n", " n = numpy.arange(N)\n", " k = n.reshape((N, 1)).astype(numpy.float64)\n", " M = numpy.exp(-2j * numpy.pi * k * n / fft_length)\n", " return M[:fft_length // 2 + 1] if trunc else M\n", "\n", "def DFT(x, fft_length=None, axis=1):\n", " if axis == 1:\n", " x = x.T\n", " if fft_length is None:\n", " fft_length = x.shape[0]\n", " cst = _DFT_cst(x.shape[0], fft_length, trunc=axis==1)\n", " if axis == 1:\n", " return numpy.matmul(cst, x).T\n", " return numpy.matmul(cst, x)\n", "\n", "def fft2d_(mat, fft_length):\n", " mat = mat[:fft_length[0], :fft_length[1]]\n", " res = mat.copy()\n", " res = DFT(res, fft_length[1], axis=1)\n", " res = DFT(res, fft_length[0], axis=0)\n", " return res[:fft_length[0], :fft_length[1]//2 + 1]\n", "\n", "\n", "rnd = numpy.random.randn(5, 7).astype(numpy.float32)\n", "fft2d_np_ = fft2d_(rnd, rnd.shape)\n", "fft2d_np = numpy.fft.rfft2(rnd)\n", "fft2d_np_"]}, {"cell_type": "code", "execution_count": 13, "id": "777d2775", "metadata": {}, "outputs": [], "source": ["almost_equal(fft2d_np_, fft2d_np)"]}, {"cell_type": "markdown", "id": "cfbbe2fd", "metadata": {}, "source": ["It implies the computation of two FFT 1D along both axes. However, as ONNX does not support complex, it needs to be rewritten with only real numbers. The algorithm can be summarized into this formula $FFT(FFT(x, axis=1), axis=0)$. If *x* is real, $FFT(x, .)$ is complex. We still assume *x* is real, it then becomes (FFT is a linear operator, so $FFT(ix)=i FFT(x)$):\n", "\n", "* $y = FFT(x, axis=1)$\n", "* $z_r = FFT(Real(y), axis=0)$, $z_i = FFT(Imag(y), axis=0)$\n", "* $z = z_r + i z_i$\n", "\n", "*z* is the desired output. The following implementation is probably not the most efficient one. It avoids inplace computation as ONNX does like that."]}, {"cell_type": "code", "execution_count": 14, "id": "dd4fc711", "metadata": {}, "outputs": [], "source": ["def fft2d(mat, fft_length):\n", " mat = mat[:fft_length[0], :fft_length[1]]\n", " res = mat.copy()\n", " \n", " # first FFT\n", " res = dft_real(res, fft_length=fft_length[1], transpose=True)\n", " \n", " # second FFT decomposed on FFT on real part and imaginary part\n", " res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False)\n", " res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) \n", " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", " res = res2_real + res2_imag2\n", " size = fft_length[1]//2 + 1\n", " return res[:, :fft_length[0], :size]\n", "\n", "\n", "fft2d_np = numpy.fft.rfft2(rnd)\n", "fft2d_cus = fft2d(rnd, rnd.shape)\n", "almost_equal(fft2d_np, fft2d_cus)"]}, {"cell_type": "code", "execution_count": 15, "id": "bb8667e6", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[-4.14039719 +0.j , -1.06715605 +1.16770652j,\n", " -0.27080808 +1.93562775j, 5.28785846 +2.27915445j],\n", " [-2.57576449 +3.09907081j, -8.90391777 -5.56953367j,\n", " -1.6455202 +2.03337471j, 4.21121677 -1.85803104j],\n", " [ 1.84529583 -0.54705419j, 3.61232172 -4.11661604j,\n", " 1.00659205 +3.72264071j, -0.36878039 -8.21956881j],\n", " [ 1.84529583 +0.54705419j, -1.173484 +5.12345283j,\n", " -1.7897386 -10.15322422j, -0.17258219 +2.37388952j],\n", " [-2.57576449 -3.09907081j, 0.58355627 +1.62293628j,\n", " 0.71779814 +4.64582025j, -6.32441255 -4.21906685j]])"]}, "execution_count": 16, "metadata": {}, "output_type": "execute_result"}], "source": ["fft2d_np"]}, {"cell_type": "code", "execution_count": 16, "id": "56a94d97", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[[ -4.14039719, -1.06715605, -0.27080808, 5.28785846],\n", " [ -2.57576449, -8.90391777, -1.6455202 , 4.21121677],\n", " [ 1.84529583, 3.61232172, 1.00659205, -0.36878039],\n", " [ 1.84529583, -1.173484 , -1.7897386 , -0.17258219],\n", " [ -2.57576449, 0.58355627, 0.71779814, -6.32441255]],\n", "\n", " [[ 0. , 1.16770652, 1.93562775, 2.27915445],\n", " [ 3.09907081, -5.56953367, 2.03337471, -1.85803104],\n", " [ -0.54705419, -4.11661604, 3.72264071, -8.21956881],\n", " [ 0.54705419, 5.12345283, -10.15322422, 2.37388952],\n", " [ -3.09907081, 1.62293628, 4.64582025, -4.21906685]]])"]}, "execution_count": 17, "metadata": {}, "output_type": "execute_result"}], "source": ["fft2d_cus"]}, {"cell_type": "markdown", "id": "faa21909", "metadata": {}, "source": ["And with a different `fft_length`."]}, {"cell_type": "code", "execution_count": 17, "id": "bf98995f", "metadata": {}, "outputs": [], "source": ["fft2d_np = numpy.fft.rfft2(rnd, (4, 6))\n", "fft2d_cus = fft2d(rnd, (4, 6))\n", "almost_equal(fft2d_np[:4, :], fft2d_cus)"]}, {"cell_type": "markdown", "id": "caee1f84", "metadata": {}, "source": ["## FFT 2D in ONNX\n", "\n", "We use again the numpy API for ONNX."]}, {"cell_type": "code", "execution_count": 18, "id": "ca641274", "metadata": {}, "outputs": [], "source": ["def onnx_rfft_1d(x, fft_length=None, transpose=True):\n", " if fft_length is None:\n", " raise RuntimeError(\"fft_length must be specified.\")\n", " \n", " size = fft_length // 2 + 1\n", " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", " if transpose:\n", " xt = npnx.transpose(x, (1, 0))\n", " res = npnx.matmul(cst[:, :, :fft_length], xt[:fft_length])[:, :size, :]\n", " return npnx.transpose(res, (0, 2, 1))\n", " else:\n", " return npnx.matmul(cst[:, :, :fft_length], x[:fft_length])\n", "\n", "\n", "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", "def onnx_rfft_2d(x, fft_length=None):\n", " mat = x[:fft_length[0], :fft_length[1]]\n", " \n", " # first FFT\n", " res = onnx_rfft_1d(mat, fft_length=fft_length[1], transpose=True)\n", " \n", " # second FFT decomposed on FFT on real part and imaginary part\n", " res2_real = onnx_rfft_1d(res[0], fft_length=fft_length[0], transpose=False)\n", " res2_imag = onnx_rfft_1d(res[1], fft_length=fft_length[0], transpose=False) \n", " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", " res = res2_real + res2_imag2\n", " size = fft_length[1]//2 + 1\n", " return res[:, :fft_length[0], :size]\n", "\n", "\n", "fft2d_cus = fft2d(rnd, rnd.shape)\n", "fft2d_onx = onnx_rfft_2d(rnd, fft_length=rnd.shape)\n", "almost_equal(fft2d_cus, fft2d_onx)"]}, {"cell_type": "markdown", "id": "20fcd8a9", "metadata": {}, "source": ["The corresponding ONNX graph."]}, {"cell_type": "code", "execution_count": 19, "id": "b1379b06", "metadata": {"scrolled": false}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 20, "metadata": {}, "output_type": "execute_result"}], "source": ["%onnxview onnx_rfft_2d.to_onnx()"]}, {"cell_type": "code", "execution_count": 20, "id": "3034da60", "metadata": {}, "outputs": [], "source": ["with open(\"fft2d.onnx\", \"wb\") as f:\n", " f.write(onnx_rfft_2d.to_onnx().SerializeToString())"]}, {"cell_type": "markdown", "id": "3a747f0c", "metadata": {}, "source": ["With a different `fft_length`."]}, {"cell_type": "code", "execution_count": 21, "id": "16732cbb", "metadata": {}, "outputs": [], "source": ["fft2d_cus = fft2d(rnd, (4, 5))\n", "fft2d_onx = onnx_rfft_2d(rnd, fft_length=(4, 5))\n", "almost_equal(fft2d_cus, fft2d_onx)"]}, {"cell_type": "markdown", "id": "04924e7d", "metadata": {}, "source": ["This implementation of FFT in ONNX assumes shapes and fft lengths are constant. Otherwise, the matrix returned by function `dft_real_cst` must be converted as well. That's left as an exercise."]}, {"cell_type": "markdown", "id": "c9da88a0", "metadata": {}, "source": ["## FFT2D with shape (3,1,4)\n", "\n", "Previous implementation expects the input matrix to have two dimensions. It fails with 3."]}, {"cell_type": "code", "execution_count": 22, "id": "66ba70ee", "metadata": {}, "outputs": [{"data": {"text/plain": ["(3, 1, 4)"]}, "execution_count": 23, "metadata": {}, "output_type": "execute_result"}], "source": ["shape = (3, 1, 4)\n", "fft_length = (1, 4)\n", "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", "fft2d_numpy.shape"]}, {"cell_type": "code", "execution_count": 23, "id": "a4d123e1", "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[[-1.04513007+0.j , 0.7261328 -0.1488841j ,\n", " -0.76143177+0.j , 0.7261328 +0.1488841j ]],\n", "\n", " [[ 0.13626025+0.j , -0.37364573+0.49485394j,\n", " -0.5746009 +0.j , -0.37364573-0.49485394j]],\n", "\n", " [[ 1.52022177+0.j , 0.35786384+1.09477997j,\n", " 2.16783673+0.j , 0.35786384-1.09477997j]]])"]}, "execution_count": 24, "metadata": {}, "output_type": "execute_result"}], "source": ["fft2d_numpy"]}, {"cell_type": "code", "execution_count": 24, "id": "4b1bd05b", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["axes don't match array\n"]}], "source": ["try:\n", " fft2d_cus = fft2d(rnd, fft_length)\n", "except Exception as e:\n", " print(e)\n", "# fft2d_onx = onnx_rfft_2d(rnd, fft_length=fft_length)"]}, {"cell_type": "markdown", "id": "7bd79a00", "metadata": {}, "source": ["### numpy version\n", "\n", "Let's do it again with numpy first. [fft2](https://numpy.org/doc/stable/reference/generated/numpy.fft.fft2.html) performs `fft2` on the last two axis as many times as the first axis. The goal is still to have an implementation which works for any dimension."]}, {"cell_type": "code", "execution_count": 25, "id": "3b618335", "metadata": {}, "outputs": [], "source": ["conc = []\n", "for i in range(rnd.shape[0]):\n", " f2 = fft2d(rnd[i], fft_length)\n", " conc.append(numpy.expand_dims(f2, 0))\n", "res = numpy.vstack(conc).transpose(1, 0, 2, 3)\n", "almost_equal(fft2d_numpy[:, :, :3], res)"]}, {"cell_type": "markdown", "id": "7c837e7a", "metadata": {}, "source": ["It works. And now a more efficient implementation. It is better to read [matmul](https://numpy.org/doc/stable/reference/generated/numpy.matmul.html) description before. To summarize, a third axis is equivalent to many matrix multiplications over the last two axes, as many as the dimension of the first axis: ``matmul(A[I,J,K], B[I,K,L]) --> C[I,J,L]``. Broadcasting also works... ``matmul(A[1,J,K], B[I,K,L]) --> C[I,J,L]``."]}, {"cell_type": "code", "execution_count": 26, "id": "29055cb2", "metadata": {}, "outputs": [], "source": ["def dft_real_d3(x, fft_length=None, transpose=True):\n", " if len(x.shape) != 3:\n", " raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)\n", " N = x.shape[1]\n", " C = x.shape[-1] if transpose else x.shape[-2]\n", " if fft_length is None:\n", " fft_length = x.shape[-1]\n", " size = fft_length // 2 + 1\n", "\n", " cst = dft_real_cst(C, fft_length)\n", " if transpose:\n", " x = numpy.transpose(x, (0, 2, 1))\n", " a = cst[:, :, :fft_length]\n", " b = x[:, :fft_length, :]\n", " a = numpy.expand_dims(a, 0)\n", " b = numpy.expand_dims(b, 1)\n", " res = numpy.matmul(a, b)\n", " res = res[:, :, :size, :]\n", " return numpy.transpose(res, (1, 0, 3, 2))\n", " else:\n", " a = cst[:, :, :fft_length]\n", " b = x[:, :fft_length, :]\n", " a = numpy.expand_dims(a, 0)\n", " b = numpy.expand_dims(b, 1)\n", " res = numpy.matmul(a, b)\n", " return numpy.transpose(res, (1, 0, 2, 3))\n", "\n", "\n", "def fft2d_d3(mat, fft_length):\n", " mat = mat[:, :fft_length[-2], :fft_length[-1]]\n", " res = mat.copy()\n", " \n", " # first FFT\n", " res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)\n", " \n", " # second FFT decomposed on FFT on real part and imaginary part\n", " res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)\n", " res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)\n", " res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])\n", " res = res2_real + res2_imag2\n", " size = fft_length[-1]//2 + 1\n", " return res[:, :, :fft_length[-2], :size]\n", "\n", "\n", "def fft2d_any(mat, fft_length):\n", " new_shape = (-1, ) + mat.shape[-2:]\n", " mat2 = mat.reshape(new_shape)\n", " f2 = fft2d_d3(mat2, fft_length)\n", " new_shape = (2, ) + mat.shape[:-2] + f2.shape[-2:]\n", " return f2.reshape(new_shape)\n", "\n", "\n", "shape = (3, 1, 4)\n", "fft_length = (1, 4)\n", "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", "fft2d_numpy = numpy.fft.fft2(rnd, fft_length)\n", "fft2d_cus = fft2d_any(rnd, fft_length)\n", "almost_equal(fft2d_numpy[..., :3], fft2d_cus)"]}, {"cell_type": "markdown", "id": "0128b3f2", "metadata": {}, "source": ["We check with more shapes to see if the implementation works for all of them."]}, {"cell_type": "code", "execution_count": 27, "id": "82f5fc78", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 1, 4) or (2, 3, 1, 3)\n", "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 1, 2) or (2, 3, 1, 2)\n", "OK x.shape=(3, 1, 4) length=(1, 1) output shape=(3, 1, 1) or (2, 3, 1, 1)\n", "OK x.shape=(5, 7) length=(5, 7) output shape=(5, 7) or (2, 5, 4)\n", "OK x.shape=(5, 7) length=(1, 7) output shape=(1, 7) or (2, 1, 4)\n", "OK x.shape=(5, 7) length=(2, 7) output shape=(2, 7) or (2, 2, 4)\n", "OK x.shape=(5, 7) length=(5, 2) output shape=(5, 2) or (2, 5, 2)\n", "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 5, 7) or (2, 3, 5, 4)\n", "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 1, 7) or (2, 3, 1, 4)\n", "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 2, 7) or (2, 3, 2, 4)\n", "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 5, 2) or (2, 3, 5, 2)\n", "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 3, 4) or (2, 3, 3, 3)\n", "OK x.shape=(7, 5) length=(7, 5) output shape=(7, 5) or (2, 7, 3)\n", "OK x.shape=(7, 5) length=(1, 5) output shape=(1, 5) or (2, 1, 3)\n", "OK x.shape=(7, 5) length=(2, 5) output shape=(2, 5) or (2, 2, 3)\n", "OK x.shape=(7, 5) length=(7, 2) output shape=(7, 2) or (2, 7, 2)\n", "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n"]}], "source": ["for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", " for fft_length in [shape[-2:], (1, shape[-1]),\n", " (min(2, shape[-2]), shape[-1]),\n", " (shape[-2], 2),\n", " (min(3, shape[-2]), min(4, shape[-2]))]:\n", " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", " fnp = numpy.fft.fft2(x, fft_length)\n", " if len(fnp.shape) == 2:\n", " fn= numpy.expand_dims(fnp, 0)\n", " try:\n", " cus = fft2d_any(x, fft_length)\n", " except IndexError as e:\n", " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", " continue\n", " try:\n", " almost_equal(fnp[..., :cus.shape[-1]], cus)\n", " except (AssertionError, IndexError) as e:\n", " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", " x.shape, fft_length, e, fnp.shape, cus.shape))\n", " continue\n", " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", " x.shape, fft_length, fnp.shape, cus.shape))"]}, {"cell_type": "markdown", "id": "c5f5229a", "metadata": {}, "source": ["### ONNX version\n", "\n", "Let's look into the differences first."]}, {"cell_type": "code", "execution_count": 28, "id": "025c2d88", "metadata": {}, "outputs": [], "source": ["%load_ext pyquickhelper"]}, {"cell_type": "code", "execution_count": 29, "id": "8a9d153c", "metadata": {}, "outputs": [{"data": {"text/html": ["\n"], "text/plain": [""]}, "metadata": {}, "output_type": "display_data"}], "source": ["%%html\n", ""]}, {"cell_type": "code", "execution_count": 30, "id": "82664bc5", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:00<00:00, 573.03it/s]\n"]}, {"data": {"text/html": ["\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
00def dft_real(x, fft_length=None, transpose=True):def dft_real_d3(x, fft_length=None, transpose=True):
11 if len(x.shape) == 1: if len(x.shape) != 3:
22 x = x.reshape((1, -1)) raise RuntimeError(\"Not implemented for shape=%r.\" % x.shape)
3 N = 1
4 else:
53 N = x.shape[0] N = x.shape[1]
64 C = x.shape[-1] if transpose else x.shape[-2] C = x.shape[-1] if transpose else x.shape[-2]
75 if fft_length is None: if fft_length is None:
86 fft_length = x.shape[-1] fft_length = x.shape[-1]
97 size = fft_length // 2 + 1 size = fft_length // 2 + 1
108
119 cst = dft_real_cst(C, fft_length) cst = dft_real_cst(C, fft_length)
1210 if transpose: if transpose:
1311 x = numpy.transpose(x, (1, 0)) x = numpy.transpose(x, (0, 2, 1))
1412 a = cst[:, :, :fft_length] a = cst[:, :, :fft_length]
1513 b = x[:fft_length] b = x[:, :fft_length, :]
14 a = numpy.expand_dims(a, 0)
15 b = numpy.expand_dims(b, 1)
1616 res = numpy.matmul(a, b) res = numpy.matmul(a, b)
1717 res = res[:, :size, :] res = res[:, :, :size, :]
1818 return numpy.transpose(res, (0, 2, 1)) return numpy.transpose(res, (1, 0, 3, 2))
1919 else: else:
2020 a = cst[:, :, :fft_length] a = cst[:, :, :fft_length]
2121 b = x[:fft_length] b = x[:, :fft_length, :]
22 a = numpy.expand_dims(a, 0)
23 b = numpy.expand_dims(b, 1)
2224 return numpy.matmul(a, b) res = numpy.matmul(a, b)
25 return numpy.transpose(res, (1, 0, 2, 3))
2326
"], "text/plain": [""]}, "execution_count": 31, "metadata": {}, "output_type": "execute_result"}], "source": ["import inspect\n", "text1 = inspect.getsource(dft_real)\n", "text2 = inspect.getsource(dft_real_d3)\n", "%codediff text1 text2 --verbose 1 --two 1"]}, {"cell_type": "code", "execution_count": 31, "id": "cd7e14d4", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 15/15 [00:00<00:00, 791.61it/s]\n"]}, {"data": {"text/html": ["\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
00def fft2d(mat, fft_length):def fft2d_d3(mat, fft_length):
11 mat = mat[:fft_length[0], :fft_length[1]] mat = mat[:, :fft_length[-2], :fft_length[-1]]
22 res = mat.copy() res = mat.copy()
33
44 # first FFT # first FFT
55 res = dft_real(res, fft_length=fft_length[1], transpose=True) res = dft_real_d3(res, fft_length=fft_length[-1], transpose=True)
66
77 # second FFT decomposed on FFT on real part and imaginary part # second FFT decomposed on FFT on real part and imaginary part
88 res2_real = dft_real(res[0], fft_length=fft_length[0], transpose=False) res2_real = dft_real_d3(res[0], fft_length=fft_length[-2], transpose=False)
99 res2_imag = dft_real(res[1], fft_length=fft_length[0], transpose=False) res2_imag = dft_real_d3(res[1], fft_length=fft_length[-2], transpose=False)
1010 res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]]) res2_imag2 = numpy.vstack([-res2_imag[1:2], res2_imag[:1]])
1111 res = res2_real + res2_imag2 res = res2_real + res2_imag2
1212 size = fft_length[1]//2 + 1 size = fft_length[-1]//2 + 1
1313 return res[:, :fft_length[0], :size] return res[:, :, :fft_length[-2], :size]
1414
"], "text/plain": [""]}, "execution_count": 32, "metadata": {}, "output_type": "execute_result"}], "source": ["text1 = inspect.getsource(fft2d)\n", "text2 = inspect.getsource(fft2d_d3)\n", "%codediff text1 text2 --verbose 1 --two 1"]}, {"cell_type": "code", "execution_count": 32, "id": "51e7a4f7", "metadata": {"scrolled": false}, "outputs": [], "source": ["def onnx_rfft_3d_1d(x, fft_length=None, transpose=True):\n", " if fft_length is None:\n", " raise RuntimeError(\"fft_length must be specified.\")\n", " \n", " size = fft_length // 2 + 1\n", " cst = dft_real_cst(fft_length, fft_length).astype(numpy.float32)\n", " if transpose:\n", " xt = npnx.transpose(x, (0, 2, 1))\n", " a = cst[:, :, :fft_length]\n", " b = xt[:, :fft_length, :]\n", " a = npnx.expand_dims(a, 0)\n", " b = npnx.expand_dims(b, 1)\n", " res = npnx.matmul(a, b)\n", " res2 = res[:, :size, :]\n", " return npnx.transpose(res2, (1, 0, 3, 2))\n", " else:\n", " a = cst[:, :, :fft_length]\n", " b = x[:, :fft_length, :]\n", " a = npnx.expand_dims(a, 0)\n", " b = npnx.expand_dims(b, 1)\n", " res = npnx.matmul(a, b)\n", " return npnx.transpose(res, (1, 0, 2, 3)) \n", " \n", "\n", "def onnx_rfft_3d_2d(x, fft_length=None):\n", " mat = x[:, :fft_length[-2], :fft_length[-1]]\n", " \n", " # first FFT\n", " res = onnx_rfft_3d_1d(mat, fft_length=fft_length[-1], transpose=True)\n", " \n", " # second FFT decomposed on FFT on real part and imaginary part\n", " res2_real = onnx_rfft_3d_1d(res[0], fft_length=fft_length[0], transpose=False)\n", " res2_imag = onnx_rfft_3d_1d(res[1], fft_length=fft_length[0], transpose=False) \n", " res2_imag2 = npnx.vstack(-res2_imag[1:2], res2_imag[:1])\n", " res = res2_real + res2_imag2\n", " size = fft_length[1]//2 + 1\n", " return res[:, :, :fft_length[-2], :size]\n", "\n", "\n", "@onnxnumpy_np(signature=NDArrayType((\"T:all\", ), dtypes_out=('T',)))\n", "def onnx_rfft_2d_any(x, fft_length=None):\n", " new_shape = npnx.concat(\n", " numpy.array([-1], dtype=numpy.int64), x.shape[-2:], axis=0)\n", " mat2 = x.reshape(new_shape)\n", " f2 = onnx_rfft_3d_2d(mat2, fft_length)\n", " new_shape = npnx.concat(\n", " numpy.array([2], dtype=numpy.int64), x.shape[:-2], f2.shape[-2:])\n", " return f2.reshape(new_shape)\n", "\n", "\n", "shape = (3, 1, 4)\n", "fft_length = (1, 4)\n", "rnd = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", "fft2d_cus = fft2d_any(rnd, fft_length)\n", "fft2d_onx = onnx_rfft_2d_any(rnd, fft_length=fft_length)\n", "almost_equal(fft2d_cus, fft2d_onx)"]}, {"cell_type": "markdown", "id": "37c45ae7", "metadata": {}, "source": ["Let's do the same comparison."]}, {"cell_type": "code", "execution_count": 33, "id": "11c1e596", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", "OK x.shape=(3, 1, 4) length=(1, 4) output shape=(3, 4) or (2, 3, 1, 3)\n", "OK x.shape=(3, 1, 4) length=(1, 2) output shape=(3, 4) or (2, 3, 1, 2)\n", "OK x.shape=(3, 1, 4) length=(1, 1) output shape=(3, 4) or (2, 3, 1, 1)\n", "OK x.shape=(5, 7) length=(5, 7) output shape=(3, 4) or (2, 5, 4)\n", "OK x.shape=(5, 7) length=(1, 7) output shape=(3, 4) or (2, 1, 4)\n", "OK x.shape=(5, 7) length=(2, 7) output shape=(3, 4) or (2, 2, 4)\n", "OK x.shape=(5, 7) length=(5, 2) output shape=(3, 4) or (2, 5, 2)\n", "OK x.shape=(5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n", "OK x.shape=(3, 5, 7) length=(5, 7) output shape=(3, 4) or (2, 3, 5, 4)\n", "OK x.shape=(3, 5, 7) length=(1, 7) output shape=(3, 4) or (2, 3, 1, 4)\n", "OK x.shape=(3, 5, 7) length=(2, 7) output shape=(3, 4) or (2, 3, 2, 4)\n", "OK x.shape=(3, 5, 7) length=(5, 2) output shape=(3, 4) or (2, 3, 5, 2)\n", "OK x.shape=(3, 5, 7) length=(3, 4) output shape=(3, 4) or (2, 3, 3, 3)\n", "OK x.shape=(7, 5) length=(7, 5) output shape=(3, 4) or (2, 7, 3)\n", "OK x.shape=(7, 5) length=(1, 5) output shape=(3, 4) or (2, 1, 3)\n", "OK x.shape=(7, 5) length=(2, 5) output shape=(3, 4) or (2, 2, 3)\n", "OK x.shape=(7, 5) length=(7, 2) output shape=(3, 4) or (2, 7, 2)\n", "OK x.shape=(7, 5) length=(3, 4) output shape=(3, 4) or (2, 3, 3)\n"]}], "source": ["for shape in [(3, 1, 4), (5, 7), (3, 5, 7), (7, 5)]:\n", " for fft_length in [shape[-2:], (1, shape[-1]),\n", " (min(2, shape[-2]), shape[-1]),\n", " (shape[-2], 2),\n", " (min(3, shape[-2]), min(4, shape[-2]))]:\n", " x = numpy.random.randn(*list(shape)).astype(numpy.float32)\n", " if len(fnp.shape) == 2:\n", " fn= numpy.expand_dims(fnp, 0)\n", " try:\n", " cus = fft2d_any(x, fft_length)\n", " except IndexError as e:\n", " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", " continue\n", " try:\n", " onx = onnx_rfft_2d_any(x, fft_length=fft_length)\n", " except IndexError as e:\n", " print(\"ERR x.shape=%r length=%r error=%r\" % (x.shape, fft_length, e))\n", " continue\n", " try:\n", " almost_equal(onx, cus)\n", " except (AssertionError, IndexError) as e:\n", " print(\"DIS x.shape=%r length=%r error=%r output shape=%r or %r\" % (\n", " x.shape, fft_length, e, fnp.shape, cus.shape))\n", " continue\n", " print(\"OK x.shape=%r length=%r output shape=%r or %r\" % (\n", " x.shape, fft_length, fnp.shape, cus.shape))"]}, {"cell_type": "markdown", "id": "d197467f", "metadata": {}, "source": ["There is one issue with ``fft_length=(1, 1)`` but that case is out of scope."]}, {"cell_type": "markdown", "id": "33b5897e", "metadata": {}, "source": ["### ONNX graph"]}, {"cell_type": "code", "execution_count": 34, "id": "d45e9a99", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 35, "metadata": {}, "output_type": "execute_result"}], "source": ["key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", "%onnxview onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_"]}, {"cell_type": "code", "execution_count": 35, "id": "2ab7a3d0", "metadata": {}, "outputs": [], "source": ["with open(\"fft2d_any.onnx\", \"wb\") as f:\n", " key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", " f.write(onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_.SerializeToString())"]}, {"cell_type": "markdown", "id": "3c17b577", "metadata": {}, "source": ["Let's check the intermediate results."]}, {"cell_type": "code", "execution_count": 36, "id": "9e5507f7", "metadata": {}, "outputs": [{"data": {"text/plain": ["FctVersion((numpy.float32,), ((1, 4),))"]}, "execution_count": 37, "metadata": {}, "output_type": "execute_result"}], "source": ["key = list(onnx_rfft_2d_any.signed_compiled)[0]\n", "key"]}, {"cell_type": "code", "execution_count": 37, "id": "376036f8", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["+ki='init': (1,) (dtype=int64 min=0 max=0)\n", "+ki='init_1': (1,) (dtype=int64 min=-2 max=-2)\n", "+ki='init_3': (1,) (dtype=int64 min=-1 max=-1)\n", "+ki='init_4': (2,) (dtype=int64 min=0 max=0)\n", "+ki='init_5': (2,) (dtype=int64 min=1 max=4)\n", "+ki='init_6': (2,) (dtype=int64 min=1 max=2)\n", "+ki='init_8': (1,) (dtype=int64 min=4 max=4)\n", "+ki='init_9': (1,) (dtype=int64 min=1 max=1)\n", "+ki='init_b11': (2, 4, 4) (dtype=float32 min=-1.0 max=1.0)\n", "+ki='init_b14': (1,) (dtype=int64 min=3 max=3)\n", "+ki='init_b16': () (dtype=int64 min=1 max=1)\n", "+ki='init_b21': (2, 1, 1) (dtype=float32 min=0.0 max=1.0)\n", "+ki='init_b23': () (dtype=int64 min=0 max=0)\n", "+ki='init_b28': (1,) (dtype=int64 min=2 max=2)\n", "+ki='init_b37': (2,) (dtype=int64 min=1 max=3)\n", "+ki='init_b38': (2,) (dtype=int64 min=2 max=3)\n", "-- OnnxInference: run 38 nodes\n", "Onnx-Shape(x) -> out_sha_0 (name='_shape')\n", "+kr='out_sha_0': (3,) (dtype=int64 min=1 max=4)\n", "Onnx-Shape(out_sha_0) -> out_sha_0_1 (name='_shape_1')\n", "+kr='out_sha_0_1': (1,) (dtype=int64 min=3 max=3)\n", "Onnx-Gather(out_sha_0_1, init) -> out_gat_0 (name='_gather')\n", "+kr='out_gat_0': (1,) (dtype=int64 min=3 max=3)\n", "Onnx-Slice(out_sha_0, init_1, out_gat_0, init) -> out_sli_0 (name='_slice')\n", "+kr='out_sli_0': (2,) (dtype=int64 min=1 max=4)\n", "Onnx-Concat(init_3, out_sli_0) -> out_con_0 (name='_concat')\n", "+kr='out_con_0': (3,) (dtype=int64 min=-1 max=4)\n", "Onnx-Reshape(x, out_con_0) -> out_res_0 (name='_reshape')\n", "+kr='out_res_0': (3, 1, 4) (dtype=float32 min=-2.0340726375579834 max=2.391742706298828)\n", "Onnx-Slice(out_res_0, init_4, init_5, init_6) -> out_sli_0_1 (name='_slice_1')\n", "+kr='out_sli_0_1': (3, 1, 4) (dtype=float32 min=-2.0340726375579834 max=2.391742706298828)\n", "Onnx-Transpose(out_sli_0_1) -> out_tra_0 (name='_transpose')\n", "+kr='out_tra_0': (3, 4, 1) (dtype=float32 min=-2.0340726375579834 max=2.391742706298828)\n", "Onnx-Slice(out_tra_0, init, init_8, init_9) -> out_sli_0_2 (name='_slice_2')\n", "+kr='out_sli_0_2': (3, 4, 1) (dtype=float32 min=-2.0340726375579834 max=2.391742706298828)\n", "Onnx-Unsqueeze(out_sli_0_2, init_9) -> out_uns_0 (name='_unsqueeze')\n", "+kr='out_uns_0': (3, 1, 4, 1) (dtype=float32 min=-2.0340726375579834 max=2.391742706298828)\n", "Onnx-Unsqueeze(init_b11, init) -> out_uns_0_1 (name='_unsqueeze_1')\n", "+kr='out_uns_0_1': (1, 2, 4, 4) (dtype=float32 min=-1.0 max=1.0)\n", "Onnx-MatMul(out_uns_0_1, out_uns_0) -> out_mat_0 (name='_matmul')\n", "+kr='out_mat_0': (3, 2, 4, 1) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Slice(out_mat_0, init, init_b14, init_9) -> out_sli_0_3 (name='_slice_3')\n", "+kr='out_sli_0_3': (3, 2, 4, 1) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Transpose(out_sli_0_3) -> out_tra_0_1 (name='_transpose_1')\n", "+kr='out_tra_0_1': (2, 3, 1, 4) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Gather(out_tra_0_1, init_b16) -> out_gat_0_1 (name='_gather_1')\n", "+kr='out_gat_0_1': (3, 1, 4) (dtype=float32 min=-2.054079532623291 max=2.054079532623291)\n", "Onnx-Slice(out_gat_0_1, init, init_9, init_9) -> out_sli_0_4 (name='_slice_4')\n", "+kr='out_sli_0_4': (3, 1, 4) (dtype=float32 min=-2.054079532623291 max=2.054079532623291)\n", "Onnx-Unsqueeze(out_sli_0_4, init_9) -> out_uns_0_2 (name='_unsqueeze_2')\n", "+kr='out_uns_0_2': (3, 1, 1, 4) (dtype=float32 min=-2.054079532623291 max=2.054079532623291)\n", "Onnx-Unsqueeze(init_b21, init) -> out_uns_0_3 (name='_unsqueeze_3')\n", "+kr='out_uns_0_3': (1, 2, 1, 1) (dtype=float32 min=0.0 max=1.0)\n", "Onnx-MatMul(out_uns_0_3, out_uns_0_2) -> out_mat_0_1 (name='_matmul_1')\n", "+kr='out_mat_0_1': (3, 2, 1, 4) (dtype=float32 min=-2.054079532623291 max=2.054079532623291)\n", "Onnx-Gather(out_tra_0_1, init_b23) -> out_gat_0_2 (name='_gather_2')\n", "+kr='out_gat_0_2': (3, 1, 4) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Transpose(out_mat_0_1) -> out_tra_0_2 (name='_transpose_2')\n", "+kr='out_tra_0_2': (2, 3, 1, 4) (dtype=float32 min=-2.054079532623291 max=2.054079532623291)\n", "Onnx-Slice(out_gat_0_2, init, init_9, init_9) -> out_sli_0_5 (name='_slice_5')\n", "+kr='out_sli_0_5': (3, 1, 4) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Slice(out_tra_0_2, init_9, init_b28, init) -> out_sli_0_6 (name='_slice_6')\n", "+kr='out_sli_0_6': (1, 3, 1, 4) (dtype=float32 min=0.0 max=0.0)\n", "Onnx-Unsqueeze(out_sli_0_5, init_9) -> out_uns_0_4 (name='_unsqueeze_4')\n", "+kr='out_uns_0_4': (3, 1, 1, 4) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Slice(out_tra_0_2, init, init_9, init) -> out_sli_0_7 (name='_slice_7')\n", "+kr='out_sli_0_7': (1, 3, 1, 4) (dtype=float32 min=-2.054079532623291 max=2.054079532623291)\n", "Onnx-Neg(out_sli_0_6) -> out_neg_0 (name='_neg')\n", "+kr='out_neg_0': (1, 3, 1, 4) (dtype=float32 min=-0.0 max=-0.0)\n", "Onnx-MatMul(out_uns_0_3, out_uns_0_4) -> out_mat_0_2 (name='_matmul_2')\n", "+kr='out_mat_0_2': (3, 2, 1, 4) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Concat(out_neg_0, out_sli_0_7) -> out_con_0_1 (name='_concat_1')\n", "+kr='out_con_0_1': (2, 3, 1, 4) (dtype=float32 min=-2.054079532623291 max=2.054079532623291)\n", "Onnx-Transpose(out_mat_0_2) -> out_tra_0_3 (name='_transpose_3')\n", "+kr='out_tra_0_3': (2, 3, 1, 4) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Add(out_tra_0_3, out_con_0_1) -> out_add_0 (name='_add')\n", "+kr='out_add_0': (2, 3, 1, 4) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Slice(out_add_0, init_4, init_b37, init_b38) -> out_sli_0_8 (name='_slice_8')\n", "+kr='out_sli_0_8': (2, 3, 1, 3) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n", "Onnx-Shape(out_sli_0_8) -> out_sha_0_2 (name='_shape_2')\n", "+kr='out_sha_0_2': (4,) (dtype=int64 min=1 max=3)\n", "Onnx-Shape(out_sha_0_2) -> out_sha_0_3 (name='_shape_3')\n", "+kr='out_sha_0_3': (1,) (dtype=int64 min=4 max=4)\n", "Onnx-Gather(out_sha_0_3, init) -> out_gat_0_3 (name='_gather_3')\n", "+kr='out_gat_0_3': (1,) (dtype=int64 min=4 max=4)\n", "Onnx-Slice(out_sha_0_2, init_1, out_gat_0_3, init) -> out_sli_0_9 (name='_slice_9')\n", "+kr='out_sli_0_9': (2,) (dtype=int64 min=1 max=3)\n", "Onnx-Slice(out_sha_0, init, init_1, init) -> out_sli_0_b10 (name='_slice_b10')\n", "+kr='out_sli_0_b10': (1,) (dtype=int64 min=3 max=3)\n", "Onnx-Concat(init_b28, out_sli_0_b10, out_sli_0_9) -> out_con_0_2 (name='_concat_2')\n", "+kr='out_con_0_2': (4,) (dtype=int64 min=1 max=3)\n", "Onnx-Reshape(out_sli_0_8, out_con_0_2) -> y (name='_reshape_1')\n", "+kr='y': (2, 3, 1, 3) (dtype=float32 min=-2.188795566558838 max=3.3646905422210693)\n"]}, {"data": {"text/plain": ["{'y': array([[[[-8.3439898e-01, 6.9026375e-01, 3.2907667e+00]],\n", " \n", " [[ 3.3646905e+00, -2.9031307e-01, -2.0941215e+00]],\n", " \n", " [[ 2.1246734e+00, 5.1293659e-01, -2.1887956e+00]]],\n", " \n", " \n", " [[[ 0.0000000e+00, -2.0055625e+00, 8.1667386e-16]],\n", " \n", " [[ 0.0000000e+00, 2.0540795e+00, -8.0671079e-16]],\n", " \n", " [[ 0.0000000e+00, -3.2617974e-01, -5.5504507e-16]]]],\n", " dtype=float32)}"]}, "execution_count": 38, "metadata": {}, "output_type": "execute_result"}], "source": ["from mlprodict.onnxrt import OnnxInference\n", "\n", "x = numpy.random.randn(3, 1, 4).astype(numpy.float32)\n", "onx = onnx_rfft_2d_any.signed_compiled[key].compiled.onnx_\n", "oinf = OnnxInference(onx)\n", "oinf.run({'x': x}, verbose=1, fLOG=print)"]}, {"cell_type": "code", "execution_count": 38, "id": "3843308e", "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5"}}, "nbformat": 4, "nbformat_minor": 5}