安装
二进制分发版
通过 pip 安装稳定版:
pip install triton
二进制包支持 CPython 3.6-3.9 和 PyPy 3.6-3.7。
最新的每日构建版本:
pip install -U --pre triton
从源码安装
Python Package
运行下面的命令:
git clone https://github.com/openai/triton.git; cd triton/python; pip install cmake; # build time dependency pip install -e .
注意,如果llvm-11在系统未找到,setup.py script 将自动下载 LLVM11 静态库链接。
运行单元测试:
pytest -vs .
性能测试:
cd bench/
python -m run --with-plots --result-dir /tmp/triton-bench
教程
以下是使用 Triton 的一些基本操作的栗子。建议按顺序阅读,然后去实验。
Vector Addition
该教程中,使用Triton去编写矢量加法,可以了解到:
-
Triton的基本编程模型。
-
装饰器triton.jit 用于定义 Triton 内核。
-
自定义操作的验证和测度的最佳实践。
Compute Kernel
import torch import triton.language as tl import triton @triton.jit def _add( X, # *Pointer* to first input vector Y, # *Pointer* to second input vector Z, # *Pointer* to output vector N, # Size of the vector **meta # Optional meta-parameters for the kernel ): pid = tl.program_id(0) # Create an offset for the blocks of pointers to be # processed by this program instance offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK']) # Create a mask to guard memory operations against # out-of-bounds accesses mask = offsets < N # Load x x = tl.load(X + offsets, mask=mask) y = tl.load(Y + offsets, mask=mask) # Write back x + y z = x + y tl.store(Z + offsets, z)
同时声明一个 helper function to (1)分配 z tensor 和 (2) enqueue 上面的内核,包含合适的 grid/block sizes。
def add(x, y): z = torch.empty_like(x) N = z.shape[0] # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int] grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), ) # NOTE: # - each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel # - don't forget to pass meta-parameters as keywords arguments _add[grid](x, y, z, N, BLOCK=1024) # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still # running asynchronously at this point. return z
使用上面的函数计算 element-wise sum of two torch.tensor objects,然后测试偏差:
torch.manual_seed(0) size = 98432 x = torch.rand(size, device='cuda') y = torch.rand(size, device='cuda') za = x + y zb = add(x, y) print(za) print(zb) print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')
输出:
tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0') tensor([1.3713, 1.3076, 0.4940, ..., 0.6724, 1.2141, 0.9733], device='cuda:0') The maximum difference between torch and triton is 0.0
看起来不错!
Benchmark
现在benchmark我们的自定义矢量操作,通过逐步增加 sizes 来测度相对于 PyTorch的敏感度。为了方便, Triton 已经有内置工具使我们能够绘出不同sizes下的执行情况:
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['size'], # argument names to use as an x-axis for the plot x_vals=[2**i for i in range(12, 28, 1)], # different possible values for `x_name` x_log=True, # x axis is logarithmic line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['triton', 'torch'], # possible values for `line_arg` line_names=["Triton", "Torch"], # label name for the lines styles=[('blue', '-'), ('green', '-')], # line styles ylabel="GB/s", # label name for the y-axis plot_name="vector-add-performance", # name for the plot. Used also as a file name for saving the plot. args={} # values for function arguments not in `x_names` and `y_name` ) ) def benchmark(size, provider): x = torch.rand(size, device='cuda', dtype=torch.float32) y = torch.rand(size, device='cuda', dtype=torch.float32) if provider == 'torch': ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y)) gbps = lambda ms: 12 * size / ms * 1e-6 return gbps(ms), gbps(max_ms), gbps(min_ms)
运行上面函数,传递 print_data=True 查看性能参数,show_plots=True 进行绘制,以及 `save_path=’/path/to/results/’ 保存原始的 CSV 数据。
benchmark.run(print_data=True, show_plots=True)

输出:
vector-add-performance: size Triton Torch 0 4096.0 9.600000 9.600000 1 8192.0 19.200000 19.200000 2 16384.0 38.400001 38.400001 3 32768.0 76.800002 76.800002 4 65536.0 127.999995 127.999995 5 131072.0 219.428568 219.428568 6 262144.0 341.333321 384.000001 7 524288.0 472.615390 472.615390 8 1048576.0 614.400016 614.400016 9 2097152.0 722.823517 722.823517 10 4194304.0 780.190482 780.190482 11 8388608.0 812.429770 812.429770 12 16777216.0 833.084721 833.084721 13 33554432.0 843.811163 843.811163 14 67108864.0 849.278610 848.362445 15 134217728.0 851.577704 850.656574
Total running time of the script: ( 0 minutes 11.032 seconds)
Fused Softmax
该教程中,将编写 fused softmax operation,将显著比PyTorch’s native op快,在特定的场景下: those whose rows can fit in the GPU’s SRAM. 将学习到:
-
The benefits of kernel fusion for bandwidth-bound operations.
-
Reduction operators in Triton.
Motivations
自定义GPU kernels for elementwise additions 哼好滴用于教学,但在实践中用处不大。让我们考虑一个简单的(numerically stabilized) softmax operation:
import torch # Compute the row-wise softmax of x @torch.jit.script def naive_softmax(x): # read MN elements ; write M elements x_max = x.max(dim=1)[0] # read 2MN elements ; write MN elements z = x - x_max[:, None] # read MN elements ; write MN elements numerator = torch.exp(x) # read MN elements ; write M elements denominator = numerator.sum(dim=1) # read 2MN elements ; write MN elements ret = numerator / denominator[:, None] # in total: read 7MN elements ; wrote 3MN + 2M elements return ret
当在pytorch中natively实现时,计算y = naive_softmax(x)
要求从 DRAM 读入和写回elements。这显然有点浪费; 我们偏好自定义 “fused” kernel t只需要读 X 一次然后在片上完成所有需要的计算。要求读写的只有数个bytes,期待有达到5倍的提升。 The torch.jit.script flags 可以帮助自动实现这种 “kernel fusion” ,但是后面我们会看到,这离理想的情况还有比较大的差距。
Compute Kernel
我们定义的 softmax kernel 工作如下:
- each program loads a row of the input matrix X,
- normalizes it and writes back the result to the output Y.
Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:
import triton import triton.language as tl @triton.jit def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta): # row index m = tl.program_id(0) # col indices # here BLOCK is the smallest power of two greater than `N` n = tl.arange(0, meta['BLOCK']) # the memory address of all the elements # that we want to load can be computed as follows X = X + m * stride_xm + n x = tl.load(X, mask=n < N, other=-float('inf')) # Substract maximum for numerical stability z = x - tl.max(x, axis=0) # Note that exponentials in Triton are fast # but approximate (i.e., think __expf in CUDA) num = tl.exp(z) denom = tl.sum(num, axis=0) y = num / denom # Write back to Y Y = Y + m * stride_ym + n tl.store(Y, y, mask=n < N)
我们创建一个 helper function,使 kernel 和 its (meta-)arguments为任何输入的 tensor队列化。
def next_power_of_2(n): n -= 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 n += 1 return n def softmax(x): M, N = x.shape # The block size is the smallest power of two greater than the number of columns in `x` BLOCK = next_power_of_2(N) # Another trick we can use is to ask the compiler to use more threads per row by # increasing the number of warps (`num_warps`) over which each row is distributed. # You will see in the next tutorial how to auto-tune this value in a more natural # way so you don't have to come up with manual heuristics yourself. num_warps = 4 if BLOCK >= 2048: num_warps = 8 if BLOCK >= 4096: num_warps = 16 # Allocate output y = torch.empty_like(x) # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK) return y
Unit Test
确认在非规则的行列数矩阵测试上面的内核,验证上面的对齐机制是否工作正常。
torch.manual_seed(0) x = torch.randn(1823, 781, device='cuda') y_tri = softmax(x) y_ref = torch.softmax(x, axis=1) print(torch.allclose(y_tri, y_ref))
输出:
True
结果如期望的一样。
Benchmark
然后我们benchmark我们的操作,作为function of the number of columns in the input matrix – 假定为4096 行。然后比较其性能:1) torch.softmax
和 (2) naive_softmax
上面定义的。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['triton', 'torch-native', 'torch-jit'], # possible values for `line_arg`` line_names=["Triton", "Torch (native)", "Torch (jit)"], # label name for the lines styles=[('blue', '-'), ('green', '-'), ('green', '--')], # line styles ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096} # values for function arguments not in `x_names` and `y_name` ) ) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=torch.float32) if provider == 'torch-native': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1)) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x)) if provider == 'torch-jit': ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x)) gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) benchmark.run(show_plots=True, print_data=True)

输出:
softmax-performance: N Triton Torch (native) Torch (jit) 0 256.0 512.000001 546.133347 273.066674 1 384.0 585.142862 585.142862 267.130429 2 512.0 630.153853 606.814814 264.258068 3 640.0 682.666684 640.000002 269.473696 4 768.0 702.171410 664.216187 273.066663 .. ... ... ... ... 93 12160.0 812.359066 406.179533 329.483481 94 12288.0 812.429770 415.661740 329.602681 95 12416.0 810.840807 412.149375 329.173158 96 12544.0 810.925276 412.546756 329.292871 97 12672.0 811.007961 412.097543 329.410251 [98 rows x 4 columns]
上面的图中,我们可以看到:
Triton 快 2-3x 倍,对比于 Torch JIT。
Triton 比
torch.softmax
也要快。 我们猜测 PyTorch kernel 只是 partially fuses the computation of the softmax.这意味着 – when temporary data is too large to fit entirely in the GPU’s cache – it transfers almost twice the amount of memory necessary.
注意Triton kernel 不仅是比 PyTorch’s CUDA kernel要快,而且更易于阅读、理解和维护。
Total running time of the script: ( 1 minutes 8.174 seconds)
Matrix Multiplication
该教程中,将编写一个 25行的高性能 FP16 矩阵乘法内核达到 performance on par with cuBLAS. 将学习了解:
- Block-level matrix multiplications
- Multi-dimensional pointer arithmetic
- Program re-ordering for improved L2 cache hit rate
- Automatic performance tuning
Motivations
矩阵相乘是很多现代高性能计算系统的关键模块。 这很难进行优化,因此大部分情况都是由硬件提供商实现,称为: “kernel libraries” (e.g., cuBLAS). 不幸的是,这些库经常专用的很难定制适用现代深度学习的工作负载。 本教程将让你学到使用Triton实现一个高效的矩阵乘法器,并且易于定制和扩展。
简单滴说,该kernel将实现下面的 blocked algorithm:
# do in parallel for m in range(0, M, BLOCK_M): # do in parallel for n in range(0, N, BLOCK_N): acc = zeros((BLOCK_M, BLOCK_N), dtype=float32) for k in range(0, K, BLOCK_K): a = A[m : m+BLOCK_M, k : k+BLOCK_K] b = B[k : k+BLOCK_K, n : n+BLOCK_N] acc += dot(a, b) C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;
where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.
Compute Kernel
上面的算法,实际上在Triton中可以相当直接的实现。主要难点来自于块内存的分配计算必须在循环的内部读出,因此,我们需要multi-dimensional pointer arithmetics.
Pointer Arithmetics
对于row-major 2D tensor X
, the memory location of X[i, j]
is given by &X[i, j] = X + i*stride_x_0 + j*stride_x_1
. 因此, blocks of pointers for A[m : m+BLOCK_M, k:k+BLOCK_K]
和 B[k : k+BLOCK_K, n : n+BLOCK_N]
可以被定义为下面的伪代码:
&A[m : m+BLOCK_M, k:k+BLOCK_K] = A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1); &B[k : k+BLOCK_K, n:n+BLOCK_N] = B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);
这意味着pointers for blocks of A and B 可以在 (i.e., k=0
) 在 Triton中初始化为:
pid_m = triton.program_id(0) pid_n = triton.program_id(1) rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N) rk = triton.arange(0, BLOCK_K) // pointer for A operand pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1); // pointer for B operand pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);
然后在 inner loop 中更新,如下:
pa += BLOCK_K * stride_a_1; pb += BLOCK_K * stride_b_0;
L2 Cache Optimizations
如上提到, each program instance computes an [BLOCK_M, BLOCK_N]
block of C
. 需要记住重要的一点是,被计算的 blocks的顺序很重要,因为这会影响 L2 cache 的命中率。不幸地,简单的row-major ordering:
pid = triton.program_id(0); grid_m = (M + BLOCK_M - 1) // BLOCK_M; grid_n = (N + BLOCK_N - 1) // BLOCK_N; pid_m = pid / grid_n; pid_n = pid % grid_n;
is just not going to cut it.
一个可能的解决方案是 launch blocks in an order that promotes data reuse. 可以通过 ‘super-grouping’ blocks in groups of GROUP_M
rows在切换到下一个列之前:
pid = triton.program_id(0); width = GROUP_M * grid_n; group_id = pid // width; # we need to handle the case where M % (GROUP_M*BLOCK_M) != 0 group_size = min(grid_m - group_id * GROUP_M, GROUP_M); pid_m = group_id * GROUP_M + (pid % group_size); pid_n = (pid % width) // (group_size);
实践中,这会提升我们的matrix multiplication kernel 性能>10% (在一些硬件上,如220 to 245 TFLOPS on A100).
最终结果
import torch import triton import triton.language as tl # % # :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: # - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try # - A autotuning *key* whose change in values will trigger evaluation of all the provided configs @triton.autotune( configs=[ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8), triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\ triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\ triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\ triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2), #triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4), ], key=['M', 'N', 'K'], ) # % # We can now define our kernel as normal, using all the techniques presented above @triton.jit def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META): # extract meta-parameters BLOCK_M = META['BLOCK_M'] BLOCK_N = META['BLOCK_N'] BLOCK_K = META['BLOCK_K'] GROUP_M = 8 # matrix multiplication pid = tl.program_id(0) grid_m = (M + BLOCK_M - 1) // BLOCK_M grid_n = (N + BLOCK_N - 1) // BLOCK_N # re-order program ID for better L2 performance width = GROUP_M * grid_n group_id = pid // width group_size = min(grid_m - group_id * GROUP_M, GROUP_M) pid_m = group_id * GROUP_M + (pid % group_size) pid_n = (pid % width) // (group_size) # do matrix multiplication rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rk = tl.arange(0, BLOCK_K) A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(K, 0, -BLOCK_K): a = tl.load(A) b = tl.load(B) acc += tl.dot(a, b) A += BLOCK_K * stride_ak B += BLOCK_K * stride_bk # triton can accept arbitrary activation function # via metaparameters! if META['ACTIVATION']: acc = META['ACTIVATION'](acc) # rematerialize rm and rn to save registers rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) mask = (rm[:, None] < M) & (rn[None, :] < N) tl.store(C, acc, mask=mask) # we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul` @triton.jit def leaky_relu(x): return tl.where(x >= 0, x, 0.01*x)
现在创建 convenience wrapper function,包含两个输入tensors 以及: (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel。
def matmul(a, b, activation=None): # checks constraints assert a.shape[1] == b.shape[0], "incompatible dimensions" assert a.is_contiguous(), "matrix A must be contiguous" assert b.is_contiguous(), "matrix B must be contiguous" M, K = a.shape _, N = b.shape # allocates output c = torch.empty((M, N), device=a.device, dtype=a.dtype) # launch kernel grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) pgm = _matmul[grid]( a, b, c, M, N, K, \ a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\ ACTIVATION = activation ) # done; return the output tensor return c
Unit Test
测试matrix multiplication operation 相对于native torch implementation (i.e., cuBLAS):
torch.manual_seed(0) a = torch.randn((512, 512), device='cuda', dtype=torch.float16) b = torch.randn((512, 512), device='cuda', dtype=torch.float16) c_0 = matmul(a, b, activation=None) c_1 = torch.matmul(a, b) print(c_0) print(c_1) print(triton.testing.allclose(c_0, c_1))
输出:
tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3984, 24.4531, -32.3438], [ 6.3555, -19.6094, 34.0938, ..., -5.8945, 5.2891, 6.8867], [-32.0625, 5.9492, 15.3984, ..., -21.3906, -23.9844, -10.1328], ..., [ -5.7031, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500], [ 25.5000, 24.3281, -8.4688, ..., -18.9375, 32.5312, -29.9219], [ -5.3477, 4.9844, 11.8906, ..., 5.5898, 6.4023, -17.3125]], device='cuda:0', dtype=torch.float16) tensor([[ 1.1045, -36.9688, 31.4688, ..., -11.3906, 24.4531, -32.3438], [ 6.3516, -19.6094, 34.0938, ..., -5.8906, 5.2812, 6.8828], [-32.0625, 5.9531, 15.3984, ..., -21.4062, -23.9844, -10.1328], ..., [ -5.7070, 7.4492, 8.2656, ..., -10.6953, -40.0000, 17.7500], [ 25.5000, 24.3438, -8.4609, ..., -18.9375, 32.5312, -29.9219], [ -5.3477, 4.9805, 11.8828, ..., 5.5859, 6.4023, -17.3125]], device='cuda:0', dtype=torch.float16) tensor(True, device='cuda:0')
Benchmark
Square Matrix 性能
我们现在比较一下我们的内核与 cuBLAS的性能。这里聚焦在 square matrices, 但是也可以将这个脚本用于其它的 matrix shape的测试。
@triton.testing.perf_report( triton.testing.Benchmark( x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot x_vals=[128 * i for i in range(1, 33)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'], # possible values for `line_arg`` line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"], # label name for the lines styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], # line styles ylabel="TFLOPS", # label name for the y-axis plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. args={} ) ) def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) if provider == 'cublas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b)) if provider == 'triton': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b)) if provider == 'cublas + relu': torch_relu = torch.nn.ReLU(inplace=True) ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b))) if provider == 'triton + relu': ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu)) perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) return perf(ms), perf(max_ms), perf(min_ms) benchmark.run(show_plots=True, print_data=True)

输出:
matmul-performance: M cuBLAS ... Triton Triton (+ LeakyReLU) 0 128.0 0.455111 ... 0.512000 0.512000 1 256.0 2.978909 ... 2.978909 2.978909 2 384.0 7.372800 ... 7.899428 7.899428 3 512.0 14.563555 ... 16.384000 15.420235 4 640.0 22.260869 ... 24.380953 24.380953 5 768.0 32.768000 ... 34.028308 34.028308 6 896.0 39.025776 ... 39.025776 39.025776 7 1024.0 51.150050 ... 52.428801 52.428801 8 1152.0 44.566925 ... 46.656000 46.656000 9 1280.0 51.200001 ... 56.109587 56.109587 10 1408.0 64.138541 ... 65.684049 65.684049 11 1536.0 80.430545 ... 76.106321 75.296679 12 1664.0 63.372618 ... 62.061463 61.636381 13 1792.0 72.983276 ... 68.953520 68.533074 14 1920.0 69.120002 ... 68.435645 68.435645 15 2048.0 73.908442 ... 75.573044 75.234154 16 2176.0 83.500614 ... 80.173899 79.855747 17 2304.0 68.446623 ... 73.051599 72.607513 18 2432.0 71.125224 ... 81.197876 80.963875 19 2560.0 77.649287 ... 76.027843 76.740048 20 2688.0 83.552988 ... 83.186525 82.823267 21 2816.0 84.035084 ... 76.921000 79.733474 22 2944.0 82.102191 ... 80.122235 78.729910 23 3072.0 82.540970 ... 82.661468 82.661468 24 3200.0 84.432717 ... 89.385477 84.432717 25 3328.0 83.905938 ... 86.113988 86.528001 26 3456.0 82.015834 ... 83.545665 84.156124 27 3584.0 87.466332 ... 92.600816 84.988707 28 3712.0 85.163978 ... 82.902362 83.666116 29 3840.0 84.292684 ... 84.550462 85.070769 30 3968.0 89.921841 ... 87.472354 87.409694 31 4096.0 93.792965 ... 89.478485 90.260743 [32 rows x 5 columns]
Total running time of the script: ( 2 minutes 2.376 seconds)