Skip to content

Vulkan 计算和跨平台 GPU

Vulkan 是唯一一款可在所有主要平台上运行的 GPU 计算 API:NVIDIA、AMD、Intel、Apple(通过 MoltenVK)、Android,甚至浏览器(通过 WebGPU)。该文件涵盖了 Vulkan 架构、compute pipeline、在 GLSL 中编写compute shader、GPU 计算程序的完整 C++ 设置、shared memory 和 synchronisation、浏览器的 WebGPU 以及实用的 ML inference 示例

  • CUDA 在 NVIDIA 硬件上主导 ML training。但并非每个部署目标都有 NVIDIA GPU。移动应用程序在 Qualcomm Adreno 或 ARM Mali GPU 上运行。 Web 应用程序在浏览器中运行。游戏引擎需要同时支持AMD、Intel、NVIDIA。对于所有这些,Vulkan 就是答案。

  • Vulkan 很冗长——一个“hello world”计算程序大约有 300 行 C++。但这种冗长是显式控制的代价:您自己管理每个 GPU 资源(内存、管道、命令缓冲区)。这种控制可以实现最大的性能和可移植性,但代价是开发速度。

Vulkan架构概述

  • Vulkan是由Khronos Group(OpenGL背后的同一组织)创建的低级GPU API。与 CUDA(隐藏了 GPU 资源管理)不同,Vulkan 要求您显式管理:

    • 实例和设备:创建一个 Vulkan 实例,枚举可用的 GPU,然后选择一个。
    • 内存:显式分配 GPU 内存,指定内存类型(设备本地用于速度,主机可见用于 CPU 访问)。
    • 缓冲区:创建引用分配的内存的缓冲区对象。
    • descriptor set:将缓冲区绑定到着色器输入(如 compute shader 的函数参数)。
    • compute pipeline:编译着色器并创建管道对象。
    • 命令缓冲区:记录 GPU 命令序列(绑定管道、绑定描述符、调度计算)。
    • 队列commit:将command buffercommit给GPU执行。
    • synchronisation:栅栏和障碍物以确保正确的排序。
  • 这与CUDA的cudaMalloc + kernellaunch model完全不同。在 CUDA 中,驱动程序在幕后处理大部分工作。在Vulkan中,你自己做。

为什么这么冗长?

  • Vulkan 的显式性存在有两个原因:

    1. 驱动程序简单性:OpenGL 驱动程序非常复杂(他们必须猜测应用程序的意图并相应地进行优化)。 Vulkan 将这一责任转移给应用程序,使驱动程序更精简、更可预测,并且更容易跨供应商正确实施。

    2. 性能:对内存布局、synchronisation 和命令批处理的显式控制可让应用程序做出最佳决策。在CUDA中,驱动程序可能会插入不必要的synchronisation。在 Vulkan 中,您仅在需要时进行synchronisation。

GLSL 中的compute shader

  • compute shader 是在 GPU 上运行的程序,类似于 CUDA kernel。它是用 GLSL (OpenGL 着色语言)编写的,并编译为 SPIR-V 字节码(一种可移植的二进制格式)。

向量加法

// add.comp — compile with: glslangValidator -V add.comp -o add.spv
#version 450

// Workgroup size: 256 invocations per workgroup (= threads per block in CUDA)
layout(local_size_x = 256) in;

// Buffer bindings (like kernel arguments)
layout(set = 0, binding = 0) buffer InputA { float a[]; };
layout(set = 0, binding = 1) buffer InputB { float b[]; };
layout(set = 0, binding = 2) buffer Output { float c[]; };

// Push constant: small uniform data (like a kernel parameter)
layout(push_constant) uniform PushConstants {
    uint n;  // number of elements
};

void main() {
    uint idx = gl_GlobalInvocationID.x;  // global thread index
    if (idx < n) {
        c[idx] = a[idx] + b[idx];
    }
}
  • 映射到 CUDA 概念
Vulkan CUDA 意义
工作组 堵塞 可shared memory的threads组
祈求 线 单一执行单元
gl_GlobalInvocationID blockIdx * blockDim + threadIdx 全球thread指数
gl_LocalInvocationID threadIdx 工作组内的thread索引
gl_WorkGroupID blockIdx 工作组索引
local_size_x blockDim.x 每个工作组的thread数
存储缓冲区 global memory 读/写GPU存储器
shared memory(shared __shared__ 每个工作组快速记忆
推常数 kernel参数 小统一数据

ReLU 带shared memory

// relu_shared.comp
#version 450

layout(local_size_x = 256) in;

layout(set = 0, binding = 0) buffer Input  { float input_data[]; };
layout(set = 0, binding = 1) buffer Output { float output_data[]; };

layout(push_constant) uniform PushConstants { uint n; };

// Shared memory (equivalent to CUDA __shared__)
shared float tile[256];

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint lid = gl_LocalInvocationID.x;

    // Load into shared memory
    if (gid < n) {
        tile[lid] = input_data[gid];
    }

    // Barrier: wait for all invocations in workgroup to finish loading
    barrier();  // equivalent to CUDA __syncthreads()

    // Compute ReLU
    if (gid < n) {
        output_data[gid] = max(tile[lid], 0.0);
    }
}
  • 对于 ReLU,shared memory 并不是严格必要的(操作是逐元素的)。但这演示了该模式:加载到 shared memory → 屏障 → 计算 → 存储。对于需要来自邻近 threads 的数据的操作(卷积、归约、softmax),shared memory 是必不可少的。

并行减少(总和)

// reduce_sum.comp
#version 450

layout(local_size_x = 256) in;

layout(set = 0, binding = 0) buffer Input  { float input_data[]; };
layout(set = 0, binding = 1) buffer Output { float partial_sums[]; };

layout(push_constant) uniform PushConstants { uint n; };

shared float sdata[256];

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint lid = gl_LocalInvocationID.x;
    uint wgid = gl_WorkGroupID.x;

    // Load into shared memory
    sdata[lid] = (gid < n) ? input_data[gid] : 0.0;
    barrier();

    // Tree reduction within the workgroup
    for (uint stride = 128; stride > 0; stride >>= 1) {
        if (lid < stride) {
            sdata[lid] += sdata[lid + stride];
        }
        barrier();
    }

    // Thread 0 writes the workgroup's partial sum
    if (lid == 0) {
        partial_sums[wgid] = sdata[0];
    }
}
  • 这是经典的并行缩减模式(与 CUDA 相同)。每个工作组产生一个部分总和。第二次调度将部分总和减少为最终结果。树缩减在每一步将活动的 threads 减半:256 → 128 → 64 → ... → 1。

矩阵乘法与平铺

// matmul_tiled.comp
#version 450

#define TILE_SIZE 16

layout(local_size_x = TILE_SIZE, local_size_y = TILE_SIZE) in;

layout(set = 0, binding = 0) buffer MatA { float A[]; };
layout(set = 0, binding = 1) buffer MatB { float B[]; };
layout(set = 0, binding = 2) buffer MatC { float C[]; };

layout(push_constant) uniform PushConstants {
    uint M, N, K;
};

shared float tileA[TILE_SIZE][TILE_SIZE];
shared float tileB[TILE_SIZE][TILE_SIZE];

void main() {
    uint row = gl_GlobalInvocationID.y;
    uint col = gl_GlobalInvocationID.x;
    uint lr = gl_LocalInvocationID.y;
    uint lc = gl_LocalInvocationID.x;

    float sum = 0.0;

    for (uint t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
        // Load tile of A and B into shared memory
        uint aCol = t * TILE_SIZE + lc;
        uint bRow = t * TILE_SIZE + lr;

        tileA[lr][lc] = (row < M && aCol < K) ? A[row * K + aCol] : 0.0;
        tileB[lr][lc] = (bRow < K && col < N) ? B[bRow * N + col] : 0.0;

        barrier();

        // Compute partial dot product
        for (uint k = 0; k < TILE_SIZE; k++) {
            sum += tileA[lr][k] * tileB[k][lc];
        }

        barrier();
    }

    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}
  • 这与 CUDA 版本(文件 04)相同的 tiling 算法,只是采用 GLSL 语法。概念是相同的:将图block加载到 shared memory、屏障、计算、屏障、重复。

C++ Vulkan 设置

  • compute shader 是最简单的部分。最困难的部分是 C++ 样板,它创建 Vulkan 实例、分配内存、绑定缓冲区和commit命令。这是完整管道的精简版本:
// vulkan_compute.cpp — a minimal but complete Vulkan compute example
// Compile: g++ -O3 -o vulkan_compute vulkan_compute.cpp -lvulkan
// Requires: Vulkan SDK installed, add.spv compiled from add.comp

#include <vulkan/vulkan.h>
#include <iostream>
#include <vector>
#include <fstream>
#include <cassert>

// Helper: read SPIR-V file
std::vector<uint32_t> readSPIRV(const std::string& filename) {
    std::ifstream file(filename, std::ios::ate | std::ios::binary);
    size_t fileSize = file.tellg();
    std::vector<uint32_t> buffer(fileSize / sizeof(uint32_t));
    file.seekg(0);
    file.read(reinterpret_cast<char*>(buffer.data()), fileSize);
    return buffer;
}

int main() {
    const uint32_t N = 1024;
    const size_t bufferSize = N * sizeof(float);

    // ========== 1. Create Vulkan Instance ==========
    VkApplicationInfo appInfo{};
    appInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
    appInfo.apiVersion = VK_API_VERSION_1_2;

    VkInstanceCreateInfo instanceInfo{};
    instanceInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
    instanceInfo.pApplicationInfo = &appInfo;

    VkInstance instance;
    vkCreateInstance(&instanceInfo, nullptr, &instance);

    // ========== 2. Select Physical Device (GPU) ==========
    uint32_t deviceCount = 0;
    vkEnumeratePhysicalDevices(instance, &deviceCount, nullptr);
    std::vector<VkPhysicalDevice> devices(deviceCount);
    vkEnumeratePhysicalDevices(instance, &deviceCount, devices.data());
    VkPhysicalDevice physicalDevice = devices[0];  // use first GPU

    // Print GPU name
    VkPhysicalDeviceProperties props;
    vkGetPhysicalDeviceProperties(physicalDevice, &props);
    std::cout << "Using GPU: " << props.deviceName << "\n";

    // ========== 3. Find Compute Queue Family ==========
    uint32_t queueFamilyCount = 0;
    vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, &queueFamilyCount, nullptr);
    std::vector<VkQueueFamilyProperties> queueFamilies(queueFamilyCount);
    vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice, &queueFamilyCount, queueFamilies.data());

    uint32_t computeFamily = 0;
    for (uint32_t i = 0; i < queueFamilyCount; i++) {
        if (queueFamilies[i].queueFlags & VK_QUEUE_COMPUTE_BIT) {
            computeFamily = i;
            break;
        }
    }

    // ========== 4. Create Logical Device and Queue ==========
    float queuePriority = 1.0f;
    VkDeviceQueueCreateInfo queueInfo{};
    queueInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
    queueInfo.queueFamilyIndex = computeFamily;
    queueInfo.queueCount = 1;
    queueInfo.pQueuePriorities = &queuePriority;

    VkDeviceCreateInfo deviceInfo{};
    deviceInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
    deviceInfo.queueCreateInfoCount = 1;
    deviceInfo.pQueueCreateInfos = &queueInfo;

    VkDevice device;
    vkCreateDevice(physicalDevice, &deviceInfo, nullptr, &device);

    VkQueue computeQueue;
    vkGetDeviceQueue(device, computeFamily, 0, &computeQueue);

    // ========== 5. Allocate Buffers (A, B, C) ==========
    // For brevity, this uses host-visible memory (slower but simpler)
    auto createBuffer = [&](VkBuffer& buffer, VkDeviceMemory& memory) {
        VkBufferCreateInfo bufInfo{};
        bufInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
        bufInfo.size = bufferSize;
        bufInfo.usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
        vkCreateBuffer(device, &bufInfo, nullptr, &buffer);

        VkMemoryRequirements memReqs;
        vkGetBufferMemoryRequirements(device, buffer, &memReqs);

        // Find host-visible memory type
        VkPhysicalDeviceMemoryProperties memProps;
        vkGetPhysicalDeviceMemoryProperties(physicalDevice, &memProps);
        uint32_t memType = 0;
        for (uint32_t i = 0; i < memProps.memoryTypeCount; i++) {
            if ((memReqs.memoryTypeBits & (1 << i)) &&
                (memProps.memoryTypes[i].propertyFlags &
                 (VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT))) {
                memType = i;
                break;
            }
        }

        VkMemoryAllocateInfo allocInfo{};
        allocInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
        allocInfo.allocationSize = memReqs.size;
        allocInfo.memoryTypeIndex = memType;
        vkAllocateMemory(device, &allocInfo, nullptr, &memory);
        vkBindBufferMemory(device, buffer, memory, 0);
    };

    VkBuffer bufA, bufB, bufC;
    VkDeviceMemory memA, memB, memC;
    createBuffer(bufA, memA);
    createBuffer(bufB, memB);
    createBuffer(bufC, memC);

    // ========== 6. Fill Input Buffers ==========
    float* ptrA;
    vkMapMemory(device, memA, 0, bufferSize, 0, (void**)&ptrA);
    for (uint32_t i = 0; i < N; i++) ptrA[i] = 1.0f;
    vkUnmapMemory(device, memA);

    float* ptrB;
    vkMapMemory(device, memB, 0, bufferSize, 0, (void**)&ptrB);
    for (uint32_t i = 0; i < N; i++) ptrB[i] = 2.0f;
    vkUnmapMemory(device, memB);

    // ========== 7. Create Compute Pipeline ==========
    auto spirvCode = readSPIRV("add.spv");
    VkShaderModuleCreateInfo shaderInfo{};
    shaderInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
    shaderInfo.codeSize = spirvCode.size() * sizeof(uint32_t);
    shaderInfo.pCode = spirvCode.data();
    VkShaderModule shaderModule;
    vkCreateShaderModule(device, &shaderInfo, nullptr, &shaderModule);

    // Descriptor set layout (tells Vulkan about the buffer bindings)
    VkDescriptorSetLayoutBinding bindings[3] = {};
    for (int i = 0; i < 3; i++) {
        bindings[i].binding = i;
        bindings[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
        bindings[i].descriptorCount = 1;
        bindings[i].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
    }

    VkDescriptorSetLayoutCreateInfo layoutInfo{};
    layoutInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
    layoutInfo.bindingCount = 3;
    layoutInfo.pBindings = bindings;
    VkDescriptorSetLayout descLayout;
    vkCreateDescriptorSetLayout(device, &layoutInfo, nullptr, &descLayout);

    // Push constant range
    VkPushConstantRange pushRange{};
    pushRange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
    pushRange.offset = 0;
    pushRange.size = sizeof(uint32_t);

    // Pipeline layout
    VkPipelineLayoutCreateInfo pipeLayoutInfo{};
    pipeLayoutInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
    pipeLayoutInfo.setLayoutCount = 1;
    pipeLayoutInfo.pSetLayouts = &descLayout;
    pipeLayoutInfo.pushConstantRangeCount = 1;
    pipeLayoutInfo.pPushConstantRanges = &pushRange;
    VkPipelineLayout pipelineLayout;
    vkCreatePipelineLayout(device, &pipeLayoutInfo, nullptr, &pipelineLayout);

    // Compute pipeline
    VkComputePipelineCreateInfo pipeInfo{};
    pipeInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
    pipeInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
    pipeInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
    pipeInfo.stage.module = shaderModule;
    pipeInfo.stage.pName = "main";
    pipeInfo.layout = pipelineLayout;
    VkPipeline pipeline;
    vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipeInfo, nullptr, &pipeline);

    // ========== 8. Descriptor Set (bind buffers to shader) ==========
    VkDescriptorPoolSize poolSize{};
    poolSize.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
    poolSize.descriptorCount = 3;

    VkDescriptorPoolCreateInfo poolInfo{};
    poolInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
    poolInfo.maxSets = 1;
    poolInfo.poolSizeCount = 1;
    poolInfo.pPoolSizes = &poolSize;
    VkDescriptorPool descPool;
    vkCreateDescriptorPool(device, &poolInfo, nullptr, &descPool);

    VkDescriptorSetAllocateInfo descAllocInfo{};
    descAllocInfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
    descAllocInfo.descriptorPool = descPool;
    descAllocInfo.descriptorSetCount = 1;
    descAllocInfo.pSetLayouts = &descLayout;
    VkDescriptorSet descSet;
    vkAllocateDescriptorSets(device, &descAllocInfo, &descSet);

    // Write buffer references into the descriptor set
    VkDescriptorBufferInfo bufInfos[3] = {
        {bufA, 0, bufferSize}, {bufB, 0, bufferSize}, {bufC, 0, bufferSize}
    };
    VkWriteDescriptorSet writes[3] = {};
    for (int i = 0; i < 3; i++) {
        writes[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
        writes[i].dstSet = descSet;
        writes[i].dstBinding = i;
        writes[i].descriptorCount = 1;
        writes[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
        writes[i].pBufferInfo = &bufInfos[i];
    }
    vkUpdateDescriptorSets(device, 3, writes, 0, nullptr);

    // ========== 9. Record and Submit Command Buffer ==========
    VkCommandPoolCreateInfo cmdPoolInfo{};
    cmdPoolInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
    cmdPoolInfo.queueFamilyIndex = computeFamily;
    VkCommandPool cmdPool;
    vkCreateCommandPool(device, &cmdPoolInfo, nullptr, &cmdPool);

    VkCommandBufferAllocateInfo cmdAllocInfo{};
    cmdAllocInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
    cmdAllocInfo.commandPool = cmdPool;
    cmdAllocInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
    cmdAllocInfo.commandBufferCount = 1;
    VkCommandBuffer cmdBuf;
    vkAllocateCommandBuffers(device, &cmdAllocInfo, &cmdBuf);

    VkCommandBufferBeginInfo beginInfo{};
    beginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
    vkBeginCommandBuffer(cmdBuf, &beginInfo);

    vkCmdBindPipeline(cmdBuf, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
    vkCmdBindDescriptorSets(cmdBuf, VK_PIPELINE_BIND_POINT_COMPUTE,
                            pipelineLayout, 0, 1, &descSet, 0, nullptr);
    vkCmdPushConstants(cmdBuf, pipelineLayout, VK_SHADER_STAGE_COMPUTE_BIT,
                       0, sizeof(uint32_t), &N);
    vkCmdDispatch(cmdBuf, (N + 255) / 256, 1, 1);  // launch workgroups

    vkEndCommandBuffer(cmdBuf);

    // Submit
    VkFenceCreateInfo fenceInfo{};
    fenceInfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO;
    VkFence fence;
    vkCreateFence(device, &fenceInfo, nullptr, &fence);

    VkSubmitInfo submitInfo{};
    submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
    submitInfo.commandBufferCount = 1;
    submitInfo.pCommandBuffers = &cmdBuf;
    vkQueueSubmit(computeQueue, 1, &submitInfo, fence);
    vkWaitForFences(device, 1, &fence, VK_TRUE, UINT64_MAX);

    // ========== 10. Read Results ==========
    float* ptrC;
    vkMapMemory(device, memC, 0, bufferSize, 0, (void**)&ptrC);
    std::cout << "Results: c[0]=" << ptrC[0] << " c[1]=" << ptrC[1]
              << " (expected 3.0)\n";
    bool correct = true;
    for (uint32_t i = 0; i < N; i++) {
        if (ptrC[i] != 3.0f) { correct = false; break; }
    }
    std::cout << (correct ? "ALL CORRECT" : "ERRORS FOUND") << "\n";
    vkUnmapMemory(device, memC);

    // ========== Cleanup (abbreviated) ==========
    vkDestroyFence(device, fence, nullptr);
    vkDestroyCommandPool(device, cmdPool, nullptr);
    vkDestroyPipeline(device, pipeline, nullptr);
    vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
    vkDestroyDescriptorPool(device, descPool, nullptr);
    vkDestroyDescriptorSetLayout(device, descLayout, nullptr);
    vkDestroyShaderModule(device, shaderModule, nullptr);
    vkDestroyBuffer(device, bufA, nullptr); vkFreeMemory(device, memA, nullptr);
    vkDestroyBuffer(device, bufB, nullptr); vkFreeMemory(device, memB, nullptr);
    vkDestroyBuffer(device, bufC, nullptr); vkFreeMemory(device, memC, nullptr);
    vkDestroyDevice(device, nullptr);
    vkDestroyInstance(instance, nullptr);

    return 0;
}
  • 是的,vector 加法大约需要 200 行。 与 CUDA 的大约 30 行相比。这就是明确性的代价。但请注意:每一行都有一个目的。没有隐藏的驱动程序决策,没有隐式的 synchronisation,没有意外的分配。你控制一切。

  • 在实践中,您可以将此样板包装在帮助器 library 中(或使用现有的样板,如 vk-bootstrapVMA 进行内存分配,或使用 kompute 进行专注于 ML 的 Vulkan 计算)。

Kompute:用于 ML 的简化 Vulkan

  • Kompute 是一个开源 C++ library,它包装了 Vulkan 的 GPU 计算样板。同样的 vector 加法变为:
#include <kompute/Kompute.hpp>

int main() {
    kp::Manager mgr;

    auto tensorA = mgr.tensor({1, 1, 1, 1, 1});
    auto tensorB = mgr.tensor({2, 2, 2, 2, 2});
    auto tensorC = mgr.tensor({0, 0, 0, 0, 0});

    std::string shader = R"(
        #version 450
        layout(local_size_x = 1) in;
        layout(set=0, binding=0) buffer A { float a[]; };
        layout(set=0, binding=1) buffer B { float b[]; };
        layout(set=0, binding=2) buffer C { float c[]; };
        void main() {
            uint i = gl_GlobalInvocationID.x;
            c[i] = a[i] + b[i];
        }
    )";

    auto algorithm = mgr.algorithm({tensorA, tensorB, tensorC},
                                     kompute::Shader::compile_source(shader));

    mgr.sequence()
        ->record<kp::OpTensorSyncDevice>({tensorA, tensorB, tensorC})
        ->record<kp::OpAlgoDispatch>(algorithm)
        ->record<kp::OpTensorSyncLocal>({tensorC})
        ->eval();

    // tensorC now contains [3, 3, 3, 3, 3]
}
  • 更具可读性。 Kompute 处理实例创建、设备选择、内存分配、descriptor set和 command buffer 管理。您专注于着色器和数据。

WebGPU:GPU 在浏览器中计算

  • WebGPU 是 WebGL 的后继者,提供从 JavaScript 进行现代 GPU 访问。它基于 Vulkan (Linux/Android)、Metal (macOS/iOS) 和 DirectX 12 (Windows) 构建,抽象了平台差异。

  • WebGPU 使用WGSL(WebGPU 着色语言)而不是 GLSL:

// add.wgsl — WebGPU compute shader
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> c: array<f32>;

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
    let i = id.x;
    c[i] = a[i] + b[i];
}
  • JavaScript 设置(精简):
const adapter = await navigator.gpu.requestAdapter();
const device = await adapter.requestDevice();

// Create buffers
const bufferA = device.createBuffer({ size: N * 4, usage: GPUBufferUsage.STORAGE, mappedAtCreation: true });
new Float32Array(bufferA.getMappedRange()).fill(1.0);
bufferA.unmap();

// ... (similar for B and C)

// Create pipeline from WGSL shader
const pipeline = device.createComputePipeline({
    layout: 'auto',
    compute: { module: device.createShaderModule({ code: wgslSource }), entryPoint: 'main' }
});

// Dispatch
const encoder = device.createCommandEncoder();
const pass = encoder.beginComputePass();
pass.setPipeline(pipeline);
pass.setBindGroup(0, bindGroup);
pass.dispatchWorkgroups(Math.ceil(N / 256));
pass.end();
device.queue.submit([encoder.finish()]);
  • 为什么 WebGPU 对 ML 很重要:在浏览器中运行 inference 意味着没有服务器成本、没有延迟,并且用户数据永远不会离开设备。 ONNX Runtime WebTransformers.js 等库使用 WebGPU 完全在客户端运行 models(包括小型 LLM)。

何时使用 Vulkan

设想 使用Vulkan? 为什么/替代方案
ML training CUDA/Triton在NVIDIA上更简单、更快
NVIDIA GPU 上的inference TensorRT 还是 CUDA 更好
AMD/Intel GPU 上的inference 是的 仅跨供应商 GPU 计算选项
手机inference(安卓) 是的 Vulkan 是 Android 上的标准 GPU API
手机 inference (iOS) 直接使用Metal(MoltenVK增加开销)
浏览器 inference WebGPU 基于 Vulkan/金属/DX12 构建
游戏引擎+机器学习 是的 引擎已经使用 Vulkan 进行渲染
跨平台library 是的 一台 codebase 适用于所有 GPU 供应商
学习GPU编程 或许 CUDA上手比较容易; Vulkan 教更多

编码任务(使用g++ -lvulkan编译,需要Vulkan SDK)

  1. 编译并运行上面的vector加法示例。修改着色器以计算 c[i] = a[i] * b[i] + a[i](融合乘加)并验证结果。

  2. 编写一个 compute shader,将 softmax 应用于使用 shared memory 进行缩减步骤(最大值和总和)的一行数据。使用已知值进行测试。

// softmax.comp — compile with: glslangValidator -V softmax.comp -o softmax.spv
#version 450

#define WG_SIZE 256

layout(local_size_x = WG_SIZE) in;

layout(set = 0, binding = 0) buffer Input  { float input_data[]; };
layout(set = 0, binding = 1) buffer Output { float output_data[]; };

layout(push_constant) uniform PC { uint n; };

shared float sdata[WG_SIZE];

void main() {
    uint gid = gl_GlobalInvocationID.x;
    uint lid = gl_LocalInvocationID.x;

    // Step 1: find max (for numerical stability)
    sdata[lid] = (gid < n) ? input_data[gid] : -1e30;
    barrier();
    for (uint s = WG_SIZE / 2; s > 0; s >>= 1) {
        if (lid < s) sdata[lid] = max(sdata[lid], sdata[lid + s]);
        barrier();
    }
    float maxVal = sdata[0];
    barrier();

    // Step 2: compute exp(x - max)
    float expVal = (gid < n) ? exp(input_data[gid] - maxVal) : 0.0;
    sdata[lid] = expVal;
    barrier();

    // Step 3: sum of exp values
    for (uint s = WG_SIZE / 2; s > 0; s >>= 1) {
        if (lid < s) sdata[lid] += sdata[lid + s];
        barrier();
    }
    float sumExp = sdata[0];

    // Step 4: normalise
    if (gid < n) {
        output_data[gid] = expVal / sumExp;
    }
}
  1. 修改 C++ 主机代码以对 compute shader 进行基准测试:使用 Vulkan 时间戳查询或 CPU 侧围栏对调度(不包括设置)进行计时,并计算获得的带宽(以 GB/s 为单位)。