Triton:安装与使用例程

2021/08/02 22:32
阅读数 472

安装

二进制分发版

通过 pip 安装稳定版:

pip install triton

二进制包支持 CPython 3.6-3.9 和 PyPy 3.6-3.7。

最新的每日构建版本:

pip install -U --pre triton

从源码安装

Python Package

运行下面的命令:

git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build time dependency
pip install -e .

注意,如果llvm-11在系统未找到,setup.py script 将自动下载 LLVM11 静态库链接。

运行单元测试:

pytest -vs .

性能测试:

cd bench/
python -m run --with-plots --result-dir /tmp/triton-bench

教程

以下是使用 Triton 的一些基本操作的栗子。建议按顺序阅读,然后去实验。

Vector Addition

Vector Addition

该教程中,使用Triton去编写矢量加法,可以了解到:

  • Triton的基本编程模型。

  • 装饰器triton.jit 用于定义 Triton 内核。

  • 自定义操作的验证和测度的最佳实践。

Compute Kernel

import torch
import triton.language as tl
import triton


@triton.jit
def _add(
    X,  # *Pointer* to first input vector
    Y,  # *Pointer* to second input vector
    Z,  # *Pointer* to output vector
    N,  # Size of the vector
    **meta  # Optional meta-parameters for the kernel
):
    pid = tl.program_id(0)
    # Create an offset for the blocks of pointers to be
    # processed by this program instance
    offsets = pid * meta['BLOCK'] + tl.arange(0, meta['BLOCK'])
    # Create a mask to guard memory operations against
    # out-of-bounds accesses
    mask = offsets < N
    # Load x
    x = tl.load(X + offsets, mask=mask)
    y = tl.load(Y + offsets, mask=mask)
    # Write back x + y
    z = x + y
    tl.store(Z + offsets, z)

同时声明一个 helper function to (1)分配 z tensor 和 (2) enqueue 上面的内核,包含合适的 grid/block sizes。

def add(x, y):
    z = torch.empty_like(x)
    N = z.shape[0]
    # The SPMD launch grid denotes the number of kernel instances that run in parallel.
    # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]
    grid = lambda meta: (triton.cdiv(N, meta['BLOCK']), )
    # NOTE:
    #  - each torch.tensor object is implicitly converted into a pointer to its first element.
    #  - `triton.jit`'ed functions can be index with a launch grid to obtain a callable GPU kernel
    #  - don't forget to pass meta-parameters as keywords arguments
    _add[grid](x, y, z, N, BLOCK=1024)
    # We return a handle to z but, since `torch.cuda.synchronize()` hasn't been called, the kernel is still
    # running asynchronously at this point.
    return z

使用上面的函数计算 element-wise sum of two torch.tensor objects,然后测试偏差:

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
za = x + y
zb = add(x, y)
print(za)
print(zb)
print(f'The maximum difference between torch and triton is ' f'{torch.max(torch.abs(za - zb))}')

输出:

tensor([1.3713, 1.3076, 0.4940,  ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.6724, 1.2141, 0.9733], device='cuda:0')
The maximum difference between torch and triton is 0.0

看起来不错!

Benchmark

现在benchmark我们的自定义矢量操作,通过逐步增加 sizes 来测度相对于 PyTorch的敏感度。为了方便, Triton 已经有内置工具使我们能够绘出不同sizes下的执行情况:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],  # argument names to use as an x-axis for the plot
        x_vals=[2**i for i in range(12, 28, 1)],  # different possible values for `x_name`
        x_log=True,  # x axis is logarithmic
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch'],  # possible values for `line_arg`
        line_names=["Triton", "Torch"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="vector-add-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={}  # values for function arguments not in `x_names` and `y_name`
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device='cuda', dtype=torch.float32)
    y = torch.rand(size, device='cuda', dtype=torch.float32)
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y))
    gbps = lambda ms: 12 * size / ms * 1e-6
    return gbps(ms), gbps(max_ms), gbps(min_ms)

运行上面函数,传递 print_data=True 查看性能参数,show_plots=True 进行绘制,以及 `save_path=’/path/to/results/’ 保存原始的 CSV 数据。

benchmark.run(print_data=True, show_plots=True)
01 vector add

输出:

vector-add-performance:
           size      Triton       Torch
0        4096.0    9.600000    9.600000
1        8192.0   19.200000   19.200000
2       16384.0   38.400001   38.400001
3       32768.0   76.800002   76.800002
4       65536.0  127.999995  127.999995
5      131072.0  219.428568  219.428568
6      262144.0  341.333321  384.000001
7      524288.0  472.615390  472.615390
8     1048576.0  614.400016  614.400016
9     2097152.0  722.823517  722.823517
10    4194304.0  780.190482  780.190482
11    8388608.0  812.429770  812.429770
12   16777216.0  833.084721  833.084721
13   33554432.0  843.811163  843.811163
14   67108864.0  849.278610  848.362445
15  134217728.0  851.577704  850.656574

Total running time of the script: ( 0 minutes 11.032 seconds)

Fused Softmax

Fused Softmax

该教程中,将编写 fused softmax operation,将显著比PyTorch’s native op快,在特定的场景下: those whose rows can fit in the GPU’s SRAM. 将学习到:

  • The benefits of kernel fusion for bandwidth-bound operations.

  • Reduction operators in Triton.

Motivations

自定义GPU kernels for elementwise additions 哼好滴用于教学,但在实践中用处不大。让我们考虑一个简单的(numerically stabilized) softmax operation:

import torch


# Compute the row-wise softmax of x
@torch.jit.script
def naive_softmax(x):
    # read  MN elements ; write M  elements
    x_max = x.max(dim=1)[0]
    # read 2MN elements ; write MN elements
    z = x - x_max[:, None]
    # read  MN elements ; write MN elements
    numerator = torch.exp(x)
    # read  MN elements ; write M  elements
    denominator = numerator.sum(dim=1)
    # read 2MN elements ; write MN elements
    ret = numerator / denominator[:, None]
    # in total: read 7MN elements ; wrote 3MN + 2M elements
    return ret

当在pytorch中natively实现时,计算y = naive_softmax(x)要求从 DRAM 读入和写回elements。这显然有点浪费; 我们偏好自定义 “fused” kernel t只需要读 X 一次然后在片上完成所有需要的计算。要求读写的只有数个bytes,期待有达到5倍的提升。 The torch.jit.script flags 可以帮助自动实现这种 “kernel fusion” ,但是后面我们会看到,这离理想的情况还有比较大的差距。

Compute Kernel

我们定义的 softmax kernel 工作如下:

  • each program loads a row of the input matrix X,
  • normalizes it and writes back the result to the output Y.

Note that one important limitation of Triton is that each block must have a power-of-two number of elements, so we need to internally “pad” each row and guard the memory operations properly if we want to handle any possible input shapes:

import triton
import triton.language as tl


@triton.jit
def _softmax(Y, X, stride_xm, stride_ym, M, N, **meta):
    # row index
    m = tl.program_id(0)
    # col indices
    # here BLOCK is the smallest power of two greater than `N`
    n = tl.arange(0, meta['BLOCK'])
    # the memory address of all the elements
    # that we want to load can be computed as follows
    X = X + m * stride_xm + n
    x = tl.load(X, mask=n < N, other=-float('inf'))
    # Substract maximum for numerical stability
    z = x - tl.max(x, axis=0)
    # Note that exponentials in Triton are fast
    # but approximate (i.e., think __expf in CUDA)
    num = tl.exp(z)
    denom = tl.sum(num, axis=0)
    y = num / denom
    # Write back to Y
    Y = Y + m * stride_ym + n
    tl.store(Y, y, mask=n < N)

我们创建一个 helper function,使 kernel 和 its (meta-)arguments为任何输入的 tensor队列化。

def next_power_of_2(n):
    n -= 1
    n |= n >> 1
    n |= n >> 2
    n |= n >> 4
    n |= n >> 8
    n |= n >> 16
    n += 1
    return n


def softmax(x):
    M, N = x.shape
    # The block size is the smallest power of two greater than the number of columns in `x`
    BLOCK = next_power_of_2(N)
    # Another trick we can use is to ask the compiler to use more threads per row by
    # increasing the number of warps (`num_warps`) over which each row is distributed.
    # You will see in the next tutorial how to auto-tune this value in a more natural
    # way so you don't have to come up with manual heuristics yourself.
    num_warps = 4
    if BLOCK >= 2048: num_warps = 8
    if BLOCK >= 4096: num_warps = 16
    # Allocate output
    y = torch.empty_like(x)
    # Enqueue kernel. The launch grid is simple: we have one kernel instance per row of the input matrix
    _softmax[(M, )](y, x, x.stride(0), y.stride(0), M, N, num_warps=num_warps, BLOCK=BLOCK)
    return y

Unit Test

确认在非规则的行列数矩阵测试上面的内核,验证上面的对齐机制是否工作正常。

torch.manual_seed(0)
x = torch.randn(1823, 781, device='cuda')
y_tri = softmax(x)
y_ref = torch.softmax(x, axis=1)
print(torch.allclose(y_tri, y_ref))

输出:

True

结果如期望的一样。

Benchmark

然后我们benchmark我们的操作,作为function of the number of columns in the input matrix – 假定为4096 行。然后比较其性能:1) torch.softmax 和 (2)  naive_softmax 上面定义的。

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['triton', 'torch-native', 'torch-jit'],  # possible values for `line_arg``
        line_names=["Triton", "Torch (native)", "Torch (jit)"],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('green', '--')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096}  # values for function arguments not in `x_names` and `y_name`
    )
)
def benchmark(M, N, provider):
    x = torch.randn(M, N, device='cuda', dtype=torch.float32)
    if provider == 'torch-native':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: softmax(x))
    if provider == 'torch-jit':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: naive_softmax(x))
    gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)


benchmark.run(show_plots=True, print_data=True)
02 fused softmax

输出:

softmax-performance:
          N      Triton  Torch (native)  Torch (jit)
0     256.0  512.000001      546.133347   273.066674
1     384.0  585.142862      585.142862   267.130429
2     512.0  630.153853      606.814814   264.258068
3     640.0  682.666684      640.000002   269.473696
4     768.0  702.171410      664.216187   273.066663
..      ...         ...             ...          ...
93  12160.0  812.359066      406.179533   329.483481
94  12288.0  812.429770      415.661740   329.602681
95  12416.0  810.840807      412.149375   329.173158
96  12544.0  810.925276      412.546756   329.292871
97  12672.0  811.007961      412.097543   329.410251

[98 rows x 4 columns]

上面的图中,我们可以看到:

  • Triton 快 2-3x 倍,对比于 Torch JIT。

  • Triton 比 torch.softmax也要快。 我们猜测 PyTorch kernel 只是 partially fuses the computation of the softmax.

  • 这意味着 – when temporary data is too large to fit entirely in the GPU’s cache – it transfers almost twice the amount of memory necessary.

  • 注意Triton kernel 不仅是比 PyTorch’s CUDA kernel要快,而且更易于阅读、理解和维护。

Total running time of the script: ( 1 minutes 8.174 seconds)

Matrix Multiplication

Matrix Multiplication

该教程中,将编写一个 25行的高性能 FP16 矩阵乘法内核达到 performance on par with cuBLAS. 将学习了解:

  • Block-level matrix multiplications
  • Multi-dimensional pointer arithmetic
  • Program re-ordering for improved L2 cache hit rate
  • Automatic performance tuning

Motivations

矩阵相乘是很多现代高性能计算系统的关键模块。 这很难进行优化,因此大部分情况都是由硬件提供商实现,称为: “kernel libraries” (e.g., cuBLAS). 不幸的是,这些库经常专用的很难定制适用现代深度学习的工作负载。 本教程将让你学到使用Triton实现一个高效的矩阵乘法器,并且易于定制和扩展。

简单滴说,该kernel将实现下面的 blocked algorithm:

# do in parallel
for m in range(0, M, BLOCK_M):
  # do in parallel
  for n in range(0, N, BLOCK_N):
    acc = zeros((BLOCK_M, BLOCK_N), dtype=float32)
    for k in range(0, K, BLOCK_K):
      a = A[m : m+BLOCK_M, k : k+BLOCK_K]
      b = B[k : k+BLOCK_K, n : n+BLOCK_N]
      acc += dot(a, b)
    C[m : m+BLOCK_M, n : n+BLOCK_N] = acc;

where each iteration of the doubly-nested for-loop corresponds to a Triton program instance.

Compute Kernel

上面的算法,实际上在Triton中可以相当直接的实现。主要难点来自于块内存的分配计算必须在循环的内部读出,因此,我们需要multi-dimensional pointer arithmetics.

Pointer Arithmetics

对于row-major 2D tensor X, the memory location of X[i, j] is given by &X[i, j] = X + i*stride_x_0 + j*stride_x_1. 因此, blocks of pointers for A[m : m+BLOCK_M, k:k+BLOCK_K]B[k : k+BLOCK_K, n : n+BLOCK_N] 可以被定义为下面的伪代码:

&A[m : m+BLOCK_M, k:k+BLOCK_K] =  A + (m : m+BLOCK_M)[:, None]*A.stride(0) + (k : k+BLOCK_K)[None, :]*A.stride(1);
&B[k : k+BLOCK_K, n:n+BLOCK_N] =  B + (k : k+BLOCK_K)[:, None]*B.stride(0) + (n : n+BLOCK_N)[None, :]*B.stride(1);

这意味着pointers for blocks of A and B 可以在 (i.e., k=0) 在 Triton中初始化为:

pid_m = triton.program_id(0)
pid_n = triton.program_id(1)
rm = pid_m * BLOCK_M + triton.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + triton.arange(0, BLOCK_N)
rk = triton.arange(0, BLOCK_K)
// pointer for A operand
pa = A + (rm[:, None] * stride_a_0 + rk[None, :] * stride_a_1);
// pointer for B operand
pb = B + (rk[:, None] * stride_b_0 + rn[None, :] * stride_b_1);

然后在 inner loop 中更新,如下:

pa += BLOCK_K * stride_a_1;
pb += BLOCK_K * stride_b_0;

L2 Cache Optimizations

如上提到, each program instance computes an [BLOCK_M, BLOCK_N] block of C. 需要记住重要的一点是,被计算的 blocks的顺序很重要,因为这会影响 L2 cache 的命中率。不幸地,简单的row-major ordering:

pid = triton.program_id(0);
grid_m = (M + BLOCK_M - 1) // BLOCK_M;
grid_n = (N + BLOCK_N - 1) // BLOCK_N;
pid_m = pid / grid_n;
pid_n = pid % grid_n;

is just not going to cut it.

一个可能的解决方案是 launch blocks in an order that promotes data reuse. 可以通过 ‘super-grouping’ blocks in groups of GROUP_M rows在切换到下一个列之前:

pid = triton.program_id(0);
width = GROUP_M * grid_n;
group_id = pid // width;
# we need to handle the case where M % (GROUP_M*BLOCK_M) != 0
group_size = min(grid_m - group_id * GROUP_M, GROUP_M);
pid_m = group_id * GROUP_M + (pid % group_size);
pid_n = (pid % width) // (group_size);

实践中,这会提升我们的matrix multiplication kernel 性能>10%  (在一些硬件上,如220 to 245 TFLOPS on A100).

最终结果

import torch
import triton
import triton.language as tl

# %
# :code:`triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
#   - A list of :code:`triton.Config` objects that define different configurations of meta-parameters (e.g., BLOCK_M) and compilation options (e.g., num_warps) to try
#   - A autotuning *key* whose change in values will trigger evaluation of all the provided configs

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64,  'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 256, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
        triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=4, num_warps=4),\
        triton.Config({'BLOCK_M': 64 , 'BLOCK_N': 32 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),\
        triton.Config({'BLOCK_M': 32 , 'BLOCK_N': 64 , 'BLOCK_K': 32, 'GROUP_M': 8}, num_stages=5, num_warps=2),
        #triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'GROUP_M': 8}, num_warps=4),
    ],
    key=['M', 'N', 'K'],
)
# %
# We can now define our kernel as normal, using all the techniques presented above
@triton.jit
def _matmul(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, **META):
    # extract meta-parameters
    BLOCK_M = META['BLOCK_M']
    BLOCK_N = META['BLOCK_N']
    BLOCK_K = META['BLOCK_K']
    GROUP_M = 8
    # matrix multiplication
    pid = tl.program_id(0)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N
    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    # do matrix multiplication
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    rk = tl.arange(0, BLOCK_K)
    A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)
    B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for k in range(K, 0, -BLOCK_K):
        a = tl.load(A)
        b = tl.load(B)
        acc += tl.dot(a, b)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk
    # triton can accept arbitrary activation function
    # via metaparameters!
    if META['ACTIVATION']:
        acc = META['ACTIVATION'](acc)
    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
    mask = (rm[:, None] < M) & (rn[None, :] < N)
    tl.store(C, acc, mask=mask)


# we can fuse `leaky_relu` by providing it as an `ACTIVATION` meta-parameter in `_matmul`
@triton.jit
def leaky_relu(x):
    return tl.where(x >= 0, x, 0.01*x)

现在创建 convenience wrapper function,包含两个输入tensors 以及: (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel。

def matmul(a, b, activation=None):
    # checks constraints
    assert a.shape[1] == b.shape[0], "incompatible dimensions"
    assert a.is_contiguous(), "matrix A must be contiguous"
    assert b.is_contiguous(), "matrix B must be contiguous"
    M, K = a.shape
    _, N = b.shape
    # allocates output
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    # launch kernel
    grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), )
    pgm = _matmul[grid](
        a, b, c, M, N, K, \
        a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1),\
        ACTIVATION = activation
    )
    # done; return the output tensor
    return c

Unit Test

测试matrix multiplication operation 相对于native torch implementation (i.e., cuBLAS):

torch.manual_seed(0)
a = torch.randn((512, 512), device='cuda', dtype=torch.float16)
b = torch.randn((512, 512), device='cuda', dtype=torch.float16)
c_0 = matmul(a, b, activation=None)
c_1 = torch.matmul(a, b)
print(c_0)
print(c_1)
print(triton.testing.allclose(c_0, c_1))

输出:

tensor([[  1.1045, -36.9688,  31.4688,  ..., -11.3984,  24.4531, -32.3438],
        [  6.3555, -19.6094,  34.0938,  ...,  -5.8945,   5.2891,   6.8867],
        [-32.0625,   5.9492,  15.3984,  ..., -21.3906, -23.9844, -10.1328],
        ...,
        [ -5.7031,   7.4492,   8.2656,  ..., -10.6953, -40.0000,  17.7500],
        [ 25.5000,  24.3281,  -8.4688,  ..., -18.9375,  32.5312, -29.9219],
        [ -5.3477,   4.9844,  11.8906,  ...,   5.5898,   6.4023, -17.3125]],
       device='cuda:0', dtype=torch.float16)
tensor([[  1.1045, -36.9688,  31.4688,  ..., -11.3906,  24.4531, -32.3438],
        [  6.3516, -19.6094,  34.0938,  ...,  -5.8906,   5.2812,   6.8828],
        [-32.0625,   5.9531,  15.3984,  ..., -21.4062, -23.9844, -10.1328],
        ...,
        [ -5.7070,   7.4492,   8.2656,  ..., -10.6953, -40.0000,  17.7500],
        [ 25.5000,  24.3438,  -8.4609,  ..., -18.9375,  32.5312, -29.9219],
        [ -5.3477,   4.9805,  11.8828,  ...,   5.5859,   6.4023, -17.3125]],
       device='cuda:0', dtype=torch.float16)
tensor(True, device='cuda:0')

Benchmark

Square Matrix 性能

我们现在比较一下我们的内核与 cuBLAS的性能。这里聚焦在 square matrices, 但是也可以将这个脚本用于其它的 matrix shape的测试。

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['M', 'N', 'K'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(1, 33)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=['cublas', 'cublas + relu', 'triton', 'triton + relu'],  # possible values for `line_arg``
        line_names=["cuBLAS", "cuBLAS (+ torch.nn.LeakyReLU)", "Triton", "Triton (+ LeakyReLU)"],  # label name for the lines
        styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')],  # line styles
        ylabel="TFLOPS",  # label name for the y-axis
        plot_name="matmul-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={}
    )
)
def benchmark(M, N, K, provider):
    a = torch.randn((M, K), device='cuda', dtype=torch.float16)
    b = torch.randn((K, N), device='cuda', dtype=torch.float16)
    if provider == 'cublas':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b))
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b))
    if provider == 'cublas + relu':
        torch_relu = torch.nn.ReLU(inplace=True)
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_relu(torch.matmul(a, b)))
    if provider == 'triton + relu':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, activation=leaky_relu))
    perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)


benchmark.run(show_plots=True, print_data=True)
03 matrix multiplication

输出:

matmul-performance:
         M     cuBLAS  ...     Triton  Triton (+ LeakyReLU)
0    128.0   0.455111  ...   0.512000              0.512000
1    256.0   2.978909  ...   2.978909              2.978909
2    384.0   7.372800  ...   7.899428              7.899428
3    512.0  14.563555  ...  16.384000             15.420235
4    640.0  22.260869  ...  24.380953             24.380953
5    768.0  32.768000  ...  34.028308             34.028308
6    896.0  39.025776  ...  39.025776             39.025776
7   1024.0  51.150050  ...  52.428801             52.428801
8   1152.0  44.566925  ...  46.656000             46.656000
9   1280.0  51.200001  ...  56.109587             56.109587
10  1408.0  64.138541  ...  65.684049             65.684049
11  1536.0  80.430545  ...  76.106321             75.296679
12  1664.0  63.372618  ...  62.061463             61.636381
13  1792.0  72.983276  ...  68.953520             68.533074
14  1920.0  69.120002  ...  68.435645             68.435645
15  2048.0  73.908442  ...  75.573044             75.234154
16  2176.0  83.500614  ...  80.173899             79.855747
17  2304.0  68.446623  ...  73.051599             72.607513
18  2432.0  71.125224  ...  81.197876             80.963875
19  2560.0  77.649287  ...  76.027843             76.740048
20  2688.0  83.552988  ...  83.186525             82.823267
21  2816.0  84.035084  ...  76.921000             79.733474
22  2944.0  82.102191  ...  80.122235             78.729910
23  3072.0  82.540970  ...  82.661468             82.661468
24  3200.0  84.432717  ...  89.385477             84.432717
25  3328.0  83.905938  ...  86.113988             86.528001
26  3456.0  82.015834  ...  83.545665             84.156124
27  3584.0  87.466332  ...  92.600816             84.988707
28  3712.0  85.163978  ...  82.902362             83.666116
29  3840.0  84.292684  ...  84.550462             85.070769
30  3968.0  89.921841  ...  87.472354             87.409694
31  4096.0  93.792965  ...  89.478485             90.260743

[32 rows x 5 columns]

Total running time of the script: ( 2 minutes 2.376 seconds)

 

展开阅读全文
加载中
点击引领话题📣 发布并加入讨论🔥
0 评论
0 收藏
0
分享
返回顶部
顶部