{"cells": [{"cell_type": "markdown", "id": "3231ec55", "metadata": {}, "source": ["# Infer operator computation cost\n", "\n", "This notebooks explores a way to predict the cost of operator Transpose based on some features."]}, {"cell_type": "code", "execution_count": 1, "id": "c57ea015", "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "code", "execution_count": 2, "id": "346ccb45", "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "code", "execution_count": 3, "id": "f86d2b4e", "metadata": {}, "outputs": [], "source": ["%load_ext mlprodict"]}, {"cell_type": "markdown", "id": "d632066f", "metadata": {}, "source": ["## ONNX graph and measures"]}, {"cell_type": "code", "execution_count": 4, "id": "db95c3c3", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["import numpy\n", "from skl2onnx.common.data_types import FloatTensorType\n", "from skl2onnx.algebra.onnx_ops import OnnxTranspose\n", "\n", "\n", "def create_onnx_graph(perm=(0, 1, 2, 3), target_opset=14):\n", " tr = OnnxTranspose('X', perm=perm, output_names=['Y'], op_version=target_opset)\n", " return tr.to_onnx({'X': FloatTensorType([None] * len(perm))})\n", "\n", "\n", "onx = create_onnx_graph()\n", "\n", "%onnxview onx"]}, {"cell_type": "code", "execution_count": 5, "id": "5c7b91e1", "metadata": {}, "outputs": [{"data": {"text/plain": ["(6, 5, 8, 7)"]}, "execution_count": 6, "metadata": {}, "output_type": "execute_result"}], "source": ["from mlprodict.onnxrt import OnnxInference\n", "\n", "onx = create_onnx_graph(perm=(1, 0, 3, 2))\n", "oinf = OnnxInference(onx)\n", "inputs = {'X': numpy.full((5, 6, 7, 8), 1, dtype=numpy.float32)}\n", "res = oinf.run(inputs)['Y']\n", "res.shape"]}, {"cell_type": "code", "execution_count": 6, "id": "161e6ca6", "metadata": {}, "outputs": [{"data": {"text/plain": ["(6, 5, 8, 7)"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["from onnxruntime import InferenceSession\n", "sess = InferenceSession(onx.SerializeToString())\n", "res = sess.run(None, inputs)[0]\n", "res.shape"]}, {"cell_type": "code", "execution_count": 7, "id": "cd35aa1d", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'average': 0.0024677738666666646,\n", " 'deviation': 0.00022911153911864325,\n", " 'min_exec': 0.0022292380000000023,\n", " 'max_exec': 0.003265080000000005,\n", " 'repeat': 30,\n", " 'number': 50,\n", " 'context_size': 232}"]}, "execution_count": 8, "metadata": {}, "output_type": "execute_result"}], "source": ["from cpyquickhelper.numbers.speed_measure import measure_time\n", "\n", "def measure_time_onnx(sess, X, number=50, repeat=30):\n", " inputs = {'X': X}\n", " return measure_time(lambda: sess.run(None, inputs), context=dict(sess=sess, inputs=inputs),\n", " div_by_number=True, number=number, repeat=repeat)\n", "\n", "X = numpy.random.random((3, 224, 224, 4)).astype(numpy.float32)\n", "measure_time_onnx(sess, X)"]}, {"cell_type": "markdown", "id": "811e19bf", "metadata": {}, "source": ["## Simulation to build a database"]}, {"cell_type": "markdown", "id": "21ab13cf", "metadata": {}, "source": ["### Many dimensions, many permutations"]}, {"cell_type": "code", "execution_count": 8, "id": "c4c1f75a", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:04<00:00, 5.73it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumbercontext_sizepermshaperatiodim
30.0000440.0000060.0000390.0000573050232(0, 2, 3, 1)(12, 13, 15, 18)0.7503164
10.0000480.0000030.0000450.0000583050232(0, 1, 3, 2)(12, 13, 15, 18)0.8208214
180.0000490.0000030.0000450.0000623050232(3, 0, 1, 2)(12, 13, 15, 18)0.8230704
90.0000490.0000010.0000480.0000533050232(1, 2, 3, 0)(12, 13, 15, 18)0.8306044
120.0000510.0000040.0000390.0000623050232(2, 0, 1, 3)(12, 13, 15, 18)0.8619944
40.0000520.0000050.0000470.0000733050232(0, 3, 1, 2)(12, 13, 15, 18)0.8897534
80.0000540.0000060.0000440.0000673050232(1, 2, 0, 3)(12, 13, 15, 18)0.9094774
20.0000540.0000070.0000490.0000813050232(0, 2, 1, 3)(12, 13, 15, 18)0.9223544
140.0000570.0000060.0000460.0000643050232(2, 1, 0, 3)(12, 13, 15, 18)0.9721984
00.0000590.0000190.0000340.0000933050232(0, 1, 2, 3)(12, 13, 15, 18)1.0000004
60.0000920.0000190.0000530.0001393050232(1, 0, 2, 3)(12, 13, 15, 18)1.5579034
110.0001360.0000200.0001190.0001863050232(1, 3, 2, 0)(12, 13, 15, 18)2.3015564
130.0001380.0000230.0001210.0001813050232(2, 0, 3, 1)(12, 13, 15, 18)2.3368264
100.0001380.0000180.0001180.0001763050232(1, 3, 0, 2)(12, 13, 15, 18)2.3461184
160.0001400.0000150.0001240.0001933050232(2, 3, 0, 1)(12, 13, 15, 18)2.3791684
150.0001440.0000190.0001190.0001963050232(2, 1, 3, 0)(12, 13, 15, 18)2.4433924
170.0001450.0000220.0001230.0001993050232(2, 3, 1, 0)(12, 13, 15, 18)2.4550984
230.0001450.0000170.0001250.0001963050232(3, 2, 1, 0)(12, 13, 15, 18)2.4564314
200.0001460.0000150.0001280.0001843050232(3, 1, 0, 2)(12, 13, 15, 18)2.4732504
220.0001500.0000170.0001270.0001703050232(3, 2, 0, 1)(12, 13, 15, 18)2.5398174
190.0001580.0000210.0001270.0001923050232(3, 0, 2, 1)(12, 13, 15, 18)2.6848764
210.0001640.0000450.0001240.0002313050232(3, 1, 2, 0)(12, 13, 15, 18)2.7781934
70.0002140.0000600.0001360.0002953050232(1, 0, 3, 2)(12, 13, 15, 18)3.6272404
50.0002150.0000710.0001430.0003403050232(0, 3, 2, 1)(12, 13, 15, 18)3.6401324
\n", "
"], "text/plain": [" average deviation min_exec max_exec repeat number context_size \\\n", "3 0.000044 0.000006 0.000039 0.000057 30 50 232 \n", "1 0.000048 0.000003 0.000045 0.000058 30 50 232 \n", "18 0.000049 0.000003 0.000045 0.000062 30 50 232 \n", "9 0.000049 0.000001 0.000048 0.000053 30 50 232 \n", "12 0.000051 0.000004 0.000039 0.000062 30 50 232 \n", "4 0.000052 0.000005 0.000047 0.000073 30 50 232 \n", "8 0.000054 0.000006 0.000044 0.000067 30 50 232 \n", "2 0.000054 0.000007 0.000049 0.000081 30 50 232 \n", "14 0.000057 0.000006 0.000046 0.000064 30 50 232 \n", "0 0.000059 0.000019 0.000034 0.000093 30 50 232 \n", "6 0.000092 0.000019 0.000053 0.000139 30 50 232 \n", "11 0.000136 0.000020 0.000119 0.000186 30 50 232 \n", "13 0.000138 0.000023 0.000121 0.000181 30 50 232 \n", "10 0.000138 0.000018 0.000118 0.000176 30 50 232 \n", "16 0.000140 0.000015 0.000124 0.000193 30 50 232 \n", "15 0.000144 0.000019 0.000119 0.000196 30 50 232 \n", "17 0.000145 0.000022 0.000123 0.000199 30 50 232 \n", "23 0.000145 0.000017 0.000125 0.000196 30 50 232 \n", "20 0.000146 0.000015 0.000128 0.000184 30 50 232 \n", "22 0.000150 0.000017 0.000127 0.000170 30 50 232 \n", "19 0.000158 0.000021 0.000127 0.000192 30 50 232 \n", "21 0.000164 0.000045 0.000124 0.000231 30 50 232 \n", "7 0.000214 0.000060 0.000136 0.000295 30 50 232 \n", "5 0.000215 0.000071 0.000143 0.000340 30 50 232 \n", "\n", " perm shape ratio dim \n", "3 (0, 2, 3, 1) (12, 13, 15, 18) 0.750316 4 \n", "1 (0, 1, 3, 2) (12, 13, 15, 18) 0.820821 4 \n", "18 (3, 0, 1, 2) (12, 13, 15, 18) 0.823070 4 \n", "9 (1, 2, 3, 0) (12, 13, 15, 18) 0.830604 4 \n", "12 (2, 0, 1, 3) (12, 13, 15, 18) 0.861994 4 \n", "4 (0, 3, 1, 2) (12, 13, 15, 18) 0.889753 4 \n", "8 (1, 2, 0, 3) (12, 13, 15, 18) 0.909477 4 \n", "2 (0, 2, 1, 3) (12, 13, 15, 18) 0.922354 4 \n", "14 (2, 1, 0, 3) (12, 13, 15, 18) 0.972198 4 \n", "0 (0, 1, 2, 3) (12, 13, 15, 18) 1.000000 4 \n", "6 (1, 0, 2, 3) (12, 13, 15, 18) 1.557903 4 \n", "11 (1, 3, 2, 0) (12, 13, 15, 18) 2.301556 4 \n", "13 (2, 0, 3, 1) (12, 13, 15, 18) 2.336826 4 \n", "10 (1, 3, 0, 2) (12, 13, 15, 18) 2.346118 4 \n", "16 (2, 3, 0, 1) (12, 13, 15, 18) 2.379168 4 \n", "15 (2, 1, 3, 0) (12, 13, 15, 18) 2.443392 4 \n", "17 (2, 3, 1, 0) (12, 13, 15, 18) 2.455098 4 \n", "23 (3, 2, 1, 0) (12, 13, 15, 18) 2.456431 4 \n", "20 (3, 1, 0, 2) (12, 13, 15, 18) 2.473250 4 \n", "22 (3, 2, 0, 1) (12, 13, 15, 18) 2.539817 4 \n", "19 (3, 0, 2, 1) (12, 13, 15, 18) 2.684876 4 \n", "21 (3, 1, 2, 0) (12, 13, 15, 18) 2.778193 4 \n", "7 (1, 0, 3, 2) (12, 13, 15, 18) 3.627240 4 \n", "5 (0, 3, 2, 1) (12, 13, 15, 18) 3.640132 4 "]}, "execution_count": 9, "metadata": {}, "output_type": "execute_result"}], "source": ["from itertools import permutations\n", "from tqdm import tqdm\n", "from pandas import DataFrame\n", "\n", "\n", "def process_shape(shape, rnd=False, number=50, repeat=30, bar=True):\n", " X = numpy.random.random(shape).astype(numpy.float32)\n", " obs = []\n", " perms = list(permutations(list(range(len(X.shape)))))\n", " baseline = None\n", " itergen = perms if (rnd or not bar) else tqdm(perms)\n", " for perm in itergen:\n", " if baseline is not None and rnd:\n", " if random.randint(0, 4) != 0:\n", " continue\n", " onx = create_onnx_graph(perm=perm)\n", " sess = InferenceSession(onx.SerializeToString())\n", " res = measure_time_onnx(sess, X, number=number, repeat=repeat)\n", " res['perm'] = perm\n", " res['shape'] = shape\n", " if baseline is None:\n", " baseline = res\n", " res[\"ratio\"] = res[\"average\"] / baseline[\"average\"]\n", " res['dim'] = len(shape)\n", " obs.append(res)\n", " return DataFrame(obs).sort_values('average')\n", "\n", "dfs = []\n", "df = process_shape((12, 13, 15, 18))\n", "dfs.append(df)\n", "df"]}, {"cell_type": "code", "execution_count": 9, "id": "daf805ba", "metadata": {"scrolled": false}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 6/6 [00:01<00:00, 4.70it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumbercontext_sizepermshaperatiodim
30.0000730.0000090.0000620.0000943050232(1, 2, 0)(43, 44, 45)0.9855133
00.0000740.0000090.0000650.0001093050232(0, 1, 2)(43, 44, 45)1.0000003
10.0000770.0000080.0000690.0001013050232(0, 2, 1)(43, 44, 45)1.0327593
40.0000970.0000040.0000830.0001103050232(2, 0, 1)(43, 44, 45)1.3009153
20.0001130.0000290.0000610.0001413050232(1, 0, 2)(43, 44, 45)1.5157113
50.0003750.0001210.0002920.0007503050232(2, 1, 0)(43, 44, 45)5.0543013
\n", "
"], "text/plain": [" average deviation min_exec max_exec repeat number context_size \\\n", "3 0.000073 0.000009 0.000062 0.000094 30 50 232 \n", "0 0.000074 0.000009 0.000065 0.000109 30 50 232 \n", "1 0.000077 0.000008 0.000069 0.000101 30 50 232 \n", "4 0.000097 0.000004 0.000083 0.000110 30 50 232 \n", "2 0.000113 0.000029 0.000061 0.000141 30 50 232 \n", "5 0.000375 0.000121 0.000292 0.000750 30 50 232 \n", "\n", " perm shape ratio dim \n", "3 (1, 2, 0) (43, 44, 45) 0.985513 3 \n", "0 (0, 1, 2) (43, 44, 45) 1.000000 3 \n", "1 (0, 2, 1) (43, 44, 45) 1.032759 3 \n", "4 (2, 0, 1) (43, 44, 45) 1.300915 3 \n", "2 (1, 0, 2) (43, 44, 45) 1.515711 3 \n", "5 (2, 1, 0) (43, 44, 45) 5.054301 3 "]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}], "source": ["df = process_shape((43, 44, 45))\n", "dfs.append(df)\n", "df"]}, {"cell_type": "code", "execution_count": 10, "id": "0abe8e18", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 6/6 [00:01<00:00, 3.05it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumbercontext_sizepermshaperatiodim
20.0001000.0000090.0000900.0001253050232(1, 0, 2)(3, 244, 244)0.9552033
00.0001050.0000160.0000780.0001383050232(0, 1, 2)(3, 244, 244)1.0000003
10.0001230.0000130.0001080.0001613050232(0, 2, 1)(3, 244, 244)1.1788273
40.0001240.0000170.0001080.0001713050232(2, 0, 1)(3, 244, 244)1.1856663
30.0001510.0000160.0001360.0001973050232(1, 2, 0)(3, 244, 244)1.4384463
50.0006720.0000830.0006260.0010303050232(2, 1, 0)(3, 244, 244)6.4181953
\n", "
"], "text/plain": [" average deviation min_exec max_exec repeat number context_size \\\n", "2 0.000100 0.000009 0.000090 0.000125 30 50 232 \n", "0 0.000105 0.000016 0.000078 0.000138 30 50 232 \n", "1 0.000123 0.000013 0.000108 0.000161 30 50 232 \n", "4 0.000124 0.000017 0.000108 0.000171 30 50 232 \n", "3 0.000151 0.000016 0.000136 0.000197 30 50 232 \n", "5 0.000672 0.000083 0.000626 0.001030 30 50 232 \n", "\n", " perm shape ratio dim \n", "2 (1, 0, 2) (3, 244, 244) 0.955203 3 \n", "0 (0, 1, 2) (3, 244, 244) 1.000000 3 \n", "1 (0, 2, 1) (3, 244, 244) 1.178827 3 \n", "4 (2, 0, 1) (3, 244, 244) 1.185666 3 \n", "3 (1, 2, 0) (3, 244, 244) 1.438446 3 \n", "5 (2, 1, 0) (3, 244, 244) 6.418195 3 "]}, "execution_count": 11, "metadata": {}, "output_type": "execute_result"}], "source": ["df = process_shape((3, 244, 244))\n", "dfs.append(df)\n", "df"]}, {"cell_type": "code", "execution_count": 11, "id": "d91a9d22", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:19<00:00, 1.26it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumbercontext_sizepermshaperatiodim
40.0000920.0000080.0000780.0001073050232(0, 3, 1, 2)(3, 244, 244, 1)0.8599034
00.0001070.0000180.0000840.0001573050232(0, 1, 2, 3)(3, 244, 244, 1)1.0000004
60.0001240.0000680.0000880.0003233050232(1, 0, 2, 3)(3, 244, 244, 1)1.1624564
120.0001260.0000170.0001070.0001853050232(2, 0, 1, 3)(3, 244, 244, 1)1.1809964
30.0001300.0000090.0001200.0001633050232(0, 2, 3, 1)(3, 244, 244, 1)1.2100774
180.0001370.0000470.0000900.0002503050232(3, 0, 1, 2)(3, 244, 244, 1)1.2766424
10.0001470.0000170.0001060.0001753050232(0, 1, 3, 2)(3, 244, 244, 1)1.3699784
80.0001850.0000170.0001640.0002463050232(1, 2, 0, 3)(3, 244, 244, 1)1.7253914
90.0001890.0000440.0001420.0002653050232(1, 2, 3, 0)(3, 244, 244, 1)1.7669054
20.0002010.0000540.0001210.0002893050232(0, 2, 1, 3)(3, 244, 244, 1)1.8788024
70.0005220.0000610.0004570.0007333050232(1, 0, 3, 2)(3, 244, 244, 1)4.8740094
100.0005330.0001570.0004560.0011283050232(1, 3, 0, 2)(3, 244, 244, 1)4.9739164
130.0006400.0001890.0004770.0012893050232(2, 0, 3, 1)(3, 244, 244, 1)5.9807964
160.0006600.0001060.0005030.0008603050232(2, 3, 0, 1)(3, 244, 244, 1)6.1677034
50.0006920.0001360.0005290.0010213050232(0, 3, 2, 1)(3, 244, 244, 1)6.4607594
190.0007490.0002060.0005080.0013243050232(3, 0, 2, 1)(3, 244, 244, 1)6.9963624
140.0007540.0001050.0006330.0009943050232(2, 1, 0, 3)(3, 244, 244, 1)7.0410074
110.0007910.0002640.0005610.0013863050232(1, 3, 2, 0)(3, 244, 244, 1)7.3894314
150.0008180.0002780.0006250.0015223050232(2, 1, 3, 0)(3, 244, 244, 1)7.6346464
170.0008930.0002120.0006460.0014773050232(2, 3, 1, 0)(3, 244, 244, 1)8.3399264
210.0009440.0002930.0005810.0016263050232(3, 1, 2, 0)(3, 244, 244, 1)8.8147854
200.0009760.0003470.0005840.0017423050232(3, 1, 0, 2)(3, 244, 244, 1)9.1122434
220.0010110.0003370.0005440.0018103050232(3, 2, 0, 1)(3, 244, 244, 1)9.4374034
230.0011280.0003220.0006290.0017373050232(3, 2, 1, 0)(3, 244, 244, 1)10.5301824
\n", "
"], "text/plain": [" average deviation min_exec max_exec repeat number context_size \\\n", "4 0.000092 0.000008 0.000078 0.000107 30 50 232 \n", "0 0.000107 0.000018 0.000084 0.000157 30 50 232 \n", "6 0.000124 0.000068 0.000088 0.000323 30 50 232 \n", "12 0.000126 0.000017 0.000107 0.000185 30 50 232 \n", "3 0.000130 0.000009 0.000120 0.000163 30 50 232 \n", "18 0.000137 0.000047 0.000090 0.000250 30 50 232 \n", "1 0.000147 0.000017 0.000106 0.000175 30 50 232 \n", "8 0.000185 0.000017 0.000164 0.000246 30 50 232 \n", "9 0.000189 0.000044 0.000142 0.000265 30 50 232 \n", "2 0.000201 0.000054 0.000121 0.000289 30 50 232 \n", "7 0.000522 0.000061 0.000457 0.000733 30 50 232 \n", "10 0.000533 0.000157 0.000456 0.001128 30 50 232 \n", "13 0.000640 0.000189 0.000477 0.001289 30 50 232 \n", "16 0.000660 0.000106 0.000503 0.000860 30 50 232 \n", "5 0.000692 0.000136 0.000529 0.001021 30 50 232 \n", "19 0.000749 0.000206 0.000508 0.001324 30 50 232 \n", "14 0.000754 0.000105 0.000633 0.000994 30 50 232 \n", "11 0.000791 0.000264 0.000561 0.001386 30 50 232 \n", "15 0.000818 0.000278 0.000625 0.001522 30 50 232 \n", "17 0.000893 0.000212 0.000646 0.001477 30 50 232 \n", "21 0.000944 0.000293 0.000581 0.001626 30 50 232 \n", "20 0.000976 0.000347 0.000584 0.001742 30 50 232 \n", "22 0.001011 0.000337 0.000544 0.001810 30 50 232 \n", "23 0.001128 0.000322 0.000629 0.001737 30 50 232 \n", "\n", " perm shape ratio dim \n", "4 (0, 3, 1, 2) (3, 244, 244, 1) 0.859903 4 \n", "0 (0, 1, 2, 3) (3, 244, 244, 1) 1.000000 4 \n", "6 (1, 0, 2, 3) (3, 244, 244, 1) 1.162456 4 \n", "12 (2, 0, 1, 3) (3, 244, 244, 1) 1.180996 4 \n", "3 (0, 2, 3, 1) (3, 244, 244, 1) 1.210077 4 \n", "18 (3, 0, 1, 2) (3, 244, 244, 1) 1.276642 4 \n", "1 (0, 1, 3, 2) (3, 244, 244, 1) 1.369978 4 \n", "8 (1, 2, 0, 3) (3, 244, 244, 1) 1.725391 4 \n", "9 (1, 2, 3, 0) (3, 244, 244, 1) 1.766905 4 \n", "2 (0, 2, 1, 3) (3, 244, 244, 1) 1.878802 4 \n", "7 (1, 0, 3, 2) (3, 244, 244, 1) 4.874009 4 \n", "10 (1, 3, 0, 2) (3, 244, 244, 1) 4.973916 4 \n", "13 (2, 0, 3, 1) (3, 244, 244, 1) 5.980796 4 \n", "16 (2, 3, 0, 1) (3, 244, 244, 1) 6.167703 4 \n", "5 (0, 3, 2, 1) (3, 244, 244, 1) 6.460759 4 \n", "19 (3, 0, 2, 1) (3, 244, 244, 1) 6.996362 4 \n", "14 (2, 1, 0, 3) (3, 244, 244, 1) 7.041007 4 \n", "11 (1, 3, 2, 0) (3, 244, 244, 1) 7.389431 4 \n", "15 (2, 1, 3, 0) (3, 244, 244, 1) 7.634646 4 \n", "17 (2, 3, 1, 0) (3, 244, 244, 1) 8.339926 4 \n", "21 (3, 1, 2, 0) (3, 244, 244, 1) 8.814785 4 \n", "20 (3, 1, 0, 2) (3, 244, 244, 1) 9.112243 4 \n", "22 (3, 2, 0, 1) (3, 244, 244, 1) 9.437403 4 \n", "23 (3, 2, 1, 0) (3, 244, 244, 1) 10.530182 4 "]}, "execution_count": 12, "metadata": {}, "output_type": "execute_result"}], "source": ["df = process_shape((3, 244, 244, 1))\n", "dfs.append(df)\n", "df"]}, {"cell_type": "code", "execution_count": 12, "id": "2f36031e", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:22<00:00, 1.07it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumbercontext_sizepermshaperatiodim
80.0000920.0000140.0000780.0001323050232(1, 2, 0, 3)(1, 244, 244, 3)0.7530094
60.0000980.0000130.0000830.0001423050232(1, 0, 2, 3)(1, 244, 244, 3)0.8028084
90.0001070.0000180.0000750.0001373050232(1, 2, 3, 0)(1, 244, 244, 3)0.8739324
30.0001150.0000150.0000920.0001473050232(0, 2, 3, 1)(1, 244, 244, 3)0.9406064
00.0001220.0000280.0000940.0002013050232(0, 1, 2, 3)(1, 244, 244, 3)1.0000004
10.0001940.0000360.0001600.0003113050232(0, 1, 3, 2)(1, 244, 244, 3)1.5854794
40.0001950.0000190.0001630.0002583050232(0, 3, 1, 2)(1, 244, 244, 3)1.5987704
180.0002350.0000580.0001720.0003453050232(3, 0, 1, 2)(1, 244, 244, 3)1.9236544
20.0004080.0001560.0002290.0007183050232(0, 2, 1, 3)(1, 244, 244, 3)3.3454064
120.0005130.0002150.0003000.0014303050232(2, 0, 1, 3)(1, 244, 244, 3)4.2054774
100.0005580.0001310.0004580.0010233050232(1, 3, 0, 2)(1, 244, 244, 3)4.5726584
70.0006040.0001880.0004710.0010653050232(1, 0, 3, 2)(1, 244, 244, 3)4.9479374
140.0006200.0001420.0004100.0011213050232(2, 1, 0, 3)(1, 244, 244, 3)5.0783874
230.0006790.0000970.0005900.0009283050232(3, 2, 1, 0)(1, 244, 244, 3)5.5618884
220.0007100.0001610.0006200.0013903050232(3, 2, 0, 1)(1, 244, 244, 3)5.8180894
170.0007370.0002400.0004930.0011743050232(2, 3, 1, 0)(1, 244, 244, 3)6.0401894
110.0008240.0002880.0005150.0018793050232(1, 3, 2, 0)(1, 244, 244, 3)6.7526634
210.0009130.0002160.0006130.0014103050232(3, 1, 2, 0)(1, 244, 244, 3)7.4763784
200.0009180.0003280.0005720.0020793050232(3, 1, 0, 2)(1, 244, 244, 3)7.5214814
160.0010570.0006090.0005020.0027023050232(2, 3, 0, 1)(1, 244, 244, 3)8.6570764
50.0010610.0006120.0005390.0037903050232(0, 3, 2, 1)(1, 244, 244, 3)8.6938704
190.0012120.0004170.0007190.0025613050232(3, 0, 2, 1)(1, 244, 244, 3)9.9293084
150.0013110.0005050.0008560.0033773050232(2, 1, 3, 0)(1, 244, 244, 3)10.7393984
130.0014330.0005050.0007210.0023353050232(2, 0, 3, 1)(1, 244, 244, 3)11.7407724
\n", "
"], "text/plain": [" average deviation min_exec max_exec repeat number context_size \\\n", "8 0.000092 0.000014 0.000078 0.000132 30 50 232 \n", "6 0.000098 0.000013 0.000083 0.000142 30 50 232 \n", "9 0.000107 0.000018 0.000075 0.000137 30 50 232 \n", "3 0.000115 0.000015 0.000092 0.000147 30 50 232 \n", "0 0.000122 0.000028 0.000094 0.000201 30 50 232 \n", "1 0.000194 0.000036 0.000160 0.000311 30 50 232 \n", "4 0.000195 0.000019 0.000163 0.000258 30 50 232 \n", "18 0.000235 0.000058 0.000172 0.000345 30 50 232 \n", "2 0.000408 0.000156 0.000229 0.000718 30 50 232 \n", "12 0.000513 0.000215 0.000300 0.001430 30 50 232 \n", "10 0.000558 0.000131 0.000458 0.001023 30 50 232 \n", "7 0.000604 0.000188 0.000471 0.001065 30 50 232 \n", "14 0.000620 0.000142 0.000410 0.001121 30 50 232 \n", "23 0.000679 0.000097 0.000590 0.000928 30 50 232 \n", "22 0.000710 0.000161 0.000620 0.001390 30 50 232 \n", "17 0.000737 0.000240 0.000493 0.001174 30 50 232 \n", "11 0.000824 0.000288 0.000515 0.001879 30 50 232 \n", "21 0.000913 0.000216 0.000613 0.001410 30 50 232 \n", "20 0.000918 0.000328 0.000572 0.002079 30 50 232 \n", "16 0.001057 0.000609 0.000502 0.002702 30 50 232 \n", "5 0.001061 0.000612 0.000539 0.003790 30 50 232 \n", "19 0.001212 0.000417 0.000719 0.002561 30 50 232 \n", "15 0.001311 0.000505 0.000856 0.003377 30 50 232 \n", "13 0.001433 0.000505 0.000721 0.002335 30 50 232 \n", "\n", " perm shape ratio dim \n", "8 (1, 2, 0, 3) (1, 244, 244, 3) 0.753009 4 \n", "6 (1, 0, 2, 3) (1, 244, 244, 3) 0.802808 4 \n", "9 (1, 2, 3, 0) (1, 244, 244, 3) 0.873932 4 \n", "3 (0, 2, 3, 1) (1, 244, 244, 3) 0.940606 4 \n", "0 (0, 1, 2, 3) (1, 244, 244, 3) 1.000000 4 \n", "1 (0, 1, 3, 2) (1, 244, 244, 3) 1.585479 4 \n", "4 (0, 3, 1, 2) (1, 244, 244, 3) 1.598770 4 \n", "18 (3, 0, 1, 2) (1, 244, 244, 3) 1.923654 4 \n", "2 (0, 2, 1, 3) (1, 244, 244, 3) 3.345406 4 \n", "12 (2, 0, 1, 3) (1, 244, 244, 3) 4.205477 4 \n", "10 (1, 3, 0, 2) (1, 244, 244, 3) 4.572658 4 \n", "7 (1, 0, 3, 2) (1, 244, 244, 3) 4.947937 4 \n", "14 (2, 1, 0, 3) (1, 244, 244, 3) 5.078387 4 \n", "23 (3, 2, 1, 0) (1, 244, 244, 3) 5.561888 4 \n", "22 (3, 2, 0, 1) (1, 244, 244, 3) 5.818089 4 \n", "17 (2, 3, 1, 0) (1, 244, 244, 3) 6.040189 4 \n", "11 (1, 3, 2, 0) (1, 244, 244, 3) 6.752663 4 \n", "21 (3, 1, 2, 0) (1, 244, 244, 3) 7.476378 4 \n", "20 (3, 1, 0, 2) (1, 244, 244, 3) 7.521481 4 \n", "16 (2, 3, 0, 1) (1, 244, 244, 3) 8.657076 4 \n", "5 (0, 3, 2, 1) (1, 244, 244, 3) 8.693870 4 \n", "19 (3, 0, 2, 1) (1, 244, 244, 3) 9.929308 4 \n", "15 (2, 1, 3, 0) (1, 244, 244, 3) 10.739398 4 \n", "13 (2, 0, 3, 1) (1, 244, 244, 3) 11.740772 4 "]}, "execution_count": 13, "metadata": {}, "output_type": "execute_result"}], "source": ["df = process_shape((1, 244, 244, 3))\n", "dfs.append(df)\n", "df"]}, {"cell_type": "code", "execution_count": 13, "id": "9661a0ae", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:14<00:00, 1.62it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumbercontext_sizepermshaperatiodim
00.0010880.0000850.0009860.0012911515232(0, 1, 2, 3)(3, 244, 244, 3)1.0000004
40.0012270.0000880.0011520.0014741515232(0, 3, 1, 2)(3, 244, 244, 3)1.1281264
180.0012770.0001180.0010790.0014901515232(3, 0, 1, 2)(3, 244, 244, 3)1.1737214
60.0013110.0003200.0010070.0019251515232(1, 0, 2, 3)(3, 244, 244, 3)1.2051824
10.0014150.0003070.0012000.0024981515232(0, 1, 3, 2)(3, 244, 244, 3)1.3009014
30.0014260.0002210.0011910.0018631515232(0, 2, 3, 1)(3, 244, 244, 3)1.3113614
90.0015100.0004320.0011320.0024171515232(1, 2, 3, 0)(3, 244, 244, 3)1.3880684
80.0015520.0000300.0015000.0016021515232(1, 2, 0, 3)(3, 244, 244, 3)1.4271054
120.0017240.0001930.0014700.0021421515232(2, 0, 1, 3)(3, 244, 244, 3)1.5851554
20.0017900.0001910.0015660.0022381515232(0, 2, 1, 3)(3, 244, 244, 3)1.6457174
70.0025280.0001540.0023270.0029831515232(1, 0, 3, 2)(3, 244, 244, 3)2.3243844
190.0025710.0001860.0023830.0029221515232(3, 0, 2, 1)(3, 244, 244, 3)2.3634434
210.0025910.0002530.0024310.0034031515232(3, 1, 2, 0)(3, 244, 244, 3)2.3818604
220.0026980.0004120.0023460.0036891515232(3, 2, 0, 1)(3, 244, 244, 3)2.4803084
200.0028060.0007830.0021470.0042961515232(3, 1, 0, 2)(3, 244, 244, 3)2.5795174
160.0032120.0003040.0027730.0038511515232(2, 3, 0, 1)(3, 244, 244, 3)2.9530324
140.0032280.0007960.0020710.0047911515232(2, 1, 0, 3)(3, 244, 244, 3)2.9675234
110.0032570.0002870.0029120.0037391515232(1, 3, 2, 0)(3, 244, 244, 3)2.9940434
170.0035740.0004790.0030280.0050421515232(2, 3, 1, 0)(3, 244, 244, 3)3.2858424
100.0039420.0018600.0024460.0082411515232(1, 3, 0, 2)(3, 244, 244, 3)3.6241454
150.0042490.0012170.0031750.0080411515232(2, 1, 3, 0)(3, 244, 244, 3)3.9063614
50.0046850.0013430.0028270.0068681515232(0, 3, 2, 1)(3, 244, 244, 3)4.3070724
130.0055390.0021800.0029910.0096021515232(2, 0, 3, 1)(3, 244, 244, 3)5.0924224
230.0055750.0019300.0028760.0081571515232(3, 2, 1, 0)(3, 244, 244, 3)5.1255974
\n", "
"], "text/plain": [" average deviation min_exec max_exec repeat number context_size \\\n", "0 0.001088 0.000085 0.000986 0.001291 15 15 232 \n", "4 0.001227 0.000088 0.001152 0.001474 15 15 232 \n", "18 0.001277 0.000118 0.001079 0.001490 15 15 232 \n", "6 0.001311 0.000320 0.001007 0.001925 15 15 232 \n", "1 0.001415 0.000307 0.001200 0.002498 15 15 232 \n", "3 0.001426 0.000221 0.001191 0.001863 15 15 232 \n", "9 0.001510 0.000432 0.001132 0.002417 15 15 232 \n", "8 0.001552 0.000030 0.001500 0.001602 15 15 232 \n", "12 0.001724 0.000193 0.001470 0.002142 15 15 232 \n", "2 0.001790 0.000191 0.001566 0.002238 15 15 232 \n", "7 0.002528 0.000154 0.002327 0.002983 15 15 232 \n", "19 0.002571 0.000186 0.002383 0.002922 15 15 232 \n", "21 0.002591 0.000253 0.002431 0.003403 15 15 232 \n", "22 0.002698 0.000412 0.002346 0.003689 15 15 232 \n", "20 0.002806 0.000783 0.002147 0.004296 15 15 232 \n", "16 0.003212 0.000304 0.002773 0.003851 15 15 232 \n", "14 0.003228 0.000796 0.002071 0.004791 15 15 232 \n", "11 0.003257 0.000287 0.002912 0.003739 15 15 232 \n", "17 0.003574 0.000479 0.003028 0.005042 15 15 232 \n", "10 0.003942 0.001860 0.002446 0.008241 15 15 232 \n", "15 0.004249 0.001217 0.003175 0.008041 15 15 232 \n", "5 0.004685 0.001343 0.002827 0.006868 15 15 232 \n", "13 0.005539 0.002180 0.002991 0.009602 15 15 232 \n", "23 0.005575 0.001930 0.002876 0.008157 15 15 232 \n", "\n", " perm shape ratio dim \n", "0 (0, 1, 2, 3) (3, 244, 244, 3) 1.000000 4 \n", "4 (0, 3, 1, 2) (3, 244, 244, 3) 1.128126 4 \n", "18 (3, 0, 1, 2) (3, 244, 244, 3) 1.173721 4 \n", "6 (1, 0, 2, 3) (3, 244, 244, 3) 1.205182 4 \n", "1 (0, 1, 3, 2) (3, 244, 244, 3) 1.300901 4 \n", "3 (0, 2, 3, 1) (3, 244, 244, 3) 1.311361 4 \n", "9 (1, 2, 3, 0) (3, 244, 244, 3) 1.388068 4 \n", "8 (1, 2, 0, 3) (3, 244, 244, 3) 1.427105 4 \n", "12 (2, 0, 1, 3) (3, 244, 244, 3) 1.585155 4 \n", "2 (0, 2, 1, 3) (3, 244, 244, 3) 1.645717 4 \n", "7 (1, 0, 3, 2) (3, 244, 244, 3) 2.324384 4 \n", "19 (3, 0, 2, 1) (3, 244, 244, 3) 2.363443 4 \n", "21 (3, 1, 2, 0) (3, 244, 244, 3) 2.381860 4 \n", "22 (3, 2, 0, 1) (3, 244, 244, 3) 2.480308 4 \n", "20 (3, 1, 0, 2) (3, 244, 244, 3) 2.579517 4 \n", "16 (2, 3, 0, 1) (3, 244, 244, 3) 2.953032 4 \n", "14 (2, 1, 0, 3) (3, 244, 244, 3) 2.967523 4 \n", "11 (1, 3, 2, 0) (3, 244, 244, 3) 2.994043 4 \n", "17 (2, 3, 1, 0) (3, 244, 244, 3) 3.285842 4 \n", "10 (1, 3, 0, 2) (3, 244, 244, 3) 3.624145 4 \n", "15 (2, 1, 3, 0) (3, 244, 244, 3) 3.906361 4 \n", "5 (0, 3, 2, 1) (3, 244, 244, 3) 4.307072 4 \n", "13 (2, 0, 3, 1) (3, 244, 244, 3) 5.092422 4 \n", "23 (3, 2, 1, 0) (3, 244, 244, 3) 5.125597 4 "]}, "execution_count": 14, "metadata": {}, "output_type": "execute_result"}], "source": ["df = process_shape((3, 244, 244, 3), number=15, repeat=15)\n", "dfs.append(df)\n", "df"]}, {"cell_type": "code", "execution_count": 14, "id": "d0427aa3", "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 24/24 [00:34<00:00, 1.43s/it]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumbercontext_sizepermshaperatiodim
10.0022490.0001440.0020670.0026271515232(0, 1, 3, 2)(3, 244, 244, 6)0.6069614
30.0027110.0001710.0024580.0029951515232(0, 2, 3, 1)(3, 244, 244, 6)0.7317954
120.0027730.0006830.0022600.0041031515232(2, 0, 1, 3)(3, 244, 244, 6)0.7485784
40.0029530.0006770.0021870.0041321515232(0, 3, 1, 2)(3, 244, 244, 6)0.7970624
20.0032320.0009630.0023030.0050881515232(0, 2, 1, 3)(3, 244, 244, 6)0.8724274
60.0033630.0003720.0028830.0040251515232(1, 0, 2, 3)(3, 244, 244, 6)0.9078344
80.0033970.0002370.0028860.0038461515232(1, 2, 0, 3)(3, 244, 244, 6)0.9170114
90.0036530.0008740.0025670.0052441515232(1, 2, 3, 0)(3, 244, 244, 6)0.9860714
140.0036970.0001860.0034950.0041501515232(2, 1, 0, 3)(3, 244, 244, 6)0.9979014
00.0037050.0007970.0021110.0051641515232(0, 1, 2, 3)(3, 244, 244, 6)1.0000004
180.0037800.0008820.0027010.0054021515232(3, 0, 1, 2)(3, 244, 244, 6)1.0204324
100.0049380.0003670.0045320.0058441515232(1, 3, 0, 2)(3, 244, 244, 6)1.3330614
70.0059180.0010850.0045980.0083121515232(1, 0, 3, 2)(3, 244, 244, 6)1.5973574
130.0061060.0005560.0056190.0073051515232(2, 0, 3, 1)(3, 244, 244, 6)1.6483254
110.0067220.0018070.0050670.0112451515232(1, 3, 2, 0)(3, 244, 244, 6)1.8145524
200.0070710.0009820.0054540.0085591515232(3, 1, 0, 2)(3, 244, 244, 6)1.9086674
210.0074410.0017320.0061990.0121691515232(3, 1, 2, 0)(3, 244, 244, 6)2.0086354
150.0078150.0017570.0059320.0107791515232(2, 1, 3, 0)(3, 244, 244, 6)2.1094894
160.0085460.0013840.0058780.0106141515232(2, 3, 0, 1)(3, 244, 244, 6)2.3069514
50.0103390.0027890.0058780.0183011515232(0, 3, 2, 1)(3, 244, 244, 6)2.7908234
170.0106770.0014570.0085040.0140701515232(2, 3, 1, 0)(3, 244, 244, 6)2.8821914
230.0124210.0030520.0078180.0181061515232(3, 2, 1, 0)(3, 244, 244, 6)3.3527704
220.0134320.0044960.0065360.0212501515232(3, 2, 0, 1)(3, 244, 244, 6)3.6256804
190.0145790.0040260.0071440.0207391515232(3, 0, 2, 1)(3, 244, 244, 6)3.9354834
\n", "
"], "text/plain": [" average deviation min_exec max_exec repeat number context_size \\\n", "1 0.002249 0.000144 0.002067 0.002627 15 15 232 \n", "3 0.002711 0.000171 0.002458 0.002995 15 15 232 \n", "12 0.002773 0.000683 0.002260 0.004103 15 15 232 \n", "4 0.002953 0.000677 0.002187 0.004132 15 15 232 \n", "2 0.003232 0.000963 0.002303 0.005088 15 15 232 \n", "6 0.003363 0.000372 0.002883 0.004025 15 15 232 \n", "8 0.003397 0.000237 0.002886 0.003846 15 15 232 \n", "9 0.003653 0.000874 0.002567 0.005244 15 15 232 \n", "14 0.003697 0.000186 0.003495 0.004150 15 15 232 \n", "0 0.003705 0.000797 0.002111 0.005164 15 15 232 \n", "18 0.003780 0.000882 0.002701 0.005402 15 15 232 \n", "10 0.004938 0.000367 0.004532 0.005844 15 15 232 \n", "7 0.005918 0.001085 0.004598 0.008312 15 15 232 \n", "13 0.006106 0.000556 0.005619 0.007305 15 15 232 \n", "11 0.006722 0.001807 0.005067 0.011245 15 15 232 \n", "20 0.007071 0.000982 0.005454 0.008559 15 15 232 \n", "21 0.007441 0.001732 0.006199 0.012169 15 15 232 \n", "15 0.007815 0.001757 0.005932 0.010779 15 15 232 \n", "16 0.008546 0.001384 0.005878 0.010614 15 15 232 \n", "5 0.010339 0.002789 0.005878 0.018301 15 15 232 \n", "17 0.010677 0.001457 0.008504 0.014070 15 15 232 \n", "23 0.012421 0.003052 0.007818 0.018106 15 15 232 \n", "22 0.013432 0.004496 0.006536 0.021250 15 15 232 \n", "19 0.014579 0.004026 0.007144 0.020739 15 15 232 \n", "\n", " perm shape ratio dim \n", "1 (0, 1, 3, 2) (3, 244, 244, 6) 0.606961 4 \n", "3 (0, 2, 3, 1) (3, 244, 244, 6) 0.731795 4 \n", "12 (2, 0, 1, 3) (3, 244, 244, 6) 0.748578 4 \n", "4 (0, 3, 1, 2) (3, 244, 244, 6) 0.797062 4 \n", "2 (0, 2, 1, 3) (3, 244, 244, 6) 0.872427 4 \n", "6 (1, 0, 2, 3) (3, 244, 244, 6) 0.907834 4 \n", "8 (1, 2, 0, 3) (3, 244, 244, 6) 0.917011 4 \n", "9 (1, 2, 3, 0) (3, 244, 244, 6) 0.986071 4 \n", "14 (2, 1, 0, 3) (3, 244, 244, 6) 0.997901 4 \n", "0 (0, 1, 2, 3) (3, 244, 244, 6) 1.000000 4 \n", "18 (3, 0, 1, 2) (3, 244, 244, 6) 1.020432 4 \n", "10 (1, 3, 0, 2) (3, 244, 244, 6) 1.333061 4 \n", "7 (1, 0, 3, 2) (3, 244, 244, 6) 1.597357 4 \n", "13 (2, 0, 3, 1) (3, 244, 244, 6) 1.648325 4 \n", "11 (1, 3, 2, 0) (3, 244, 244, 6) 1.814552 4 \n", "20 (3, 1, 0, 2) (3, 244, 244, 6) 1.908667 4 \n", "21 (3, 1, 2, 0) (3, 244, 244, 6) 2.008635 4 \n", "15 (2, 1, 3, 0) (3, 244, 244, 6) 2.109489 4 \n", "16 (2, 3, 0, 1) (3, 244, 244, 6) 2.306951 4 \n", "5 (0, 3, 2, 1) (3, 244, 244, 6) 2.790823 4 \n", "17 (2, 3, 1, 0) (3, 244, 244, 6) 2.882191 4 \n", "23 (3, 2, 1, 0) (3, 244, 244, 6) 3.352770 4 \n", "22 (3, 2, 0, 1) (3, 244, 244, 6) 3.625680 4 \n", "19 (3, 0, 2, 1) (3, 244, 244, 6) 3.935483 4 "]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["df = process_shape((3, 244, 244, 6), number=15, repeat=15)\n", "dfs.append(df)\n", "df"]}, {"cell_type": "markdown", "id": "308e3d6d", "metadata": {}, "source": ["### Random cases"]}, {"cell_type": "code", "execution_count": 15, "id": "90fdb3be", "metadata": {"scrolled": false}, "outputs": [{"data": {"text/plain": ["7"]}, "execution_count": 16, "metadata": {}, "output_type": "execute_result"}], "source": ["import random\n", "\n", "if False: # comment out for more training data\n", " for i in tqdm(range(0, 30)):\n", " dim = random.randint(3, 5)\n", " total = 1e8\n", " while total > 1e6 or total < 0:\n", " if dim == 3:\n", " shape = [random.randint(3, 64), random.randint(3, 224), random.randint(3, 64)]\n", " elif dim == 4:\n", " shape = (\n", " [random.randint(3, 8)] + \n", " [random.randint(16, 224) for d in range(2)] +\n", " [random.randint(16, 64)])\n", " elif dim == 5:\n", " shape = (\n", " [random.randint(3, 8)] + \n", " [random.randint(16, 32) for d in range(3)] +\n", " [random.randint(16, 64)])\n", " else:\n", " raise NotImplementedError()\n", " ashape = numpy.array(shape, dtype=numpy.float64)\n", " total = numpy.prod(ashape)\n", "\n", " if total > 1000000:\n", " number, repeat = 2, 2\n", " elif total > 800000:\n", " number, repeat = 3, 3\n", " elif total > 500000:\n", " number, repeat = 5, 5\n", " elif total > 200000:\n", " number, repeat = 7, 7\n", " else:\n", " number, repeat = 10, 10\n", "\n", " df = process_shape(tuple(shape), number=number, repeat=repeat, bar=False)\n", " dfs.append(df)\n", "\n", " for i in range(len(shape)):\n", " shape2 = shape.copy()\n", " shape2[i] = 1\n", " df = process_shape(tuple(shape), number=number, repeat=repeat, bar=False)\n", " dfs.append(df)\n", " \n", "len(dfs)"]}, {"cell_type": "code", "execution_count": 16, "id": "771ad869", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execrepeatnumbercontext_sizepermshaperatiodim
1270.0103390.0027890.0058780.0183011515232(0, 3, 2, 1)(3, 244, 244, 6)2.7908234
1280.0106770.0014570.0085040.0140701515232(2, 3, 1, 0)(3, 244, 244, 6)2.8821914
1290.0124210.0030520.0078180.0181061515232(3, 2, 1, 0)(3, 244, 244, 6)3.3527704
1300.0134320.0044960.0065360.0212501515232(3, 2, 0, 1)(3, 244, 244, 6)3.6256804
1310.0145790.0040260.0071440.0207391515232(3, 0, 2, 1)(3, 244, 244, 6)3.9354834
\n", "
"], "text/plain": [" average deviation min_exec max_exec repeat number context_size \\\n", "127 0.010339 0.002789 0.005878 0.018301 15 15 232 \n", "128 0.010677 0.001457 0.008504 0.014070 15 15 232 \n", "129 0.012421 0.003052 0.007818 0.018106 15 15 232 \n", "130 0.013432 0.004496 0.006536 0.021250 15 15 232 \n", "131 0.014579 0.004026 0.007144 0.020739 15 15 232 \n", "\n", " perm shape ratio dim \n", "127 (0, 3, 2, 1) (3, 244, 244, 6) 2.790823 4 \n", "128 (2, 3, 1, 0) (3, 244, 244, 6) 2.882191 4 \n", "129 (3, 2, 1, 0) (3, 244, 244, 6) 3.352770 4 \n", "130 (3, 2, 0, 1) (3, 244, 244, 6) 3.625680 4 \n", "131 (3, 0, 2, 1) (3, 244, 244, 6) 3.935483 4 "]}, "execution_count": 17, "metadata": {}, "output_type": "execute_result"}], "source": ["import pandas\n", "\n", "data = pandas.concat(dfs, axis=0).reset_index(drop=True)\n", "data.tail()"]}, {"cell_type": "code", "execution_count": 17, "id": "0290b6d7", "metadata": {}, "outputs": [{"data": {"text/plain": ["(132, 11)"]}, "execution_count": 18, "metadata": {}, "output_type": "execute_result"}], "source": ["data.shape"]}, {"cell_type": "code", "execution_count": 18, "id": "49c7afc2", "metadata": {"scrolled": false}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ratio
minmaxmeanmedian
dimshape
3(3, 244, 244)0.9552036.4181952.0293891.182247
(43, 44, 45)0.9855135.0543011.8148671.166837
4(1, 244, 244, 3)0.75300911.7407725.0233015.013162
(3, 244, 244, 1)0.85990310.5301824.8826805.477356
(3, 244, 244, 3)1.0000005.1255972.4812872.372651
(3, 244, 244, 6)0.6069613.9354831.7041691.465209
(12, 13, 15, 18)0.7503163.6401321.8666912.319191
\n", "
"], "text/plain": [" ratio \n", " min max mean median\n", "dim shape \n", "3 (3, 244, 244) 0.955203 6.418195 2.029389 1.182247\n", " (43, 44, 45) 0.985513 5.054301 1.814867 1.166837\n", "4 (1, 244, 244, 3) 0.753009 11.740772 5.023301 5.013162\n", " (3, 244, 244, 1) 0.859903 10.530182 4.882680 5.477356\n", " (3, 244, 244, 3) 1.000000 5.125597 2.481287 2.372651\n", " (3, 244, 244, 6) 0.606961 3.935483 1.704169 1.465209\n", " (12, 13, 15, 18) 0.750316 3.640132 1.866691 2.319191"]}, "execution_count": 19, "metadata": {}, "output_type": "execute_result"}], "source": ["data[['dim', 'shape', 'ratio']].groupby(['dim', 'shape']).agg({'ratio': [min, max, numpy.mean, numpy.median]})"]}, {"cell_type": "markdown", "id": "275762e0", "metadata": {}, "source": ["## features\n", "\n"]}, {"cell_type": "markdown", "id": "44f4c4a0", "metadata": {}, "source": ["### Computing the features"]}, {"cell_type": "code", "execution_count": 19, "id": "8ecad9ae", "metadata": {}, "outputs": [{"data": {"text/plain": ["2"]}, "execution_count": 20, "metadata": {}, "output_type": "execute_result"}], "source": ["def _edit_distance(mot1, mot2):\n", " dist = {(-1, -1): 0}\n", " pred = {(-1, -1): None}\n", " if len(mot1) == 0:\n", " for j, d in enumerate(mot2):\n", " dist[-1, j] = dist[-1, j - 1] + 1\n", " pred[-1, j] = (-1, j - 1)\n", " dist[j, -1] = dist[j - 1, -1] + 1\n", " pred[j, -1] = (j - 1, -1)\n", " for i, c in enumerate(mot1):\n", " dist[i, -1] = dist[i - 1, -1] + 1\n", " pred[i, -1] = (i - 1, -1)\n", " dist[-1, i] = dist[-1, i - 1] + 1\n", " pred[-1, i] = (-1, i - 1)\n", " for j, d in enumerate(mot2):\n", " opt = []\n", " if (i - 1, j) in dist:\n", " x = dist[i - 1, j] + 1\n", " opt.append((x, (i - 1, j)))\n", " if (i, j - 1) in dist:\n", " x = dist[i, j - 1] + 1\n", " opt.append((x, (i, j - 1)))\n", " if (i - 1, j - 1) in dist:\n", " x = dist[i - 1, j - 1] + (1 if c != d else 0)\n", " opt.append((x, (i - 1, j - 1)))\n", " mi = min(opt)\n", " dist[i, j] = mi[0]\n", " pred[i, j] = mi[1]\n", "\n", " return dist[len(mot1) - 1, len(mot2) - 1]\n", "\n", "_edit_distance(\"abdc\", \"cbda\")"]}, {"cell_type": "code", "execution_count": 20, "id": "7c818814", "metadata": {}, "outputs": [{"data": {"text/plain": ["2"]}, "execution_count": 21, "metadata": {}, "output_type": "execute_result"}], "source": ["_edit_distance((0, 1, 2, 3), (0, 2, 1, 3))"]}, {"cell_type": "code", "execution_count": 21, "id": "79e165cf", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'size': 105,\n", " 'begin': -105,\n", " 'end': -105,\n", " 'middle': 1,\n", " 'dim': 3,\n", " 'discont': 0,\n", " 'end16': -105,\n", " 'end32': -105,\n", " 'rbegin': -1.0,\n", " 'rend': -1.0,\n", " 'rmiddle': 0.009523809523809525,\n", " 'rdiscont': -0.0,\n", " 'rend16': -1.0,\n", " 'rend32': -1.0,\n", " 'iend2': -1.0,\n", " 'ibegin2': -1.0,\n", " 'iend4': -1.0,\n", " 'ibegin4': -1.0,\n", " 'iend8': -1.0,\n", " 'ibegin8': -1.0,\n", " 'iend16': -1.0,\n", " 'ibegin16': -1.0,\n", " 'iend32': -1.0,\n", " 'ibegin32': -1.0,\n", " 'iend64': -1.0,\n", " 'ibegin64': -1.0,\n", " 'CST_': -1,\n", " 'dbegin': -3,\n", " 'dend': -3,\n", " 'rot': -1,\n", " 'rev': 0,\n", " 'edit': 0,\n", " 'redit': 0.0}"]}, "execution_count": 22, "metadata": {}, "output_type": "execute_result"}], "source": ["from math import log\n", "\n", "\n", "def _is_rotation(perm):\n", " t = tuple(perm)\n", " c = list(range(len(perm)))\n", " for i in range(len(c)):\n", " for k in range(len(c)):\n", " c[k] = (k + i) % len(c)\n", " if t == tuple(c):\n", " return True\n", " return False\n", "\n", "\n", "def _relu(x, origin=0):\n", " return origin if x < origin else x\n", "\n", "\n", "def compute_features(shape, perm): \n", " total = numpy.prod(numpy.array(shape, dtype=numpy.int64))\n", " \n", " begin = 1\n", " dbegin = 0\n", " for i, p in enumerate(perm):\n", " if p != i:\n", " break\n", " dbegin += 1\n", " begin *= shape[i]\n", " \n", " end = 1\n", " dend = 0\n", " for i in range(len(perm)-1, -1, -1):\n", " if perm[i] != i:\n", " break\n", " dend += 1\n", " end *= shape[i]\n", " \n", " dis_cont = 0\n", " for i in range(1, len(shape)):\n", " if perm[i] != perm[i-1] + 1:\n", " dis_cont += 1\n", " \n", " middle = max(1, int(total / (end * begin)))\n", " feat = dict(size=total, begin=begin, end=end, middle=middle,\n", " dim=len(shape), discont=dis_cont)\n", "\n", " for c in [16, 32]:\n", " feat[\"end%d\" % c] = _relu(end, c)\n", " \n", " keys = list(feat)\n", " for k in keys:\n", " if k in {'dim', 'cpu', 'size'}:\n", " continue\n", " feat['r%s' % k] = float(feat[k] / total)\n", " \n", " for c in [2, 4, 8, 16, 32, 64]:\n", " feat[\"iend%d\" % c] = float(end >= c)\n", " feat[\"ibegin%d\" % c] = float(begin >= c)\n", " \n", " # feat['CST'] = 1\n", " feat['CST_'] = -1\n", " feat['dbegin'] = - dbegin\n", " feat['dend'] = - dend\n", " \n", " keys = list(feat)\n", " for k in keys:\n", " if k.startswith('end') or k.startswith('begin'):\n", " feat[k] = - feat[k]\n", " elif k.startswith('rend') or k.startswith('rbegin'):\n", " feat[k] = - feat[k]\n", " elif k.startswith('iend') or k.startswith('ibegin'):\n", " feat[k] = - feat[k]\n", " elif k == \"rdiscont\":\n", " feat[k] = - feat[k]\n", "\n", " idp = list(range(len(perm)))\n", " feat[\"rot\"] = -1 if _is_rotation(perm) else 0\n", " feat[\"rev\"] = 1 if perm == tuple(idp[::-1]) else 0\n", " feat[\"edit\"] = _edit_distance(idp, perm)\n", " feat[\"redit\"] = feat[\"edit\"] / len(idp)\n", " return feat\n", "\n", "\n", "compute_features((3, 5, 7), (0, 1, 2))"]}, {"cell_type": "code", "execution_count": 22, "id": "75610130", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'size': 105,\n", " 'begin': -1,\n", " 'end': -1,\n", " 'middle': 105,\n", " 'dim': 3,\n", " 'discont': 2,\n", " 'end16': -16,\n", " 'end32': -32,\n", " 'rbegin': -0.009523809523809525,\n", " 'rend': -0.009523809523809525,\n", " 'rmiddle': 1.0,\n", " 'rdiscont': -0.01904761904761905,\n", " 'rend16': -0.1523809523809524,\n", " 'rend32': -0.3047619047619048,\n", " 'iend2': -0.0,\n", " 'ibegin2': -0.0,\n", " 'iend4': -0.0,\n", " 'ibegin4': -0.0,\n", " 'iend8': -0.0,\n", " 'ibegin8': -0.0,\n", " 'iend16': -0.0,\n", " 'ibegin16': -0.0,\n", " 'iend32': -0.0,\n", " 'ibegin32': -0.0,\n", " 'iend64': -0.0,\n", " 'ibegin64': -0.0,\n", " 'CST_': -1,\n", " 'dbegin': 0,\n", " 'dend': 0,\n", " 'rot': 0,\n", " 'rev': 1,\n", " 'edit': 2,\n", " 'redit': 0.6666666666666666}"]}, "execution_count": 23, "metadata": {}, "output_type": "execute_result"}], "source": ["compute_features((3, 5, 7), (2, 1, 0))"]}, {"cell_type": "code", "execution_count": 23, "id": "f6e81bd5", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'size': 105,\n", " 'begin': -1,\n", " 'end': -1,\n", " 'middle': 105,\n", " 'dim': 3,\n", " 'discont': 1,\n", " 'end16': -16,\n", " 'end32': -32,\n", " 'rbegin': -0.009523809523809525,\n", " 'rend': -0.009523809523809525,\n", " 'rmiddle': 1.0,\n", " 'rdiscont': -0.009523809523809525,\n", " 'rend16': -0.1523809523809524,\n", " 'rend32': -0.3047619047619048,\n", " 'iend2': -0.0,\n", " 'ibegin2': -0.0,\n", " 'iend4': -0.0,\n", " 'ibegin4': -0.0,\n", " 'iend8': -0.0,\n", " 'ibegin8': -0.0,\n", " 'iend16': -0.0,\n", " 'ibegin16': -0.0,\n", " 'iend32': -0.0,\n", " 'ibegin32': -0.0,\n", " 'iend64': -0.0,\n", " 'ibegin64': -0.0,\n", " 'CST_': -1,\n", " 'dbegin': 0,\n", " 'dend': 0,\n", " 'rot': -1,\n", " 'rev': 0,\n", " 'edit': 2,\n", " 'redit': 0.6666666666666666}"]}, "execution_count": 24, "metadata": {}, "output_type": "execute_result"}], "source": ["compute_features((3, 5, 7), (1, 2, 0))"]}, {"cell_type": "markdown", "id": "bd16d08b", "metadata": {}, "source": ["### Computing the features for all simulations"]}, {"cell_type": "code", "execution_count": 24, "id": "230d8ca9", "metadata": {"scrolled": false}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 132/132 [00:00<00:00, 9459.22it/s]\n", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 132/132 [00:00<00:00, 3601.95it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CST_begindbegindenddimdisconteditendend16end32...reditrendrend16rend32revrmiddlerotsizeyryt
127-1-3-10432-1-16-32...0.50-9.331422e-07-0.000015-0.0000300.333333010716482.7908230.010339
128-1-100424-1-16-32...1.00-9.331422e-07-0.000015-0.0000301.000000010716482.8821910.010677
129-1-100434-1-16-32...1.00-9.331422e-07-0.000015-0.0000311.000000010716483.3527700.012421
130-1-100424-1-16-32...1.00-9.331422e-07-0.000015-0.0000301.000000010716483.6256800.013432
131-1-100433-1-16-32...0.75-9.331422e-07-0.000015-0.0000301.000000010716483.9354830.014579
\n", "

5 rows \u00d7 35 columns

\n", "
"], "text/plain": [" CST_ begin dbegin dend dim discont edit end end16 end32 ... \\\n", "127 -1 -3 -1 0 4 3 2 -1 -16 -32 ... \n", "128 -1 -1 0 0 4 2 4 -1 -16 -32 ... \n", "129 -1 -1 0 0 4 3 4 -1 -16 -32 ... \n", "130 -1 -1 0 0 4 2 4 -1 -16 -32 ... \n", "131 -1 -1 0 0 4 3 3 -1 -16 -32 ... \n", "\n", " redit rend rend16 rend32 rev rmiddle rot size \\\n", "127 0.50 -9.331422e-07 -0.000015 -0.00003 0 0.333333 0 1071648 \n", "128 1.00 -9.331422e-07 -0.000015 -0.00003 0 1.000000 0 1071648 \n", "129 1.00 -9.331422e-07 -0.000015 -0.00003 1 1.000000 0 1071648 \n", "130 1.00 -9.331422e-07 -0.000015 -0.00003 0 1.000000 0 1071648 \n", "131 0.75 -9.331422e-07 -0.000015 -0.00003 0 1.000000 0 1071648 \n", "\n", " yr yt \n", "127 2.790823 0.010339 \n", "128 2.882191 0.010677 \n", "129 3.352770 0.012421 \n", "130 3.625680 0.013432 \n", "131 3.935483 0.014579 \n", "\n", "[5 rows x 35 columns]"]}, "execution_count": 25, "metadata": {}, "output_type": "execute_result"}], "source": ["def compute_features_dataframe(df):\n", " \n", " def merge(row):\n", " feat = compute_features(row['shape'], row['perm'])\n", " feat['yt'] = row['average']\n", " feat['yr'] = row['ratio']\n", " return feat\n", " \n", " rows = []\n", " for i in tqdm(range(df.shape[0])):\n", " rows.append(dict(shape=df.loc[i, \"shape\"], perm=df.loc[i, \"perm\"],\n", " average=df.loc[i, \"average\"], ratio=df.loc[i, \"ratio\"]))\n", " obs = []\n", " for row in tqdm(rows):\n", " obs.append(merge(row))\n", " return DataFrame(obs)\n", "\n", "fdata = compute_features_dataframe(data)\n", "col_sort = list(sorted(fdata.columns))\n", "fdata = fdata[col_sort]\n", "fdata.tail()"]}, {"cell_type": "markdown", "id": "fae10b08", "metadata": {}, "source": ["### correlations"]}, {"cell_type": "code", "execution_count": 25, "id": "7f58d17c", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
CST_begindbegindenddimdisconteditendend16end32...reditrendrend16rend32revrmiddlerotsizeyryt
CST_NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
beginNaN1.0000000.5968160.5964140.0141180.4049520.4051750.9999980.9999980.999998...0.4180220.6815730.6815730.6815730.0382160.2563490.325594-0.1335810.127658-0.008816
dbeginNaN0.5968161.0000000.6768990.0771620.4868870.6695980.5963840.5963740.596363...0.6903330.8318870.8318950.8319030.1116360.6059900.2980900.0164110.2913180.139951
dendNaN0.5964140.6768991.0000000.0771620.4868870.6695980.5969360.5969070.596881...0.6903330.8330590.8329750.8329240.1116360.6235820.2980900.0164110.3054890.155098
dimNaN0.0141180.0771620.0771621.0000000.3053200.2726140.0141530.0141450.014135...0.1159020.1604070.1604170.160414-0.1603570.1066930.2409460.2126850.1389610.192305
discontNaN0.4049520.4868870.4868870.3053201.0000000.5312540.4049710.4049610.404948...0.5042060.5942190.5942260.5942230.1501440.2258540.8239370.0649370.3881400.203342
editNaN0.4051750.6695980.6695980.2726140.5312541.0000000.4052230.4052040.405189...0.9846550.5946880.5946390.5946190.2085680.6525320.3389940.0579810.4642250.283262
endNaN0.9999980.5963840.5969360.0141530.4049710.4052231.0000001.0000001.000000...0.4180620.6815650.6815650.6815650.0382360.2564790.325559-0.1336650.127730-0.008844
end16NaN0.9999980.5963740.5969070.0141450.4049610.4052041.0000001.0000001.000000...0.4180440.6815500.6815500.6815500.0382310.2564510.325557-0.1336710.127716-0.008852
end32NaN0.9999980.5963630.5968810.0141350.4049480.4051891.0000001.0000001.000000...0.4180290.6815330.6815330.6815330.0382280.2564300.325552-0.1336770.127707-0.008859
ibegin16NaN0.4885860.8540560.5537920.1608000.4762250.5281320.4879380.4879300.487919...0.5339810.7158700.7158890.7159010.0782150.5223990.2838000.0374620.2543520.136416
ibegin2NaN0.2977790.7922250.3264180.0800330.2303930.5390820.2972850.2972810.297277...0.5486050.4361110.4361260.4361480.1283380.6856050.049586-0.0271980.3248510.109929
ibegin32NaN0.4885860.8540560.5537920.1608000.4762250.5281320.4879380.4879300.487919...0.5339810.7158700.7158890.7159010.0782150.5223990.2838000.0374620.2543520.136416
ibegin4NaN0.4200230.8141780.4742320.1144320.3886890.5179510.4194330.4194240.419415...0.5282840.6153850.6154710.6155840.0909850.5947740.2076510.1164600.2788100.179014
ibegin64NaN0.5106590.8693570.5864300.0833330.4885120.5333760.5099990.5099900.509979...0.5555290.7482430.7482530.7482570.0748330.5010690.3072070.0177240.2431280.124459
ibegin8NaN0.4200230.8141780.4742320.1144320.3886890.5179510.4194330.4194240.419415...0.5282840.6153850.6154710.6155840.0909850.5947740.2076510.1164600.2788100.179014
iend16NaN0.4058580.4524740.8079890.1815030.3836540.5173110.4066140.4065750.406542...0.5139170.5971650.5970610.5970690.0940320.6199280.1917510.1256890.3076330.182267
iend2NaN0.2973230.3264180.7922250.0800330.2303930.5390820.2979300.2978950.297872...0.5486050.4375410.4373980.4373380.1283380.7245620.049586-0.0271980.3370710.146321
iend32NaN0.4682240.5242980.8412770.2334080.4655930.5241670.4690610.4690290.468993...0.5148050.6887140.6886400.6885530.0815110.5445990.2624480.0496430.2619300.139099
iend4NaN0.3605970.4001190.7929630.1414210.3218780.5196340.3612900.3612510.361222...0.5211200.5305940.5304670.5304390.1058300.6733840.136300-0.0393550.3511510.127468
iend64NaN0.4879590.5537920.8540560.1608000.4762250.5281320.4888160.4887860.488752...0.5339810.7176770.7176120.7175340.0782150.5237460.2838000.0309380.2555670.127728
iend8NaN0.4058580.4524740.8079890.1815030.3836540.5173110.4066140.4065750.406542...0.5139170.5971650.5970610.5970690.0940320.6199280.1917510.1256890.3076330.182267
middleNaN0.1268960.3038680.3190570.1783170.1520950.3778740.1269600.1269470.126937...0.3559800.1864720.1866690.1869180.0529910.4679810.0009030.728990-0.0081200.821357
rbeginNaN0.6815760.8327940.8319330.1602960.5941710.5945680.6815640.6815490.681532...0.6134170.9999920.9999930.9999930.0561080.3763570.4776490.0343280.1874110.095471
rdiscontNaN-0.132163-0.158903-0.158903-0.077379-0.320270-0.168278-0.132191-0.132195-0.132192...-0.163660-0.193464-0.193106-0.192632-0.054527-0.001602-0.2656720.5518800.0045170.386893
reditNaN0.4180220.6903330.6903330.1159020.5042060.9846550.4180620.4180440.418029...1.0000000.6135170.6134660.6134460.2441340.6556580.3171060.0246510.4509280.256097
rendNaN0.6815730.8318870.8330590.1604070.5942190.5946880.6815650.6815500.681533...0.6135171.0000001.0000000.9999990.0561530.3766580.4775790.0344120.1875510.095557
rend16NaN0.6815730.8318950.8329750.1604170.5942260.5946390.6815650.6815500.681533...0.6134661.0000001.0000001.0000000.0561290.3765740.4776130.0346790.1875590.095755
rend32NaN0.6815730.8319030.8329240.1604140.5942230.5946190.6815650.6815500.681533...0.6134460.9999991.0000001.0000000.0561160.3765570.4776300.0350050.1876070.095996
revNaN0.0382160.1116360.111636-0.1603570.1501440.2085680.0382360.0382310.038228...0.2441340.0561530.0561290.0561161.0000000.1804700.117200-0.0341060.2183870.094260
rmiddleNaN0.2563490.6059900.6235820.1066930.2258540.6525320.2564790.2564510.256430...0.6556580.3766580.3765740.3765570.1804701.000000-0.064351-0.0137710.4689250.195497
rotNaN0.3255940.2980900.2980900.2409460.8239370.3389940.3255590.3255570.325552...0.3171060.4775790.4776130.4776300.117200-0.0643511.0000000.0512460.2432940.126195
sizeNaN-0.1335810.0164110.0164110.2126850.0649370.057981-0.133665-0.133671-0.133677...0.0246510.0344120.0346790.035005-0.034106-0.0137710.0512461.000000-0.2362890.805926
yrNaN0.1276580.2913180.3054890.1389610.3881400.4642250.1277300.1277160.127707...0.4509280.1875510.1875590.1876070.2183870.4689250.243294-0.2362891.000000-0.013907
ytNaN-0.0088160.1399510.1550980.1923050.2033420.283262-0.008844-0.008852-0.008859...0.2560970.0955570.0957550.0959960.0942600.1954970.1261950.805926-0.0139071.000000
\n", "

35 rows \u00d7 35 columns

\n", "
"], "text/plain": [" CST_ begin dbegin dend dim discont edit \\\n", "CST_ NaN NaN NaN NaN NaN NaN NaN \n", "begin NaN 1.000000 0.596816 0.596414 0.014118 0.404952 0.405175 \n", "dbegin NaN 0.596816 1.000000 0.676899 0.077162 0.486887 0.669598 \n", "dend NaN 0.596414 0.676899 1.000000 0.077162 0.486887 0.669598 \n", "dim NaN 0.014118 0.077162 0.077162 1.000000 0.305320 0.272614 \n", "discont NaN 0.404952 0.486887 0.486887 0.305320 1.000000 0.531254 \n", "edit NaN 0.405175 0.669598 0.669598 0.272614 0.531254 1.000000 \n", "end NaN 0.999998 0.596384 0.596936 0.014153 0.404971 0.405223 \n", "end16 NaN 0.999998 0.596374 0.596907 0.014145 0.404961 0.405204 \n", "end32 NaN 0.999998 0.596363 0.596881 0.014135 0.404948 0.405189 \n", "ibegin16 NaN 0.488586 0.854056 0.553792 0.160800 0.476225 0.528132 \n", "ibegin2 NaN 0.297779 0.792225 0.326418 0.080033 0.230393 0.539082 \n", "ibegin32 NaN 0.488586 0.854056 0.553792 0.160800 0.476225 0.528132 \n", "ibegin4 NaN 0.420023 0.814178 0.474232 0.114432 0.388689 0.517951 \n", "ibegin64 NaN 0.510659 0.869357 0.586430 0.083333 0.488512 0.533376 \n", "ibegin8 NaN 0.420023 0.814178 0.474232 0.114432 0.388689 0.517951 \n", "iend16 NaN 0.405858 0.452474 0.807989 0.181503 0.383654 0.517311 \n", "iend2 NaN 0.297323 0.326418 0.792225 0.080033 0.230393 0.539082 \n", "iend32 NaN 0.468224 0.524298 0.841277 0.233408 0.465593 0.524167 \n", "iend4 NaN 0.360597 0.400119 0.792963 0.141421 0.321878 0.519634 \n", "iend64 NaN 0.487959 0.553792 0.854056 0.160800 0.476225 0.528132 \n", "iend8 NaN 0.405858 0.452474 0.807989 0.181503 0.383654 0.517311 \n", "middle NaN 0.126896 0.303868 0.319057 0.178317 0.152095 0.377874 \n", "rbegin NaN 0.681576 0.832794 0.831933 0.160296 0.594171 0.594568 \n", "rdiscont NaN -0.132163 -0.158903 -0.158903 -0.077379 -0.320270 -0.168278 \n", "redit NaN 0.418022 0.690333 0.690333 0.115902 0.504206 0.984655 \n", "rend NaN 0.681573 0.831887 0.833059 0.160407 0.594219 0.594688 \n", "rend16 NaN 0.681573 0.831895 0.832975 0.160417 0.594226 0.594639 \n", "rend32 NaN 0.681573 0.831903 0.832924 0.160414 0.594223 0.594619 \n", "rev NaN 0.038216 0.111636 0.111636 -0.160357 0.150144 0.208568 \n", "rmiddle NaN 0.256349 0.605990 0.623582 0.106693 0.225854 0.652532 \n", "rot NaN 0.325594 0.298090 0.298090 0.240946 0.823937 0.338994 \n", "size NaN -0.133581 0.016411 0.016411 0.212685 0.064937 0.057981 \n", "yr NaN 0.127658 0.291318 0.305489 0.138961 0.388140 0.464225 \n", "yt NaN -0.008816 0.139951 0.155098 0.192305 0.203342 0.283262 \n", "\n", " end end16 end32 ... redit rend rend16 \\\n", "CST_ NaN NaN NaN ... NaN NaN NaN \n", "begin 0.999998 0.999998 0.999998 ... 0.418022 0.681573 0.681573 \n", "dbegin 0.596384 0.596374 0.596363 ... 0.690333 0.831887 0.831895 \n", "dend 0.596936 0.596907 0.596881 ... 0.690333 0.833059 0.832975 \n", "dim 0.014153 0.014145 0.014135 ... 0.115902 0.160407 0.160417 \n", "discont 0.404971 0.404961 0.404948 ... 0.504206 0.594219 0.594226 \n", "edit 0.405223 0.405204 0.405189 ... 0.984655 0.594688 0.594639 \n", "end 1.000000 1.000000 1.000000 ... 0.418062 0.681565 0.681565 \n", "end16 1.000000 1.000000 1.000000 ... 0.418044 0.681550 0.681550 \n", "end32 1.000000 1.000000 1.000000 ... 0.418029 0.681533 0.681533 \n", "ibegin16 0.487938 0.487930 0.487919 ... 0.533981 0.715870 0.715889 \n", "ibegin2 0.297285 0.297281 0.297277 ... 0.548605 0.436111 0.436126 \n", "ibegin32 0.487938 0.487930 0.487919 ... 0.533981 0.715870 0.715889 \n", "ibegin4 0.419433 0.419424 0.419415 ... 0.528284 0.615385 0.615471 \n", "ibegin64 0.509999 0.509990 0.509979 ... 0.555529 0.748243 0.748253 \n", "ibegin8 0.419433 0.419424 0.419415 ... 0.528284 0.615385 0.615471 \n", "iend16 0.406614 0.406575 0.406542 ... 0.513917 0.597165 0.597061 \n", "iend2 0.297930 0.297895 0.297872 ... 0.548605 0.437541 0.437398 \n", "iend32 0.469061 0.469029 0.468993 ... 0.514805 0.688714 0.688640 \n", "iend4 0.361290 0.361251 0.361222 ... 0.521120 0.530594 0.530467 \n", "iend64 0.488816 0.488786 0.488752 ... 0.533981 0.717677 0.717612 \n", "iend8 0.406614 0.406575 0.406542 ... 0.513917 0.597165 0.597061 \n", "middle 0.126960 0.126947 0.126937 ... 0.355980 0.186472 0.186669 \n", "rbegin 0.681564 0.681549 0.681532 ... 0.613417 0.999992 0.999993 \n", "rdiscont -0.132191 -0.132195 -0.132192 ... -0.163660 -0.193464 -0.193106 \n", "redit 0.418062 0.418044 0.418029 ... 1.000000 0.613517 0.613466 \n", "rend 0.681565 0.681550 0.681533 ... 0.613517 1.000000 1.000000 \n", "rend16 0.681565 0.681550 0.681533 ... 0.613466 1.000000 1.000000 \n", "rend32 0.681565 0.681550 0.681533 ... 0.613446 0.999999 1.000000 \n", "rev 0.038236 0.038231 0.038228 ... 0.244134 0.056153 0.056129 \n", "rmiddle 0.256479 0.256451 0.256430 ... 0.655658 0.376658 0.376574 \n", "rot 0.325559 0.325557 0.325552 ... 0.317106 0.477579 0.477613 \n", "size -0.133665 -0.133671 -0.133677 ... 0.024651 0.034412 0.034679 \n", "yr 0.127730 0.127716 0.127707 ... 0.450928 0.187551 0.187559 \n", "yt -0.008844 -0.008852 -0.008859 ... 0.256097 0.095557 0.095755 \n", "\n", " rend32 rev rmiddle rot size yr yt \n", "CST_ NaN NaN NaN NaN NaN NaN NaN \n", "begin 0.681573 0.038216 0.256349 0.325594 -0.133581 0.127658 -0.008816 \n", "dbegin 0.831903 0.111636 0.605990 0.298090 0.016411 0.291318 0.139951 \n", "dend 0.832924 0.111636 0.623582 0.298090 0.016411 0.305489 0.155098 \n", "dim 0.160414 -0.160357 0.106693 0.240946 0.212685 0.138961 0.192305 \n", "discont 0.594223 0.150144 0.225854 0.823937 0.064937 0.388140 0.203342 \n", "edit 0.594619 0.208568 0.652532 0.338994 0.057981 0.464225 0.283262 \n", "end 0.681565 0.038236 0.256479 0.325559 -0.133665 0.127730 -0.008844 \n", "end16 0.681550 0.038231 0.256451 0.325557 -0.133671 0.127716 -0.008852 \n", "end32 0.681533 0.038228 0.256430 0.325552 -0.133677 0.127707 -0.008859 \n", "ibegin16 0.715901 0.078215 0.522399 0.283800 0.037462 0.254352 0.136416 \n", "ibegin2 0.436148 0.128338 0.685605 0.049586 -0.027198 0.324851 0.109929 \n", "ibegin32 0.715901 0.078215 0.522399 0.283800 0.037462 0.254352 0.136416 \n", "ibegin4 0.615584 0.090985 0.594774 0.207651 0.116460 0.278810 0.179014 \n", "ibegin64 0.748257 0.074833 0.501069 0.307207 0.017724 0.243128 0.124459 \n", "ibegin8 0.615584 0.090985 0.594774 0.207651 0.116460 0.278810 0.179014 \n", "iend16 0.597069 0.094032 0.619928 0.191751 0.125689 0.307633 0.182267 \n", "iend2 0.437338 0.128338 0.724562 0.049586 -0.027198 0.337071 0.146321 \n", "iend32 0.688553 0.081511 0.544599 0.262448 0.049643 0.261930 0.139099 \n", "iend4 0.530439 0.105830 0.673384 0.136300 -0.039355 0.351151 0.127468 \n", "iend64 0.717534 0.078215 0.523746 0.283800 0.030938 0.255567 0.127728 \n", "iend8 0.597069 0.094032 0.619928 0.191751 0.125689 0.307633 0.182267 \n", "middle 0.186918 0.052991 0.467981 0.000903 0.728990 -0.008120 0.821357 \n", "rbegin 0.999993 0.056108 0.376357 0.477649 0.034328 0.187411 0.095471 \n", "rdiscont -0.192632 -0.054527 -0.001602 -0.265672 0.551880 0.004517 0.386893 \n", "redit 0.613446 0.244134 0.655658 0.317106 0.024651 0.450928 0.256097 \n", "rend 0.999999 0.056153 0.376658 0.477579 0.034412 0.187551 0.095557 \n", "rend16 1.000000 0.056129 0.376574 0.477613 0.034679 0.187559 0.095755 \n", "rend32 1.000000 0.056116 0.376557 0.477630 0.035005 0.187607 0.095996 \n", "rev 0.056116 1.000000 0.180470 0.117200 -0.034106 0.218387 0.094260 \n", "rmiddle 0.376557 0.180470 1.000000 -0.064351 -0.013771 0.468925 0.195497 \n", "rot 0.477630 0.117200 -0.064351 1.000000 0.051246 0.243294 0.126195 \n", "size 0.035005 -0.034106 -0.013771 0.051246 1.000000 -0.236289 0.805926 \n", "yr 0.187607 0.218387 0.468925 0.243294 -0.236289 1.000000 -0.013907 \n", "yt 0.095996 0.094260 0.195497 0.126195 0.805926 -0.013907 1.000000 \n", "\n", "[35 rows x 35 columns]"]}, "execution_count": 26, "metadata": {}, "output_type": "execute_result"}], "source": ["fdata.corr()"]}, {"cell_type": "code", "execution_count": 26, "id": "de885ceb", "metadata": {}, "outputs": [{"data": {"text/plain": ["CST_ NaN\n", "begin -0.008816\n", "dbegin 0.139951\n", "dend 0.155098\n", "dim 0.192305\n", "discont 0.203342\n", "edit 0.283262\n", "end -0.008844\n", "end16 -0.008852\n", "end32 -0.008859\n", "ibegin16 0.136416\n", "ibegin2 0.109929\n", "ibegin32 0.136416\n", "ibegin4 0.179014\n", "ibegin64 0.124459\n", "ibegin8 0.179014\n", "iend16 0.182267\n", "iend2 0.146321\n", "iend32 0.139099\n", "iend4 0.127468\n", "iend64 0.127728\n", "iend8 0.182267\n", "middle 0.821357\n", "rbegin 0.095471\n", "rdiscont 0.386893\n", "redit 0.256097\n", "rend 0.095557\n", "rend16 0.095755\n", "rend32 0.095996\n", "rev 0.094260\n", "rmiddle 0.195497\n", "rot 0.126195\n", "size 0.805926\n", "yr -0.013907\n", "yt 1.000000\n", "Name: yt, dtype: float64"]}, "execution_count": 27, "metadata": {}, "output_type": "execute_result"}], "source": ["fdata.corr()['yt']"]}, {"cell_type": "markdown", "id": "8ada5292", "metadata": {}, "source": ["We check the sign of the correlations of all features with *yt*. If it is positive, increasing the feature increases the processing time. We try to get only positive correlations. *end* is the flattened last dimensions left unchanged by the permutation. The bigger it is, the faster the transposition is. That's why the function computing all features multiplies this number by `-1` to get a feature positively correlated to the processing time. *end16* is equal to *end* when `end<-16` and `-16` when `end>=-16`. This is a simplification of the cost of moving data from memory to cache L1. This cost is linear when the data to move is big enough, but almost constant for small chunks."]}, {"cell_type": "markdown", "id": "040ed07d", "metadata": {}, "source": ["## Linear regression\n", "\n", "We choose a linear regression because the prediction are not limited. The training set does not include all configuration and surely does not include all possible high value the model may have to predict.\n", "\n", "The goal is not necessarily to predict the fastest permutation but to predict the processing time as the goal is to find the best combination of transpositions in a ONNX graph (einsum). The final goal is to predict which graphs optimizes a series of transpositions.\n", "\n", "The target could be the processing time or the logarithm of this time. However, making mistakes on small times is not an issue but errors on high processing time is not a good thing.\n", "\n", "We could also try to predict a ratio *transposition time /copy time* but it still gives more important to small matrix size. \n", "\n", "Many variables are correlated. Variables should be selected."]}, {"cell_type": "markdown", "id": "af7f29f9", "metadata": {}, "source": ["### Dataset"]}, {"cell_type": "code", "execution_count": 27, "id": "249b9319", "metadata": {}, "outputs": [], "source": ["X = fdata.drop([\"yt\", \"yr\"], axis=1)\n", "x_names = list(X.columns)\n", "yt = fdata['yt'] * 1000"]}, {"cell_type": "code", "execution_count": 28, "id": "1a9f6cde", "metadata": {}, "outputs": [{"data": {"text/plain": ["1.8809171132996723"]}, "execution_count": 29, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(yt)"]}, {"cell_type": "markdown", "id": "3dde6140", "metadata": {}, "source": ["### Simple model "]}, {"cell_type": "code", "execution_count": 29, "id": "66638541", "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.8157414076410756, 0.6368865305095469)"]}, "execution_count": 30, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.linear_model import LinearRegression\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.pipeline import make_pipeline\n", "from sklearn.metrics import r2_score, mean_absolute_error\n", "\n", "pipe = make_pipeline(StandardScaler(with_mean=False), LinearRegression(fit_intercept=False))\n", "pipe.fit(X, yt)\n", "model = pipe.steps[1][1]\n", "coef = {k: v for k, v in zip(X.columns, model.coef_)}\n", "coef['name'] = 'reg'\n", "coef['intercept_'] = model.intercept_\n", "pred = numpy.maximum(pipe.predict(X), 0)\n", "coef['r2'] = r2_score(yt, pred)\n", "coef['mae'] = mean_absolute_error(yt, pred)\n", "coef['model'] = pipe\n", "coefs = [coef]\n", "coef[\"r2\"], coef['mae']"]}, {"cell_type": "code", "execution_count": 30, "id": "bb8d7cbb", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["df = DataFrame([(k, v) for k, v in coef.items() if k not in {'name', 'model'}],\n", " columns=[\"feature\", \"value\"]).set_index(\"feature\")\n", "df.plot(kind=\"bar\", figsize=(14, 2));"]}, {"cell_type": "code", "execution_count": 31, "id": "7b0117ba", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
value
feature
CST_-3.076618e+08
begin-2.941725e+01
dbegin-1.854147e-01
dend-9.638954e-02
dim-1.037599e-01
discont5.204404e-01
edit3.582481e-01
end-1.046584e+12
end16-2.278042e+10
end321.069321e+12
ibegin16-3.713466e+00
ibegin21.439716e-02
ibegin323.784367e+00
ibegin4-6.813416e+00
ibegin64-7.576102e-02
ibegin86.927856e+00
iend162.028144e+07
iend28.225773e+06
iend324.322857e+07
iend41.097274e+07
iend641.996315e-01
iend82.028143e+07
middle1.541218e+00
rbegin4.940619e+01
rdiscont7.614642e-01
redit8.622710e-02
rend6.615750e+02
rend163.459172e+02
rend32-1.057057e+03
rev1.537206e-01
rmiddle-4.563712e-01
rot7.771901e-02
size1.295707e+00
intercept_0.000000e+00
r28.157414e-01
mae6.368865e-01
\n", "
"], "text/plain": [" value\n", "feature \n", "CST_ -3.076618e+08\n", "begin -2.941725e+01\n", "dbegin -1.854147e-01\n", "dend -9.638954e-02\n", "dim -1.037599e-01\n", "discont 5.204404e-01\n", "edit 3.582481e-01\n", "end -1.046584e+12\n", "end16 -2.278042e+10\n", "end32 1.069321e+12\n", "ibegin16 -3.713466e+00\n", "ibegin2 1.439716e-02\n", "ibegin32 3.784367e+00\n", "ibegin4 -6.813416e+00\n", "ibegin64 -7.576102e-02\n", "ibegin8 6.927856e+00\n", "iend16 2.028144e+07\n", "iend2 8.225773e+06\n", "iend32 4.322857e+07\n", "iend4 1.097274e+07\n", "iend64 1.996315e-01\n", "iend8 2.028143e+07\n", "middle 1.541218e+00\n", "rbegin 4.940619e+01\n", "rdiscont 7.614642e-01\n", "redit 8.622710e-02\n", "rend 6.615750e+02\n", "rend16 3.459172e+02\n", "rend32 -1.057057e+03\n", "rev 1.537206e-01\n", "rmiddle -4.563712e-01\n", "rot 7.771901e-02\n", "size 1.295707e+00\n", "intercept_ 0.000000e+00\n", "r2 8.157414e-01\n", "mae 6.368865e-01"]}, "execution_count": 32, "metadata": {}, "output_type": "execute_result"}], "source": ["df"]}, {"cell_type": "markdown", "id": "50593ad2", "metadata": {}, "source": ["Coefficients associated to features *end*, *end16* are almost opposed and it would better to get a model which keeps only one."]}, {"cell_type": "markdown", "id": "4934670e", "metadata": {}, "source": ["### Quantile Regression"]}, {"cell_type": "code", "execution_count": 32, "id": "92fcd452", "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.7924498414927943, 0.5679387557069854)"]}, "execution_count": 33, "metadata": {}, "output_type": "execute_result"}], "source": ["from mlinsights.mlmodel import QuantileLinearRegression\n", "pipe = make_pipeline(StandardScaler(with_mean=False), QuantileLinearRegression(fit_intercept=False))\n", "pipe.fit(X, yt)\n", "model = pipe.steps[1][1]\n", "coef = {k: v for k, v in zip(X.columns, model.coef_)}\n", "coef['name'] = 'med'\n", "coef['intercept_'] = model.intercept_\n", "pred = numpy.maximum(pipe.predict(X), 0)\n", "coef['r2'] = r2_score(yt, pred)\n", "coef['mae'] = mean_absolute_error(yt, pred)\n", "coef['model'] = pipe\n", "coefs.append(coef)\n", "coef[\"r2\"], coef['mae']"]}, {"cell_type": "code", "execution_count": 33, "id": "bc74043c", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
value
feature
CST_1433409.249051
begin27.13405
dbegin0.07931
dend0.087576
dim0.006919
discont0.413378
edit0.186032
end4876069525.422424
end16106134745.367844
end32-4982003112.711292
ibegin160.129918
ibegin2-0.069604
ibegin32-0.221099
ibegin4-0.045585
ibegin64-0.1085
ibegin80.073031
iend16-94492.918693
iend2-38324.37475
iend32-201401.795017
iend4-51122.392443
iend640.15928
iend8-94492.881923
middle1.588707
rbegin36.958438
rdiscont0.375421
redit0.071189
rend4424.263222
rend16-7664.018684
rend323202.681647
rev0.08288
rmiddle-0.207068
rot-0.095643
size0.938597
namemed
intercept_0
r20.79245
mae0.567939
model(StandardScaler(with_mean=False), QuantileLine...
\n", "
"], "text/plain": [" value\n", "feature \n", "CST_ 1433409.249051\n", "begin 27.13405\n", "dbegin 0.07931\n", "dend 0.087576\n", "dim 0.006919\n", "discont 0.413378\n", "edit 0.186032\n", "end 4876069525.422424\n", "end16 106134745.367844\n", "end32 -4982003112.711292\n", "ibegin16 0.129918\n", "ibegin2 -0.069604\n", "ibegin32 -0.221099\n", "ibegin4 -0.045585\n", "ibegin64 -0.1085\n", "ibegin8 0.073031\n", "iend16 -94492.918693\n", "iend2 -38324.37475\n", "iend32 -201401.795017\n", "iend4 -51122.392443\n", "iend64 0.15928\n", "iend8 -94492.881923\n", "middle 1.588707\n", "rbegin 36.958438\n", "rdiscont 0.375421\n", "redit 0.071189\n", "rend 4424.263222\n", "rend16 -7664.018684\n", "rend32 3202.681647\n", "rev 0.08288\n", "rmiddle -0.207068\n", "rot -0.095643\n", "size 0.938597\n", "name med\n", "intercept_ 0\n", "r2 0.79245\n", "mae 0.567939\n", "model (StandardScaler(with_mean=False), QuantileLine..."]}, "execution_count": 34, "metadata": {}, "output_type": "execute_result"}], "source": ["DataFrame(coef.items(), columns=[\"feature\", \"value\"]).set_index(\"feature\")"]}, {"cell_type": "markdown", "id": "c051acc4", "metadata": {}, "source": ["### Lasso\n", "\n", "To select features."]}, {"cell_type": "code", "execution_count": 34, "id": "cf893b4f", "metadata": {"scrolled": false}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 13/13 [00:00<00:00, 69.97it/s]\n"]}, {"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
r2maealphanulln
00.8097040.6294800.001433
10.8075460.6298860.0101033
20.7825410.6764990.1002333
30.7669110.6803440.2002833
40.7515460.7036840.3002933
50.7382230.7429620.4003033
60.7309370.7359580.5003133
70.7184370.7581430.6003033
80.7013290.8005030.7003033
90.6815900.8485490.8003033
100.6592180.8987700.9003033
110.6342180.9494931.0003033
120.2394131.6005422.0003033
\n", "
"], "text/plain": [" r2 mae alpha null n\n", "0 0.809704 0.629480 0.001 4 33\n", "1 0.807546 0.629886 0.010 10 33\n", "2 0.782541 0.676499 0.100 23 33\n", "3 0.766911 0.680344 0.200 28 33\n", "4 0.751546 0.703684 0.300 29 33\n", "5 0.738223 0.742962 0.400 30 33\n", "6 0.730937 0.735958 0.500 31 33\n", "7 0.718437 0.758143 0.600 30 33\n", "8 0.701329 0.800503 0.700 30 33\n", "9 0.681590 0.848549 0.800 30 33\n", "10 0.659218 0.898770 0.900 30 33\n", "11 0.634218 0.949493 1.000 30 33\n", "12 0.239413 1.600542 2.000 30 33"]}, "execution_count": 35, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.linear_model import Lasso\n", "\n", "scores = []\n", "models = []\n", "for a in tqdm([0.001, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 2.]):\n", " alpha = a * 1.\n", " pipe = make_pipeline(\n", " StandardScaler(with_mean=False),\n", " Lasso(alpha=alpha, fit_intercept=False, max_iter=5000))\n", " pipe.fit(X, yt)\n", " pred = numpy.maximum(pipe.predict(X), 0)\n", " model = pipe.steps[1][1]\n", " scores.append(dict(r2=r2_score(yt, pred), mae=mean_absolute_error(yt, pred),\n", " alpha=alpha, null=(numpy.abs(model.coef_) < 1e-6).sum(),\n", " n=len(model.coef_)))\n", " models.append(pipe)\n", " if alpha >= 0.01 and alpha <= 0.2:\n", " coef = {k: v for k, v in zip(X.columns, pipe.steps[1][1].coef_)}\n", " coef['name'] = \"Lasso-%f\" % alpha\n", " coef['model'] = pipe\n", " coef['r2'] = r2_score(yt, pred)\n", " coef['mae'] = mean_absolute_error(yt, pred)\n", " coefs.append(coef)\n", " \n", "DataFrame(scores)"]}, {"cell_type": "code", "execution_count": 35, "id": "b2a2e3b6", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["coef = {k: v for k, v in zip(X.columns, models[1].steps[1][1].coef_)}\n", "df = DataFrame(coef.items(), columns=[\"feature\", \"value\"]).set_index(\"feature\")\n", "df.plot(kind=\"bar\", figsize=(14, 2), title=\"alpha=%f\" % scores[1][\"alpha\"]);"]}, {"cell_type": "code", "execution_count": 36, "id": "62f9c62b", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["coef = {k: v for k, v in zip(X.columns, models[2].steps[1][1].coef_)}\n", "df = DataFrame(coef.items(), columns=[\"feature\", \"value\"]).set_index(\"feature\")\n", "df.plot(kind=\"bar\", figsize=(14, 2), title=\"alpha=%f\" % scores[2][\"alpha\"]);"]}, {"cell_type": "markdown", "id": "e2ca7fc5", "metadata": {}, "source": ["### Linear regression with positive weights"]}, {"cell_type": "code", "execution_count": 37, "id": "d73428ae", "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.7905447080626958, 0.6768663007518693)"]}, "execution_count": 38, "metadata": {}, "output_type": "execute_result"}], "source": ["pipe = make_pipeline(StandardScaler(with_mean=False), LinearRegression(positive=True, fit_intercept=False))\n", "pipe.fit(X, yt)\n", "model = pipe.steps[1][1]\n", "coef = {k: v for k, v in zip(X.columns, model.coef_)}\n", "coef['name'] = 'pos'\n", "coef['intercept_'] = model.intercept_\n", "pred = numpy.maximum(pipe.predict(X), 0)\n", "coef['r2'] = r2_score(yt, pred)\n", "coef['mae'] = mean_absolute_error(yt, pred)\n", "coef['model'] = pipe\n", "coefs.append(coef)\n", "coef[\"r2\"], coef['mae']"]}, {"cell_type": "code", "execution_count": 38, "id": "f5690715", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["coef = {k: v for k, v in zip(X.columns, pipe.steps[1][1].coef_)}\n", "df = DataFrame(coef.items(), columns=[\"feature\", \"value\"]).set_index(\"feature\")\n", "df.plot(kind=\"bar\", figsize=(14, 2), title=\"positive\");"]}, {"cell_type": "markdown", "id": "bd2436d5", "metadata": {}, "source": ["### Quantile regression with positive weights"]}, {"cell_type": "code", "execution_count": 39, "id": "497f7a66", "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.752689515971656, 0.6468340444504788)"]}, "execution_count": 40, "metadata": {}, "output_type": "execute_result"}], "source": ["pipe = make_pipeline(StandardScaler(with_mean=False), QuantileLinearRegression(positive=True, fit_intercept=False))\n", "pipe.fit(X, yt)\n", "model = pipe.steps[1][1]\n", "coef = {k: v for k, v in zip(X.columns, model.coef_)}\n", "coef['name'] = 'medpos'\n", "coef['intercept_'] = model.intercept_\n", "pred = numpy.maximum(pipe.predict(X), 0)\n", "coef['r2'] = r2_score(yt, pred)\n", "coef['mae'] = mean_absolute_error(yt, pred)\n", "coef['model'] = pipe\n", "coefs.append(coef)\n", "coef[\"r2\"], coef['mae']"]}, {"cell_type": "code", "execution_count": 40, "id": "893554e3", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["coef = {k: v for k, v in zip(X.columns, pipe.steps[1][1].coef_)}\n", "df = DataFrame(coef.items(), columns=[\"feature\", \"value\"]).set_index(\"feature\")\n", "df.plot(kind=\"bar\", figsize=(14, 2), title=\"positive\");"]}, {"cell_type": "markdown", "id": "b9229d31", "metadata": {}, "source": ["### Summary"]}, {"cell_type": "code", "execution_count": 41, "id": "ee71336c", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
6543210
CST_0.8294820.8210480.00.00.01433409.249051-307661768.128088
begin0.00.0-0.0-0.0-0.0344327.13405-29.417247
dbegin0.00.0-0.0-0.0-0.0447050.07931-0.185415
dend0.00.0-0.0-0.0-0.00.087576-0.09639
dim0.0238460.0-0.014763-0.030446-0.1209490.006919-0.10376
discont0.0606360.0562970.00.00.2104210.4133780.52044
edit0.038230.0948560.00.04180.3960520.1860320.358248
end0.00.0-0.0-0.0-0.0070534876069525.422424-1046583604803.358887
end160.00.0-0.0-0.0-0.000036106134745.367844-22780416305.902706
end320.00.0-0.0-0.0-0.00004-4982003112.7112921069320839370.567505
ibegin160.00.0-0.0-0.00.0666690.129918-3.713466
ibegin20.00.0-0.0-0.0-0.02181-0.0696040.014397
ibegin320.00.0-0.0-0.00.0-0.2210993.784367
ibegin40.00.0-0.0-0.00.0-0.045585-6.813416
ibegin640.00.0-0.0-0.00.0-0.1085-0.075761
ibegin80.00.0-0.0-0.00.00.0730316.927856
iend160.00.0-0.0-0.0-0.022416-94492.91869320281439.108194
iend20.00.00.0-0.00.0-38324.374758225773.255917
iend320.00.0-0.0-0.0-0.0-201401.79501743228573.054944
iend40.00.0-0.0-0.00.151081-51122.39244310972737.091606
iend640.00.0-0.0-0.0-0.00.159280.199631
iend80.00.0-0.0-0.0-0.08907-94492.88192320281426.580972
middle1.1015431.303471.2906991.3259161.4667331.5887071.541218
rbegin0.00.0-0.0-0.020369-0.2829536.95843849.406192
rdiscont0.00.00.0-0.0-0.0663850.3754210.761464
redit0.00.00.00.00.00.0711890.086227
rend0.00.0-0.0-0.007655-0.0075934424.263222661.575013
rend160.00.0-0.0-0.003393-0.010514-7664.018684345.917179
rend320.00.0-0.0-0.005349-0.0131723202.681647-1057.05651
rev0.0267570.1899090.0139920.0975850.1427910.082880.153721
rmiddle0.00.0-0.0-0.0-0.324716-0.207068-0.456371
rot0.0092220.1856870.1084680.1970210.108146-0.0956430.077719
size1.1005321.3002221.0994631.1835531.223290.9385971.295707
namemedposposLasso-0.200000Lasso-0.100000Lasso-0.010000medreg
intercept_0.00.0NaNNaNNaN0.00.0
r20.752690.7905450.7669110.7825410.8075460.792450.815741
mae0.6468340.6768660.6803440.6764990.6298860.5679390.636887
model(StandardScaler(with_mean=False), QuantileLine...(StandardScaler(with_mean=False), LinearRegres...(StandardScaler(with_mean=False), Lasso(alpha=...(StandardScaler(with_mean=False), Lasso(alpha=...(StandardScaler(with_mean=False), Lasso(alpha=...(StandardScaler(with_mean=False), QuantileLine...(StandardScaler(with_mean=False), LinearRegres...
\n", "
"], "text/plain": [" 6 \\\n", "CST_ 0.829482 \n", "begin 0.0 \n", "dbegin 0.0 \n", "dend 0.0 \n", "dim 0.023846 \n", "discont 0.060636 \n", "edit 0.03823 \n", "end 0.0 \n", "end16 0.0 \n", "end32 0.0 \n", "ibegin16 0.0 \n", "ibegin2 0.0 \n", "ibegin32 0.0 \n", "ibegin4 0.0 \n", "ibegin64 0.0 \n", "ibegin8 0.0 \n", "iend16 0.0 \n", "iend2 0.0 \n", "iend32 0.0 \n", "iend4 0.0 \n", "iend64 0.0 \n", "iend8 0.0 \n", "middle 1.101543 \n", "rbegin 0.0 \n", "rdiscont 0.0 \n", "redit 0.0 \n", "rend 0.0 \n", "rend16 0.0 \n", "rend32 0.0 \n", "rev 0.026757 \n", "rmiddle 0.0 \n", "rot 0.009222 \n", "size 1.100532 \n", "name medpos \n", "intercept_ 0.0 \n", "r2 0.75269 \n", "mae 0.646834 \n", "model (StandardScaler(with_mean=False), QuantileLine... \n", "\n", " 5 \\\n", "CST_ 0.821048 \n", "begin 0.0 \n", "dbegin 0.0 \n", "dend 0.0 \n", "dim 0.0 \n", "discont 0.056297 \n", "edit 0.094856 \n", "end 0.0 \n", "end16 0.0 \n", "end32 0.0 \n", "ibegin16 0.0 \n", "ibegin2 0.0 \n", "ibegin32 0.0 \n", "ibegin4 0.0 \n", "ibegin64 0.0 \n", "ibegin8 0.0 \n", "iend16 0.0 \n", "iend2 0.0 \n", "iend32 0.0 \n", "iend4 0.0 \n", "iend64 0.0 \n", "iend8 0.0 \n", "middle 1.30347 \n", "rbegin 0.0 \n", "rdiscont 0.0 \n", "redit 0.0 \n", "rend 0.0 \n", "rend16 0.0 \n", "rend32 0.0 \n", "rev 0.189909 \n", "rmiddle 0.0 \n", "rot 0.185687 \n", "size 1.300222 \n", "name pos \n", "intercept_ 0.0 \n", "r2 0.790545 \n", "mae 0.676866 \n", "model (StandardScaler(with_mean=False), LinearRegres... \n", "\n", " 4 \\\n", "CST_ 0.0 \n", "begin -0.0 \n", "dbegin -0.0 \n", "dend -0.0 \n", "dim -0.014763 \n", "discont 0.0 \n", "edit 0.0 \n", "end -0.0 \n", "end16 -0.0 \n", "end32 -0.0 \n", "ibegin16 -0.0 \n", "ibegin2 -0.0 \n", "ibegin32 -0.0 \n", "ibegin4 -0.0 \n", "ibegin64 -0.0 \n", "ibegin8 -0.0 \n", "iend16 -0.0 \n", "iend2 0.0 \n", "iend32 -0.0 \n", "iend4 -0.0 \n", "iend64 -0.0 \n", "iend8 -0.0 \n", "middle 1.290699 \n", "rbegin -0.0 \n", "rdiscont 0.0 \n", "redit 0.0 \n", "rend -0.0 \n", "rend16 -0.0 \n", "rend32 -0.0 \n", "rev 0.013992 \n", "rmiddle -0.0 \n", "rot 0.108468 \n", "size 1.099463 \n", "name Lasso-0.200000 \n", "intercept_ NaN \n", "r2 0.766911 \n", "mae 0.680344 \n", "model (StandardScaler(with_mean=False), Lasso(alpha=... \n", "\n", " 3 \\\n", "CST_ 0.0 \n", "begin -0.0 \n", "dbegin -0.0 \n", "dend -0.0 \n", "dim -0.030446 \n", "discont 0.0 \n", "edit 0.0418 \n", "end -0.0 \n", "end16 -0.0 \n", "end32 -0.0 \n", "ibegin16 -0.0 \n", "ibegin2 -0.0 \n", "ibegin32 -0.0 \n", "ibegin4 -0.0 \n", "ibegin64 -0.0 \n", "ibegin8 -0.0 \n", "iend16 -0.0 \n", "iend2 -0.0 \n", "iend32 -0.0 \n", "iend4 -0.0 \n", "iend64 -0.0 \n", "iend8 -0.0 \n", "middle 1.325916 \n", "rbegin -0.020369 \n", "rdiscont -0.0 \n", "redit 0.0 \n", "rend -0.007655 \n", "rend16 -0.003393 \n", "rend32 -0.005349 \n", "rev 0.097585 \n", "rmiddle -0.0 \n", "rot 0.197021 \n", "size 1.183553 \n", "name Lasso-0.100000 \n", "intercept_ NaN \n", "r2 0.782541 \n", "mae 0.676499 \n", "model (StandardScaler(with_mean=False), Lasso(alpha=... \n", "\n", " 2 \\\n", "CST_ 0.0 \n", "begin -0.03443 \n", "dbegin -0.044705 \n", "dend -0.0 \n", "dim -0.120949 \n", "discont 0.210421 \n", "edit 0.396052 \n", "end -0.007053 \n", "end16 -0.000036 \n", "end32 -0.00004 \n", "ibegin16 0.066669 \n", "ibegin2 -0.02181 \n", "ibegin32 0.0 \n", "ibegin4 0.0 \n", "ibegin64 0.0 \n", "ibegin8 0.0 \n", "iend16 -0.022416 \n", "iend2 0.0 \n", "iend32 -0.0 \n", "iend4 0.151081 \n", "iend64 -0.0 \n", "iend8 -0.08907 \n", "middle 1.466733 \n", "rbegin -0.28295 \n", "rdiscont -0.066385 \n", "redit 0.0 \n", "rend -0.007593 \n", "rend16 -0.010514 \n", "rend32 -0.013172 \n", "rev 0.142791 \n", "rmiddle -0.324716 \n", "rot 0.108146 \n", "size 1.22329 \n", "name Lasso-0.010000 \n", "intercept_ NaN \n", "r2 0.807546 \n", "mae 0.629886 \n", "model (StandardScaler(with_mean=False), Lasso(alpha=... \n", "\n", " 1 \\\n", "CST_ 1433409.249051 \n", "begin 27.13405 \n", "dbegin 0.07931 \n", "dend 0.087576 \n", "dim 0.006919 \n", "discont 0.413378 \n", "edit 0.186032 \n", "end 4876069525.422424 \n", "end16 106134745.367844 \n", "end32 -4982003112.711292 \n", "ibegin16 0.129918 \n", "ibegin2 -0.069604 \n", "ibegin32 -0.221099 \n", "ibegin4 -0.045585 \n", "ibegin64 -0.1085 \n", "ibegin8 0.073031 \n", "iend16 -94492.918693 \n", "iend2 -38324.37475 \n", "iend32 -201401.795017 \n", "iend4 -51122.392443 \n", "iend64 0.15928 \n", "iend8 -94492.881923 \n", "middle 1.588707 \n", "rbegin 36.958438 \n", "rdiscont 0.375421 \n", "redit 0.071189 \n", "rend 4424.263222 \n", "rend16 -7664.018684 \n", "rend32 3202.681647 \n", "rev 0.08288 \n", "rmiddle -0.207068 \n", "rot -0.095643 \n", "size 0.938597 \n", "name med \n", "intercept_ 0.0 \n", "r2 0.79245 \n", "mae 0.567939 \n", "model (StandardScaler(with_mean=False), QuantileLine... \n", "\n", " 0 \n", "CST_ -307661768.128088 \n", "begin -29.417247 \n", "dbegin -0.185415 \n", "dend -0.09639 \n", "dim -0.10376 \n", "discont 0.52044 \n", "edit 0.358248 \n", "end -1046583604803.358887 \n", "end16 -22780416305.902706 \n", "end32 1069320839370.567505 \n", "ibegin16 -3.713466 \n", "ibegin2 0.014397 \n", "ibegin32 3.784367 \n", "ibegin4 -6.813416 \n", "ibegin64 -0.075761 \n", "ibegin8 6.927856 \n", "iend16 20281439.108194 \n", "iend2 8225773.255917 \n", "iend32 43228573.054944 \n", "iend4 10972737.091606 \n", "iend64 0.199631 \n", "iend8 20281426.580972 \n", "middle 1.541218 \n", "rbegin 49.406192 \n", "rdiscont 0.761464 \n", "redit 0.086227 \n", "rend 661.575013 \n", "rend16 345.917179 \n", "rend32 -1057.05651 \n", "rev 0.153721 \n", "rmiddle -0.456371 \n", "rot 0.077719 \n", "size 1.295707 \n", "name reg \n", "intercept_ 0.0 \n", "r2 0.815741 \n", "mae 0.636887 \n", "model (StandardScaler(with_mean=False), LinearRegres... "]}, "execution_count": 42, "metadata": {}, "output_type": "execute_result"}], "source": ["dfcoef = DataFrame(coefs)\n", "dfcoef[::-1].T"]}, {"cell_type": "code", "execution_count": 42, "id": "31ce6479", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["dfcoef[[\"name\", \"r2\", \"mae\"]].set_index('name').plot(kind=\"bar\", title=\"performance accross models\");"]}, {"cell_type": "code", "execution_count": 43, "id": "ab8eb8cd", "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import matplotlib.pyplot as plt\n", "\n", "dfp = dfcoef.drop(['name', 'model'], axis=1).T.drop([0, 1], axis=1).copy()\n", "dfp.columns = dfcoef['name'][2:]\n", "ax = dfp.plot(figsize=(14, 4), kind=\"line\")\n", "ax.set_xticks(numpy.arange(0, dfp.shape[0]))\n", "ax.set_xticklabels(dfp.index)\n", "plt.setp(ax.get_xticklabels(), rotation=45, horizontalalignment='right');"]}, {"cell_type": "markdown", "id": "512edd9d", "metadata": {}, "source": ["## Investigation"]}, {"cell_type": "code", "execution_count": 44, "id": "d38e14a6", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execnumberpermshaperatiodimpredicterrabserrrel
280.0001130.0000290.0000610.00014150(1, 0, 2)(43, 44, 45)1.51571130.0001131.251063e-071.251063e-070.001111
550.0008930.0002120.0006460.00147750(2, 3, 1, 0)(3, 244, 244, 1)8.33992640.000893-2.410649e-072.410649e-07-0.000270
260.0000770.0000080.0000690.00010150(0, 2, 1)(43, 44, 45)1.03275930.0000774.172780e-074.172780e-070.005440
390.0001260.0000170.0001070.00018550(2, 0, 1, 3)(3, 244, 244, 1)1.18099640.000115-1.179187e-051.179187e-05-0.093246
660.0001950.0000190.0001630.00025850(0, 3, 1, 2)(1, 244, 244, 3)1.59877040.0002101.510728e-051.510728e-050.077417
500.0006920.0001360.0005290.00102150(0, 3, 2, 1)(3, 244, 244, 1)6.46075940.0007091.714180e-051.714180e-050.024778
760.0008240.0002880.0005150.00187950(1, 3, 2, 0)(1, 244, 244, 3)6.75266340.0008431.902846e-051.902846e-050.023087
540.0008180.0002780.0006250.00152250(2, 1, 3, 0)(3, 244, 244, 1)7.63464640.0008432.572773e-052.572773e-050.031471
10.0000480.0000030.0000450.00005850(0, 1, 3, 2)(12, 13, 15, 18)0.82082140.000000-4.837787e-054.837787e-05-1.000000
20.0000490.0000030.0000450.00006250(3, 0, 1, 2)(12, 13, 15, 18)0.82307040.000000-4.851040e-054.851040e-05-1.000000
1200.0059180.0010850.0045980.00831215(1, 0, 3, 2)(3, 244, 244, 6)1.59735740.0082592.341673e-032.341673e-030.395716
1280.0106770.0014570.0085040.01407015(2, 3, 1, 0)(3, 244, 244, 6)2.88219140.008132-2.545011e-032.545011e-03-0.238356
1210.0061060.0005560.0056190.00730515(2, 0, 3, 1)(3, 244, 244, 6)1.64832540.0087002.593662e-032.593662e-030.424746
1180.0037800.0008820.0027010.00540215(3, 0, 1, 2)(3, 244, 244, 6)1.02043240.0064882.707333e-032.707333e-030.716171
1150.0036530.0008740.0025670.00524415(1, 2, 3, 0)(3, 244, 244, 6)0.98607140.0064882.834624e-032.834624e-030.775972
1290.0124210.0030520.0078180.01810615(3, 2, 1, 0)(3, 244, 244, 6)3.35277040.009386-3.034652e-033.034652e-03-0.244323
1190.0049380.0003670.0045320.00584415(1, 3, 0, 2)(3, 244, 244, 6)1.33306140.0087003.761588e-033.761588e-030.761694
1270.0103390.0027890.0058780.01830115(0, 3, 2, 1)(3, 244, 244, 6)2.79082340.005271-5.068171e-035.068171e-03-0.490205
1300.0134320.0044960.0065360.02125015(3, 2, 0, 1)(3, 244, 244, 6)3.62568040.008132-5.299336e-035.299336e-03-0.394540
1310.0145790.0040260.0071440.02073915(3, 0, 2, 1)(3, 244, 244, 6)3.93548340.008259-6.320138e-036.320138e-03-0.433499
\n", "
"], "text/plain": [" average deviation min_exec max_exec number perm \\\n", "28 0.000113 0.000029 0.000061 0.000141 50 (1, 0, 2) \n", "55 0.000893 0.000212 0.000646 0.001477 50 (2, 3, 1, 0) \n", "26 0.000077 0.000008 0.000069 0.000101 50 (0, 2, 1) \n", "39 0.000126 0.000017 0.000107 0.000185 50 (2, 0, 1, 3) \n", "66 0.000195 0.000019 0.000163 0.000258 50 (0, 3, 1, 2) \n", "50 0.000692 0.000136 0.000529 0.001021 50 (0, 3, 2, 1) \n", "76 0.000824 0.000288 0.000515 0.001879 50 (1, 3, 2, 0) \n", "54 0.000818 0.000278 0.000625 0.001522 50 (2, 1, 3, 0) \n", "1 0.000048 0.000003 0.000045 0.000058 50 (0, 1, 3, 2) \n", "2 0.000049 0.000003 0.000045 0.000062 50 (3, 0, 1, 2) \n", "120 0.005918 0.001085 0.004598 0.008312 15 (1, 0, 3, 2) \n", "128 0.010677 0.001457 0.008504 0.014070 15 (2, 3, 1, 0) \n", "121 0.006106 0.000556 0.005619 0.007305 15 (2, 0, 3, 1) \n", "118 0.003780 0.000882 0.002701 0.005402 15 (3, 0, 1, 2) \n", "115 0.003653 0.000874 0.002567 0.005244 15 (1, 2, 3, 0) \n", "129 0.012421 0.003052 0.007818 0.018106 15 (3, 2, 1, 0) \n", "119 0.004938 0.000367 0.004532 0.005844 15 (1, 3, 0, 2) \n", "127 0.010339 0.002789 0.005878 0.018301 15 (0, 3, 2, 1) \n", "130 0.013432 0.004496 0.006536 0.021250 15 (3, 2, 0, 1) \n", "131 0.014579 0.004026 0.007144 0.020739 15 (3, 0, 2, 1) \n", "\n", " shape ratio dim predict err abserr \\\n", "28 (43, 44, 45) 1.515711 3 0.000113 1.251063e-07 1.251063e-07 \n", "55 (3, 244, 244, 1) 8.339926 4 0.000893 -2.410649e-07 2.410649e-07 \n", "26 (43, 44, 45) 1.032759 3 0.000077 4.172780e-07 4.172780e-07 \n", "39 (3, 244, 244, 1) 1.180996 4 0.000115 -1.179187e-05 1.179187e-05 \n", "66 (1, 244, 244, 3) 1.598770 4 0.000210 1.510728e-05 1.510728e-05 \n", "50 (3, 244, 244, 1) 6.460759 4 0.000709 1.714180e-05 1.714180e-05 \n", "76 (1, 244, 244, 3) 6.752663 4 0.000843 1.902846e-05 1.902846e-05 \n", "54 (3, 244, 244, 1) 7.634646 4 0.000843 2.572773e-05 2.572773e-05 \n", "1 (12, 13, 15, 18) 0.820821 4 0.000000 -4.837787e-05 4.837787e-05 \n", "2 (12, 13, 15, 18) 0.823070 4 0.000000 -4.851040e-05 4.851040e-05 \n", "120 (3, 244, 244, 6) 1.597357 4 0.008259 2.341673e-03 2.341673e-03 \n", "128 (3, 244, 244, 6) 2.882191 4 0.008132 -2.545011e-03 2.545011e-03 \n", "121 (3, 244, 244, 6) 1.648325 4 0.008700 2.593662e-03 2.593662e-03 \n", "118 (3, 244, 244, 6) 1.020432 4 0.006488 2.707333e-03 2.707333e-03 \n", "115 (3, 244, 244, 6) 0.986071 4 0.006488 2.834624e-03 2.834624e-03 \n", "129 (3, 244, 244, 6) 3.352770 4 0.009386 -3.034652e-03 3.034652e-03 \n", "119 (3, 244, 244, 6) 1.333061 4 0.008700 3.761588e-03 3.761588e-03 \n", "127 (3, 244, 244, 6) 2.790823 4 0.005271 -5.068171e-03 5.068171e-03 \n", "130 (3, 244, 244, 6) 3.625680 4 0.008132 -5.299336e-03 5.299336e-03 \n", "131 (3, 244, 244, 6) 3.935483 4 0.008259 -6.320138e-03 6.320138e-03 \n", "\n", " rel \n", "28 0.001111 \n", "55 -0.000270 \n", "26 0.005440 \n", "39 -0.093246 \n", "66 0.077417 \n", "50 0.024778 \n", "76 0.023087 \n", "54 0.031471 \n", "1 -1.000000 \n", "2 -1.000000 \n", "120 0.395716 \n", "128 -0.238356 \n", "121 0.424746 \n", "118 0.716171 \n", "115 0.775972 \n", "129 -0.244323 \n", "119 0.761694 \n", "127 -0.490205 \n", "130 -0.394540 \n", "131 -0.433499 "]}, "execution_count": 45, "metadata": {}, "output_type": "execute_result"}], "source": ["data_err = data.drop([\"context_size\", \"repeat\"], axis=1).copy()\n", "data_err['predict'] = numpy.maximum(coefs[0]['model'].predict(X), 0) / 1000\n", "data_err['err'] = (data_err['predict'] - data_err['average'])\n", "data_err['abserr'] = numpy.abs(data_err['predict'] - data_err['average'])\n", "data_err['rel'] = (data_err['predict'] - data_err['average']) / data_err['average']\n", "s = data_err.sort_values('abserr')\n", "pandas.concat([s.head(n=10), s.tail(n=10)])"]}, {"cell_type": "markdown", "id": "0e6ccfdb", "metadata": {}, "source": ["All big errors are negative. The model seems to give a lower value for all big errors. These errors may be outliers, the processor was busy doing something else at that time."]}, {"cell_type": "code", "execution_count": 45, "id": "2c384543", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
averagedeviationmin_execmax_execnumberpermshaperatiodimpredicterrabserrrel
200.0001580.0000210.0001270.00019250(3, 0, 2, 1)(12, 13, 15, 18)2.68487640.000000-0.0001580.000158-1.000000
420.0001470.0000170.0001060.00017550(0, 1, 3, 2)(3, 244, 244, 1)1.36997840.000000-0.0001470.000147-1.000000
340.0001510.0000160.0001360.00019750(1, 2, 0)(3, 244, 244)1.43844630.000000-0.0001510.000151-1.000000
330.0001240.0000170.0001080.00017150(2, 0, 1)(3, 244, 244)1.18566630.000000-0.0001240.000124-1.000000
440.0001890.0000440.0001420.00026550(1, 2, 3, 0)(3, 244, 244, 1)1.76690540.000000-0.0001890.000189-1.000000
270.0000970.0000040.0000830.00011050(2, 0, 1)(43, 44, 45)1.30091530.000000-0.0000970.000097-1.000000
250.0000740.0000090.0000650.00010950(0, 1, 2)(43, 44, 45)1.00000030.000000-0.0000740.000074-1.000000
240.0000730.0000090.0000620.00009450(1, 2, 0)(43, 44, 45)0.98551330.000000-0.0000730.000073-1.000000
220.0002140.0000600.0001360.00029550(1, 0, 3, 2)(12, 13, 15, 18)3.62724040.000000-0.0002140.000214-1.000000
210.0001640.0000450.0001240.00023150(3, 1, 2, 0)(12, 13, 15, 18)2.77819340.000000-0.0001640.000164-1.000000
1280.0106770.0014570.0085040.01407015(2, 3, 1, 0)(3, 244, 244, 6)2.88219140.008132-0.0025450.002545-0.238356
1300.0134320.0044960.0065360.02125015(3, 2, 0, 1)(3, 244, 244, 6)3.62568040.008132-0.0052990.005299-0.394540
1220.0067220.0018070.0050670.01124515(1, 3, 2, 0)(3, 244, 244, 6)1.81455240.0082590.0015370.0015370.228654
1250.0078150.0017570.0059320.01077915(2, 1, 3, 0)(3, 244, 244, 6)2.10948940.0082590.0004440.0004440.056871
1200.0059180.0010850.0045980.00831215(1, 0, 3, 2)(3, 244, 244, 6)1.59735740.0082590.0023420.0023420.395716
1230.0070710.0009820.0054540.00855915(3, 1, 0, 2)(3, 244, 244, 6)1.90866740.0082590.0011880.0011880.168070
1310.0145790.0040260.0071440.02073915(3, 0, 2, 1)(3, 244, 244, 6)3.93548340.008259-0.0063200.006320-0.433499
1210.0061060.0005560.0056190.00730515(2, 0, 3, 1)(3, 244, 244, 6)1.64832540.0087000.0025940.0025940.424746
1190.0049380.0003670.0045320.00584415(1, 3, 0, 2)(3, 244, 244, 6)1.33306140.0087000.0037620.0037620.761694
1290.0124210.0030520.0078180.01810615(3, 2, 1, 0)(3, 244, 244, 6)3.35277040.009386-0.0030350.003035-0.244323
\n", "
"], "text/plain": [" average deviation min_exec max_exec number perm \\\n", "20 0.000158 0.000021 0.000127 0.000192 50 (3, 0, 2, 1) \n", "42 0.000147 0.000017 0.000106 0.000175 50 (0, 1, 3, 2) \n", "34 0.000151 0.000016 0.000136 0.000197 50 (1, 2, 0) \n", "33 0.000124 0.000017 0.000108 0.000171 50 (2, 0, 1) \n", "44 0.000189 0.000044 0.000142 0.000265 50 (1, 2, 3, 0) \n", "27 0.000097 0.000004 0.000083 0.000110 50 (2, 0, 1) \n", "25 0.000074 0.000009 0.000065 0.000109 50 (0, 1, 2) \n", "24 0.000073 0.000009 0.000062 0.000094 50 (1, 2, 0) \n", "22 0.000214 0.000060 0.000136 0.000295 50 (1, 0, 3, 2) \n", "21 0.000164 0.000045 0.000124 0.000231 50 (3, 1, 2, 0) \n", "128 0.010677 0.001457 0.008504 0.014070 15 (2, 3, 1, 0) \n", "130 0.013432 0.004496 0.006536 0.021250 15 (3, 2, 0, 1) \n", "122 0.006722 0.001807 0.005067 0.011245 15 (1, 3, 2, 0) \n", "125 0.007815 0.001757 0.005932 0.010779 15 (2, 1, 3, 0) \n", "120 0.005918 0.001085 0.004598 0.008312 15 (1, 0, 3, 2) \n", "123 0.007071 0.000982 0.005454 0.008559 15 (3, 1, 0, 2) \n", "131 0.014579 0.004026 0.007144 0.020739 15 (3, 0, 2, 1) \n", "121 0.006106 0.000556 0.005619 0.007305 15 (2, 0, 3, 1) \n", "119 0.004938 0.000367 0.004532 0.005844 15 (1, 3, 0, 2) \n", "129 0.012421 0.003052 0.007818 0.018106 15 (3, 2, 1, 0) \n", "\n", " shape ratio dim predict err abserr rel \n", "20 (12, 13, 15, 18) 2.684876 4 0.000000 -0.000158 0.000158 -1.000000 \n", "42 (3, 244, 244, 1) 1.369978 4 0.000000 -0.000147 0.000147 -1.000000 \n", "34 (3, 244, 244) 1.438446 3 0.000000 -0.000151 0.000151 -1.000000 \n", "33 (3, 244, 244) 1.185666 3 0.000000 -0.000124 0.000124 -1.000000 \n", "44 (3, 244, 244, 1) 1.766905 4 0.000000 -0.000189 0.000189 -1.000000 \n", "27 (43, 44, 45) 1.300915 3 0.000000 -0.000097 0.000097 -1.000000 \n", "25 (43, 44, 45) 1.000000 3 0.000000 -0.000074 0.000074 -1.000000 \n", "24 (43, 44, 45) 0.985513 3 0.000000 -0.000073 0.000073 -1.000000 \n", "22 (12, 13, 15, 18) 3.627240 4 0.000000 -0.000214 0.000214 -1.000000 \n", "21 (12, 13, 15, 18) 2.778193 4 0.000000 -0.000164 0.000164 -1.000000 \n", "128 (3, 244, 244, 6) 2.882191 4 0.008132 -0.002545 0.002545 -0.238356 \n", "130 (3, 244, 244, 6) 3.625680 4 0.008132 -0.005299 0.005299 -0.394540 \n", "122 (3, 244, 244, 6) 1.814552 4 0.008259 0.001537 0.001537 0.228654 \n", "125 (3, 244, 244, 6) 2.109489 4 0.008259 0.000444 0.000444 0.056871 \n", "120 (3, 244, 244, 6) 1.597357 4 0.008259 0.002342 0.002342 0.395716 \n", "123 (3, 244, 244, 6) 1.908667 4 0.008259 0.001188 0.001188 0.168070 \n", "131 (3, 244, 244, 6) 3.935483 4 0.008259 -0.006320 0.006320 -0.433499 \n", "121 (3, 244, 244, 6) 1.648325 4 0.008700 0.002594 0.002594 0.424746 \n", "119 (3, 244, 244, 6) 1.333061 4 0.008700 0.003762 0.003762 0.761694 \n", "129 (3, 244, 244, 6) 3.352770 4 0.009386 -0.003035 0.003035 -0.244323 "]}, "execution_count": 46, "metadata": {}, "output_type": "execute_result"}], "source": ["s = data_err.sort_values('predict')\n", "pandas.concat([s.head(n=10), s.tail(n=10)])"]}, {"cell_type": "markdown", "id": "9ea9188b", "metadata": {}, "source": ["### Correlation between predictors"]}, {"cell_type": "code", "execution_count": 46, "id": "46c6e86e", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
regmedLasso-0.010000Lasso-0.100000Lasso-0.200000posmedposyt
00.2987890.0524360.0000000.0000000.0000000.0000000.0000000.044222
10.0000000.0715750.0000000.0000000.0000000.0000000.0000000.048378
20.0000000.0483930.0000000.0000000.0000000.0000000.0000000.048510
30.0000000.0483930.0000000.0000000.0000000.0000000.0000000.048954
40.2480890.0507810.0000000.0000000.0000000.0000000.0000000.050805
...........................
1275.2707004.1770124.9171054.6154904.4644294.8370324.25138110.338870
1288.1323427.3547998.1071917.6469667.3348617.8583636.70654810.677354
1299.3860058.1861908.9912568.0824317.3973008.7710406.89620412.420657
1308.1323427.3547998.1071917.6469667.3348617.8583636.70654813.431679
1318.2592367.5610047.9621607.6056057.3348617.8297286.73897214.579374
\n", "

132 rows \u00d7 8 columns

\n", "
"], "text/plain": [" reg med Lasso-0.010000 Lasso-0.100000 Lasso-0.200000 \\\n", "0 0.298789 0.052436 0.000000 0.000000 0.000000 \n", "1 0.000000 0.071575 0.000000 0.000000 0.000000 \n", "2 0.000000 0.048393 0.000000 0.000000 0.000000 \n", "3 0.000000 0.048393 0.000000 0.000000 0.000000 \n", "4 0.248089 0.050781 0.000000 0.000000 0.000000 \n", ".. ... ... ... ... ... \n", "127 5.270700 4.177012 4.917105 4.615490 4.464429 \n", "128 8.132342 7.354799 8.107191 7.646966 7.334861 \n", "129 9.386005 8.186190 8.991256 8.082431 7.397300 \n", "130 8.132342 7.354799 8.107191 7.646966 7.334861 \n", "131 8.259236 7.561004 7.962160 7.605605 7.334861 \n", "\n", " pos medpos yt \n", "0 0.000000 0.000000 0.044222 \n", "1 0.000000 0.000000 0.048378 \n", "2 0.000000 0.000000 0.048510 \n", "3 0.000000 0.000000 0.048954 \n", "4 0.000000 0.000000 0.050805 \n", ".. ... ... ... \n", "127 4.837032 4.251381 10.338870 \n", "128 7.858363 6.706548 10.677354 \n", "129 8.771040 6.896204 12.420657 \n", "130 7.858363 6.706548 13.431679 \n", "131 7.829728 6.738972 14.579374 \n", "\n", "[132 rows x 8 columns]"]}, "execution_count": 47, "metadata": {}, "output_type": "execute_result"}], "source": ["cc = DataFrame(dict([(c['name'], numpy.maximum(c['model'].predict(X), 0)) for c in coefs]))\n", "cc['yt'] = yt\n", "cc"]}, {"cell_type": "code", "execution_count": 47, "id": "f751d4d4", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
regmedLasso-0.010000Lasso-0.100000Lasso-0.200000posmedposyt
reg1.0000000.9941240.9969220.9857150.9798260.9883230.9804330.903528
med0.9941241.0000000.9958630.9899900.9873740.9903410.9884010.894833
Lasso-0.0100000.9969220.9958631.0000000.9926890.9879300.9944200.9883580.899384
Lasso-0.1000000.9857150.9899900.9926891.0000000.9985640.9987560.9979850.886902
Lasso-0.2000000.9798260.9873740.9879300.9985641.0000000.9950920.9993850.880614
pos0.9883230.9903410.9944200.9987560.9950921.0000000.9951690.890093
medpos0.9804330.9884010.9883580.9979850.9993850.9951691.0000000.881208
yt0.9035280.8948330.8993840.8869020.8806140.8900930.8812081.000000
\n", "
"], "text/plain": [" reg med Lasso-0.010000 Lasso-0.100000 \\\n", "reg 1.000000 0.994124 0.996922 0.985715 \n", "med 0.994124 1.000000 0.995863 0.989990 \n", "Lasso-0.010000 0.996922 0.995863 1.000000 0.992689 \n", "Lasso-0.100000 0.985715 0.989990 0.992689 1.000000 \n", "Lasso-0.200000 0.979826 0.987374 0.987930 0.998564 \n", "pos 0.988323 0.990341 0.994420 0.998756 \n", "medpos 0.980433 0.988401 0.988358 0.997985 \n", "yt 0.903528 0.894833 0.899384 0.886902 \n", "\n", " Lasso-0.200000 pos medpos yt \n", "reg 0.979826 0.988323 0.980433 0.903528 \n", "med 0.987374 0.990341 0.988401 0.894833 \n", "Lasso-0.010000 0.987930 0.994420 0.988358 0.899384 \n", "Lasso-0.100000 0.998564 0.998756 0.997985 0.886902 \n", "Lasso-0.200000 1.000000 0.995092 0.999385 0.880614 \n", "pos 0.995092 1.000000 0.995169 0.890093 \n", "medpos 0.999385 0.995169 1.000000 0.881208 \n", "yt 0.880614 0.890093 0.881208 1.000000 "]}, "execution_count": 48, "metadata": {}, "output_type": "execute_result"}], "source": ["cc.corr()"]}, {"cell_type": "markdown", "id": "ab74c8b9", "metadata": {}, "source": ["## Standalone predictions"]}, {"cell_type": "code", "execution_count": 48, "id": "4ff3905a", "metadata": {}, "outputs": [{"data": {"text/plain": ["{'CST_': 0.829481835464256,\n", " 'begin': 0.0,\n", " 'dbegin': 0.0,\n", " 'dend': 0.0,\n", " 'dim': 0.08294721851224843,\n", " 'discont': 0.07025394222472751,\n", " 'edit': 0.03782977428195987,\n", " 'end': 0.0,\n", " 'end16': 0.0,\n", " 'end32': 0.0,\n", " 'ibegin16': 0.0,\n", " 'ibegin2': 0.0,\n", " 'ibegin32': 0.0,\n", " 'ibegin4': 0.0,\n", " 'ibegin64': 0.0,\n", " 'ibegin8': 0.0,\n", " 'iend16': 0.0,\n", " 'iend2': 0.0,\n", " 'iend32': 0.0,\n", " 'iend4': 0.0,\n", " 'iend64': 0.0,\n", " 'iend8': 0.0,\n", " 'middle': 3.42896339670081e-06,\n", " 'rbegin': 0.0,\n", " 'rdiscont': 0.0,\n", " 'redit': 0.0,\n", " 'rend': 0.0,\n", " 'rend16': 0.0,\n", " 'rend32': 0.0,\n", " 'rev': 0.11940214295823245,\n", " 'rmiddle': 0.0,\n", " 'rot': 0.023189032947793925,\n", " 'size': 3.021302183272755e-06}"]}, "execution_count": 49, "metadata": {}, "output_type": "execute_result"}], "source": ["def get_coef(pipe, names):\n", " c1 = pipe.steps[0][-1].scale_\n", " c2 = pipe.steps[1][-1].coef_\n", " return dict(zip(names, c2 / c1))\n", "\n", "\n", "get_coef(coefs[-1][\"model\"], X.columns)"]}, {"cell_type": "code", "execution_count": 49, "id": "b52b1aab", "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.005450704959759156, array([0.0054507]))"]}, "execution_count": 50, "metadata": {}, "output_type": "execute_result"}], "source": ["def predict(coefs, shape, perm):\n", " feat = compute_features(shape, perm)\n", " res = 0\n", " for k, v in feat.items():\n", " res += v * coefs[k]\n", " return res / 1000\n", "\n", "\n", "def predict_model(model, shape, perm, names):\n", " feat = compute_features(shape, perm)\n", " a = numpy.zeros((1, len(names)), dtype=numpy.float64)\n", " for i, n in enumerate(names):\n", " a[0, i] = feat[n]\n", " return model.predict(a) / 1000\n", " \n", "\n", "coef = get_coef(coefs[-1][\"model\"], X.columns)\n", "(predict(coef, (3, 224, 224, 6), (3, 0, 1, 2)), \n", " predict_model(coefs[-1][\"model\"], (3, 224, 224, 6), (3, 0, 1, 2), X.columns))"]}, {"cell_type": "code", "execution_count": 50, "id": "7f770721", "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5"}}, "nbformat": 4, "nbformat_minor": 5}