|
| 1 | +####################################################################### |
| 2 | +# Copyright (c) 2019-present, Blosc Development Team <blosc@blosc.org> |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under a BSD-style license (found in the |
| 6 | +# LICENSE file in the root directory of this source tree) |
| 7 | +####################################################################### |
| 8 | +# It is important to force numpy to use mkl as it can speed up the |
| 9 | +# blosc2 matmul (which uses np.matmul as a backend) by a factor of 2x: |
| 10 | +# conda install numpy mkl |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import blosc2 |
| 14 | +import time |
| 15 | +import matplotlib.pyplot as plt |
| 16 | +import torch |
| 17 | +import pickle |
| 18 | + |
| 19 | + |
| 20 | +plt.rcParams.update({'text.usetex':False,'font.serif': ['cm'],'font.size':16}) |
| 21 | +plt.rcParams['figure.dpi'] = 300 |
| 22 | +plt.rcParams['savefig.dpi'] = 300 |
| 23 | +plt.rc('text', usetex=False) |
| 24 | +plt.rc('font',**{'serif':['cm']}) |
| 25 | +plt.style.use('seaborn-v0_8-paper') |
| 26 | + |
| 27 | +ndim = 3 |
| 28 | +filename = f"matmul{ndim}D_bench" |
| 29 | + |
| 30 | +shapes = np.array([1, 2, 4, 8, 12, 16, 20])**(1/3) * 2**(28/3) |
| 31 | +plotmode = True |
| 32 | +if not plotmode: |
| 33 | + for xp in [blosc2, np, torch]: |
| 34 | + sizes = [] |
| 35 | + mean_times = {'blosc2':[], 'torch':[], 'numpy':[]} |
| 36 | + for n in shapes: |
| 37 | + N = int(n) |
| 38 | + shape_a = (N,) * ndim |
| 39 | + shape_b = (N,) * ndim |
| 40 | + size_gb = (N ** ndim * 4) / (2 ** 30) |
| 41 | + |
| 42 | + for lib in [blosc2, torch, np]: |
| 43 | + # Generate matrices |
| 44 | + matrix_a = lib.full(shape_a, fill_value=3., dtype=lib.float32) |
| 45 | + matrix_b = lib.full(shape_b, fill_value=2.4, dtype=lib.float32) |
| 46 | + matrix_c = lib.full(shape_b[:1], fill_value=.4, dtype=lib.float32) |
| 47 | + _time = 0 |
| 48 | + #multiplication |
| 49 | + if (xp.__name__ == 'torch' and lib.__name__ == 'torch' |
| 50 | + ) or (xp.__name__ == 'numpy' and lib.__name__ != 'blosc2' |
| 51 | + ) or xp.__name__ == 'blosc2': |
| 52 | + for _ in range(1): |
| 53 | + t0 = time.perf_counter() |
| 54 | + if xp.__name__ == 'blosc2': |
| 55 | + (xp.matmul(matrix_a, matrix_b) + matrix_c).compute() |
| 56 | + else: |
| 57 | + xp.matmul(matrix_a, matrix_b) + matrix_c |
| 58 | + _time = time.perf_counter() - t0 |
| 59 | + mean_times[lib.__name__]+=[_time] |
| 60 | + print(f"Size = {np.round(size_gb, 1)} GB, {xp.__name__.upper()}_{lib.__name__} Performance = {_time:.2f} s") |
| 61 | + |
| 62 | + sizes+=[size_gb * 3] |
| 63 | + |
| 64 | + with open(f"{filename}_{xp.__name__.upper()}.pkl", 'wb') as f: |
| 65 | + pickle.dump( |
| 66 | + {'blosc2':{ |
| 67 | + "Matrix Size (GB)": sizes, |
| 68 | + "Mean Time (s)": mean_times['blosc2'] |
| 69 | + }, |
| 70 | + 'numpy':{ |
| 71 | + "Matrix Size (GB)": sizes, |
| 72 | + "Mean Time (s)": mean_times['numpy'] |
| 73 | + }, |
| 74 | + 'torch':{ |
| 75 | + "Matrix Size (GB)": sizes, |
| 76 | + "Mean Time (s)": mean_times['torch'] |
| 77 | + } |
| 78 | + }, f) |
| 79 | + |
| 80 | +else: |
| 81 | + plt.figure() |
| 82 | + for mkr, xp in zip(('X', 'd', 's'), [blosc2, torch, np]): |
| 83 | + with open(f"{filename}_{xp.__name__.upper()}.pkl", 'rb') as f: |
| 84 | + res_dict = pickle.load(f) |
| 85 | + |
| 86 | + # Create plots for Numpy vs Blosc vs Torch |
| 87 | + _dict = res_dict['torch'] |
| 88 | + x=np.round(_dict["Matrix Size (GB)"], 1) |
| 89 | + plt.plot(x, _dict["Mean Time (s)"], color='r', label=f'{xp.__name__.upper()}_torch', marker = mkr) |
| 90 | + if xp.__name__ != 'torch': |
| 91 | + _dict = res_dict['numpy'] |
| 92 | + plt.plot(x, _dict["Mean Time (s)"], color='g', label=f'{xp.__name__.upper()}_numpy', marker = mkr) |
| 93 | + if xp.__name__ == 'blosc2': |
| 94 | + _dict = res_dict['blosc2'] |
| 95 | + plt.plot(x, _dict["Mean Time (s)"], color='b', label=f'{xp.__name__.upper()}_blosc2', marker = mkr) |
| 96 | + |
| 97 | + |
| 98 | + plt.xlabel('Working set size (GB)') |
| 99 | + plt.legend() |
| 100 | + plt.ylabel("Time (s)") |
| 101 | + plt.title(f'matmul(A, B) + c, ndim = {ndim}') |
| 102 | + plt.gca().set_yscale('log') |
| 103 | + plt.savefig(f'{filename}.png', format="png") |
0 commit comments