0%

从 Rust 开始的 Vulkan Raytracing

既能写 Rust 又能用 Vulkan 还能学日语,这世间怎会有如此美妙的事?

本文是对这篇日文博客的翻译和复现,当然我并不会逐字逐句完整翻译,甚至会跳过其中几个章节,总而言之这是一篇使用 Rust 包装过的 Vulkan API 实现 Ray Tracing in One Weekend 中场景的教程文。

前置条件

通读本文需要掌握的前置知识:

除了这些前置知识以外,还需要一些硬件条件,简而言之就是一台能跑 Vulkan 的电脑,我的运行环境如下:

  • 操作系统:Ubuntu 24.02 LTS
  • GPU:RTX 2070Super Max-Q
  • Vulkan SDK:1.4.304.0

由此,我们完成了所有的前置工作,接下来将先介绍一下总体流程:

  1. 使用 rust-gpu 编写光线追踪所需的几个 shader,并将其编译为 SPIR-V 的文件备用。
  2. 使用 ash 作为 Vulkan API 在 Rust 上的绑定,构建整个 Vulkan 框架。
  3. 使用 Vulkan 的 VKR API 调用之前编写好的 shader 进行光线追踪。

rust-gpu 入门

rust-gpu 是将 Rust 代码编译成 SPIR-V 的代码的库。由于我们需要使用 rust-gpu 编写光线追踪的 shader 代码,因此首先对它进行简单介绍。

设置

我们将创建两个项目,一个是使用 rust-gpu 的 shader 项目,另一个是使用 shader 进行渲染的应用,shader 将在应用的 build.rs 中进行编译。

1
2
3
4
5
# 使用 shdaer 进行渲染的应用
cargo new vulkan-raytracing
cd vulkan-raytracing
# shader 项目
cargo new shader --lib

由于 rust-gpu 需要特定的 Rust 工具链版本,请新建一个 rust-toolchain.toml 文件并从此处复制代码。

1
2
3
4
5
6
7
[toolchain]
channel = "nightly-2024-11-22"
components = ["rust-src", "rustc-dev", "llvm-tools"]
# commit_hash = b19329a37cedf2027517ae22c87cf201f93d776e

# Whenever changing the nightly channel, update the commit hash above, and make
# sure to change `REQUIRED_TOOLCHAIN` in `crates/rustc_codegen_spirv/build.rs` also.

然后我们在 vulkan-raytracing 项目的 Cargo.toml 中添加 spirv 依赖。

1
2
3
...
[build-dependencies]
spirv-builder = { git = "https://github.com/Rust-GPU/rust-gpu" }

接着在 build.rs 中添加编译选项。

1
2
3
4
5
6
7
8
9
10
11
use std::error::Error;

use spirv_builder::{SpirvBuilder, MetadataPrintout};

fn main() -> Result<(), Box<dyn Error>> {
SpirvBuilder::new("shader", "spirv-unknown-vulkan1.2")
.print_metadata(MetadataPrintout::Full)
.build()?;

Ok(())
}

shader 设置

shader 项目的 Cargo.toml 中添加 spirv-std 依赖。

1
2
3
4
5
[lib]
crate-type = ["lib", "dylib"]

[dependencies]
spirv-std = { git = "https://github.com/Rust-GPU/rust-gpu.git" }

然后我们尝试在 shader/src/lib.rs 中编写一个简单的 vertex shader 和一个简单的 fragment shader。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#![cfg_attr(target_arch = "spirv", no_std)]

use spirv_std::spirv;
use spirv_std::arch::IndexUnchecked;
use spirv_std::glam::{vec3a, vec4, Vec3A, Vec4};

// Vertex shader
#[spirv(vertex)]
pub fn main_vs(
// 类似于 gl_VertexIndex
#[spirv(vertex_index)] vertex_index: u32,
// 类似于 gl_Position
#[spirv(position)] out_pos: &mut Vec4,
color: &mut Vec3A,
) {
*out_pos = *unsafe {
[
vec4(1.0, 1.0, 0.0, 1.0),
vec4(0.0, -1.0, 0.0, 1.0),
vec4(-1.0, 1.0, 0.0, 1.0),
]
.index_unchecked(vert_id as usize)
};

*color = *unsafe {
[
vec3a(1.0, 0.0, 0.0),
vec3a(0.0, 1.0, 0.0),
vec3a(0.0, 0.0, 1.0),
]
.index_unchecked(vert_id as usize)
};
}

// Fragment shader
#[spirv(fragment)]
pub fn main_fs(color: &Vec3A, out_color: &mut Vec4) {
*out_color = vec4(color.x, color.y, color.z, 1.0);
}

这里使用带 A 后缀的变量都是依照 SPIR-V 中的规则安装 16 字节对齐的变量,我们在之后都会使用这样的变量。这里还使用了 IndexUnchecked 作为数组的索引,这样会跳过对数组进行边界检查的步骤。

编译 shader

编译的时候可以使用 env!("<shader-name>.spv") 来获得编译后的二进制文件路径。

对于 src/main.rs 我们可以编写一个简单的程序来读取编译后的二进制文件。

1
2
3
4
5
6
7
fn main() {
const shader_path: &str = env!("shader.spv");
const shader: &[u8] = include_bytes!(shader_path);

dbg!(shader_path);
dbg!(shader.len());
}

运行后就可以看到编译后的 shader 路径和大小了。

使用 ash 调用 shader

ash 是一个 Rust 的 Vulkan API 绑定库,我们将使用它来调用 Vulkan API。不过由于 Vulkan 的流程过于复杂,无法在这里一一解释,所以只能写下使用的要点。另外,网上大多关于 Vulkan 的教程都是在窗口中渲染的,而我们在这里直接将结果保存到图像文件中来进行离屏渲染。这是因为尽管 VKR 是可以实时调用的 API,但是由于我们为了使用 Ray Tracing in One Weekend 中同样的处理方法,因此无法进行实时绘制。

ash 的使用方法

由于 ash 只是一个绑定库,因此我们还是需要手动安装 Vulkan SDK

在 ash 当中,各种 struct 的构建是通过 builder 模式来进行的,这样可以避免直接构建 struct 时的繁琐。比如下面的代码就是创建一个 Vulkan 实例的过程。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
let instance = {
let application_name = CString::new("Hello Triangle").unwrap();
let engine_name = CString::new("No Engine").unwrap();

let mut debug_utils_create_info = vk::DebugUtilsMessengerCreateInfoEXT::default()
.message_severity(
vk::DebugUtilsMessageSeverityFlagsEXT::WARNING |
vk::DebugUtilsMessageSeverityFlagsEXT::ERROR,
)
.message_type(
vk::DebugUtilsMessageTypeFlagsEXT::GENERAL
| vk::DebugUtilsMessageTypeFlagsEXT::PERFORMANCE
| vk::DebugUtilsMessageTypeFlagsEXT::VALIDATION,
)
.pfn_user_callback(Some(default_vulkan_debug_utils_callback));

let application_info = vk::ApplicationInfo::default()
.application_name(application_name.as_c_str())
.application_version(vk::make_api_version(0, 1, 0, 0))
.engine_name(engine_name.as_c_str())
.engine_version(vk::make_api_version(0, 1, 0, 0))
.api_version(vk::API_VERSION_1_3);

let enabled_extension_names = [ext::debug_utils::NAME.as_ptr()];

let instance_create_info = vk::InstanceCreateInfo::default()
.application_info(&application_info)
.enabled_layer_names(validation_layers_ptr.as_slice())
.enabled_extension_names(&enabled_extension_names);

let instance_create_info = if ENABLE_VALIDATION_LAYER {
instance_create_info.push_next(&mut debug_utils_create_info)
} else {
instance_create_info
};

unsafe { entry.create_instance(&instance_create_info, None) }
.expect("failed to create instance!")
};

我们可以看到在创建 Vulkan 实例过程中,我们使用了 vk::InstanceCreateInfo::default() 来创建一个默认的实例创建信息,然后通过 builder 模式来设置各种参数。这样的代码风格在 ash 中是非常常见的。

我们在此基础上继续进行以下步骤:

  1. 选择物理设备
  2. 根据物理设备创建逻辑设备并获得队列
  3. 创建 image 和 image view
  4. 创建 render pass
  5. 创建 pipeline
  6. 创建 framebuffer 和 shader module
  7. 创建 command buffer 和 command pool
  8. 创建输出图像的 dst_image

完成了这些步骤后,我们可以调用之前的 shader 代码并录制简单的渲染和拷贝图像的 commands 进行测试,可以看到如下的图像。

Vulkan 光线追踪简介

如果只是想使用 GPU 进行光线追踪,那么我们其实不一定需要 VKR 扩展,而是写一个 compute shader 就可以完成。那么我们为什么还需要一个额外的扩展呢?这或许是因为我们想要调用 GPU 中内置的有关 BVH 构建以及光线相交判定的硬件从而进一步加快光线追踪的过程。

实现过 Ray Tracing: The Next Week 中包含 BVH 的部分的读者都会意识到,光线追踪中耗费大量时间的部分就是对光线求交的检测,而有幸的是目前的 GPU 供应商能够提供这方面的硬件加速。例如英伟达的显卡目前提供 RT Cores 进行加速,而 AMD 的显卡也有被称为 Ray Accelerators 的硬件加速器。

Acceleration Structure

在 Vulkan 中,VKR 提供了一个被称为 acceleration structure 的数据结构。它的工作方式和 BVH 相同,通过调用 API 就可以在 GPU 上构建它并进行光线相交的检测(实际上,AS 的实现是由 GPU 供应商决定的,所以也有可能使用了和 BVH 不一样的神秘算法)。

此外,AS 在满足某些特定条件的时候能够以很低的代价重建它或者将它序列化,从而在一个 GPU 上构建的 AS 可以很方便的在另一个 GPU 上使用,但是我们在这里不做展开。

一般来说,BVH 会支持 AABB 和 OBB 的求交,但是 VKR 仅支持 AABB 和三角形。

Top Level AS 和 Bottom Level AS

AS 具有 Top Level AS(TLAS) 和 Bottom Level AS(BLAS) 这样的两层结构。Shader 只能访问 TLAS,而 TLAS 可以访问多个 BLAS 和它们的变换矩阵。BLAS 中包含了多个三角形或者 AABB。对于 AABB 来说,我们需要定义一个 intersection shader 来计算相交。而 TLAS 本身并不涉及任何的三角形/AABB 计算,而 BLAS 不能包含任何别的 TLAS 或者 BLAS。

VKR 的 shaders

VKR 定义了如下的几个 shaders:

  • ray generation shader:用于生成光线
  • intersection shader:对于 AABB 的物体用于计算光线与 BLAS 的相交
  • any hit shader:用于验证得到的交点是否有效(例如实际上击中了一个透明的部分)
  • closest hit shader:用于处理被验证确实有效的交点的着色
  • miss shader:用于处理没有击中任何物体的光线
  • callable shader:用于被其他 shader 调用,我们在这里并没有使用这个 shader

这一套操作和 DirectX 12 以及 OptiX 当中的光线追踪算法是类似的,整个流程可以用如下的流程图表示:

Ray Generation shader

这个 shader 是光线追踪的入口,它的调用次数和输出像素的数量是一致的。在这个 shader 中,我们可以生成光线并调用 trace_ray 函数来进行光线追踪,然后填充回像素点。

Intersection shader

为了特殊形状的物体的求交而定义的 shader,它只会在光线击中 BLAS 中注册的 AABB 时才会被调用,因此如果使用三角形而非自己定义的物体的话可以省略这个 shader。在本文中,我们在这个 shader 中实现对球体的求交。

Any Hit shader

这个 shader 用于验证交点是否有效,例如当光线击中一个透明的物体时,我们可以略过这个物体。如果不定义这个 shader 的话,那么 VKR 会默认认为所有的交点都是有效的。

Closest Hit shader

在 intersection shader 和 any hit shader 经过处理和判断后最终确定碰撞点时会调用的 shader。在这里,我们将创建要返回 ray generation shader 当中的数据。本文中我们将在这个 shader 中计算碰撞的法线和材质。

Miss shader

当光线没有击中任何物体时会调用的 shader。在这里我们可以定义背景色。如果你只是想检查目标区域是否有任何东西阻挡光源,可以只返回一个真值。

Shader Binding Table

在这些 shader 之间我们可能会需要一些数据传递和共用,而为了方便,驱动开发商设计了一套光线追踪流程,我们按照它们的要求进行填充就能实现数据的传递。这个表被称为 Shader Binding Table(SBT),而它包含了三个 record,分别是由 intersection shader、any hit shader 和 closest hit shader 组成的 hit group record,ray generation shader 对应的 ray generation record 以及 miss shader 对应的 miss record。我们在 SBT 中需要输入不同的 record 所在的内存中的位置和偏移量,由此实现了数据的传递。

编写 VKR 的 shader 代码

至此,我们可以着手于编写光线追踪的代码了,我们首先修改 build.rs 中的代码,将其编译为 VKR 的 shader。

1
2
3
4
5
6
7
8
9
10
11
12
use std::error::Error;
use spirv_builder::{Capability, MetadataPrintout, SpirvBuilder};

fn main() -> Result<(), Box<dyn Error>> {
SpirvBuilder::new("./shader", "spirv-unknown-vulkan1.2")
.capability(Capability::RayTracingKHR)
.extension("SPV_KHR_ray_tracing")
.print_metadata(MetadataPrintout::Full)
.build()?;

Ok(())
}

伪随机数生成

在光线追踪中,我们需要使用随机数来生成采样区域,但是 SPIR-V 并没有提供随机数生成的函数,因此我们需要自己实现一个伪随机数生成器。

首先我们需要一个随机数种子。

我们在这里使用像素坐标和主机创建的随机数进行异或作为种子,当然也可以使用 VK_KHR_shader_clock(3) 之类的方式创建。对于 shader/src/lib.rs 我们可以添加如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
pub struct PushConstants {
seed: u32,
}

// 前文提到的 ray generation shader
#[spirv(ray_generation)]
pub fn main_ray_generation(
// 并行运算的光线追踪 id,表示像素的坐标
#[spirv(launch_id)] launch_id: UVec3,
// 总的像素尺寸
#[spirv(launch_size)] launch_size: UVec3,
// 主机的随机数
#[spirv(push_constant)] push_constants: &PushConstants,
) {
let rand_seed = (launch_id.y * launch_size.x + launch_id.x) ^ push_constants.seed;
}

自然,伪随机数算法有很多种,考虑到 GPU 基本上都是 32 位的架构,因此我们在这里使用 PCG 系列中的 pcg32si 算法。我们期待 pcg32si 算法能够在 GPU 上有较好的性能。

我们新建一个 shader/src/rand.rs 文件,将 pcg32si 算法的实现放在这里。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
pub struct PCG32si {
state: u32,
}

impl PCG32si {
const PCG_DEFAULT_MULTIPLIER_32: u32 = 747796405;
const PCG_DEFAULT_INCREMENT_32: u32 = 2891336453;

fn pcg_oneseq_32_step_r(&mut self) {
self.state = self
.state
.wrapping_mul(Self::PCG_DEFAULT_MULTIPLIER_32)
.wrapping_add(Self::PCG_DEFAULT_INCREMENT_32);
}

fn pcg_output_rxs_m_xs_32_32(state: u32) -> u32 {
let word = ((state >> ((state >> 28).wrapping_add(4))) ^ state).wrapping_mul(277803737);
(word >> 22) ^ word
}

pub fn new(seed: u32) -> Self {
let mut rng = Self { state: seed };
rng.pcg_oneseq_32_step_r();
rng.state = rng.state.wrapping_add(seed);
rng.pcg_oneseq_32_step_r();
rng
}

pub fn next_u32(&mut self) -> u32 {
let old_state = self.state;
self.pcg_oneseq_32_step_r();
Self::pcg_output_rxs_m_xs_32_32(old_state)
}

// 0.0..1.0
pub fn next_f32(&mut self) -> f32 {
// 由于无法使用 std 因此在这里使用 core
let float_size = core::mem::size_of::<f32>() as u32 * 8;
let precision = 23 + 1;
let scale = 1.0 / ((1 << precision) as f32);

let value = self.next_u32();
let value = value >> (float_size - precision);
scale * value as f32
}

pub fn next_f32_range(&mut self, min: f32, max: f32) -> f32 {
min + (max - min) * self.next_f32()
}
}

pub type DefaultRng = PCG32si;

这里的代码直接来源于 PCG 的 C 语言实现。而 next_f32 算法来源于官方的 rand crate。另外为了简单起见,我们并没有用 trait 来抽象随机数生成器,而是直接将其作为默认的对象。

实现相机

相机用于确定像素坐标出发的光线从哪个位置和方向进行发射,我们在这里实现一个 Ray Tracing in One Weekend 中的相机功能。

首先需要定义一个 Ray 的类,由于我们不会实现运动模糊(相机压根不会动),因此只需要位置和方向即可。

我们添加一个 shader/src/camera.rs 文件,并添加如下代码:

1
2
3
4
5
6
use spirv_std::glam::Vec3A;
#[derive(Clone, Copy, Default)]
pub struct Ray {
pub origin: Vec3A,
pub direction: Vec3A,
}

再新建一个 shader/src/math.rs 文件并添加在圆盘上随机采样的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
use crate::rand::DefaultRng;
use spirv_std::glam::{vec3a, Vec3A};

pub fn random_in_unit_disk(rng: &mut DefaultRng) -> Vec3A {
loop {
let p = vec3a(
rng.next_f32_range(-1.0, 1.0),
rng.next_f32_range(-1.0, 1.0),
0.0,
);
if p.length_squared() < 1.0 {
break p;
}
}
}

然后我们就可以在 shader/src/camera.rs 中定义一个相机类并实现 One Weekend 中的焦散模糊算法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
use crate::{math::random_in_unit_disk, rand::DefaultRng};
#[allow(unused_imports)]
use spirv_std::num_traits::Float;

#[derive(Copy, Clone)]
pub struct Camera {
origin: Vec3A,
lower_left_corner: Vec3A,
horizontal: Vec3A,
vertical: Vec3A,
u: Vec3A,
v: Vec3A,
lens_radius: f32,
}

impl Camera {
#[allow(clippy::too_many_arguments)]
pub fn new(
look_from: Vec3A,
look_at: Vec3A,
vup: Vec3A,
vfov: f32,
aspect_ratio: f32,
aperture: f32,
focus_dist: f32,
) -> Self {
let theta = vfov;
let h = (theta * 0.5).tan();
let viewport_height = 2.0 * h;
let viewport_width = aspect_ratio * viewport_height;

let w = (look_from - look_at).normalize();
let u = vup.cross(w).normalize();
let v = w.cross(u);

let origin = look_from;
let horizontal = focus_dist * viewport_width * u;
let vertical = focus_dist * viewport_height * v;
let lower_left_corner = origin - horizontal * 0.5 - vertical * 0.5 - focus_dist * w;

Self {
origin,
lower_left_corner,
horizontal,
vertical,
u,
v,
lens_radius: aperture * 0.5,
}
}

pub fn get_ray(&self, s: f32, t: f32, rng: &mut DefaultRng) -> Ray {
let rd = self.lens_radius * random_in_unit_disk(rng);
let offset = self.u * rd.x + self.v * rd.y;
Ray {
origin: self.origin + offset,
direction: (self.lower_left_corner + s * self.horizontal + t * self.vertical
- self.origin
- offset).normalize(),
}
}
}

实现 RayPayload 类

在 closest hit shader 和 miss shader 中都有一个返回值类型,当光线传播之后,可以在任意的 {closest hit, miss} shader 中返回一个值,而这个值当然需要是同一个类型。

我们将这个类称为 RayPayload,它应该包含什么样的信息?

  • 光线是否击中了物体
  • 如果没有击中
    • 背景的颜色
  • 如果击中了
    • 击中位置
    • 法线方向
    • 材质索引(我们在这里使用 storage buffer 记录材质的列表并通过索引引用)
    • 光线是从物体的内部向外击中还是外部向内击中

仔细思考一下以上内容可以用 enum 来表示,但是博客的原作者当时的选择是使用 struct 来表示,理由是在当时 rust-gpu 还不支持 enum 以及 Option<T> 的语法,当然现在已经支持了,但是我们依然顺从原作者的写法。

同时由于 Bool 不能作为 shader 的输入,因此我们在这里使用 u32 来表示。

我们在 shader/src/lib.rs 中添加如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
#[derive(Clone, Default)]
pub struct RayPayload {
// 光线是否击中了物体
pub is_miss: u32,
// 未击中时表示颜色,击中时表示位置
pub position: Vec3A,
// 法线
pub normal: Vec3A,
// 材质索引
pub material: u32,
// 光线是从内部向外击中还是外部向内击中,外部向内时 front_face 为 true
pub front_face: u32,
}

实现 miss shader

我们之前已经定义了 RayPayload,现在可以着手编写 shader 了,如同 One Weekend 中的代码,我们在这里定义了一个蓝色的背景。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
impl RayPayload {
pub fn new_miss(color: Vec3A) -> Self {
Self {
is_miss: 1,
position: color,
..Default::default()
}
}
}

#[spirv(miss)]
pub fn min_miss(
// 光线方向
#[spirv(world_ray_direction)] world_ray_direction: Vec3A,
#[spirv(incoming_ray_payload)] out: &mut RayPayload,
) {
let unit_dir = world_ray_direction.normalize();
let t = 0.5 * (unit_dir.y + 1.0);
let color = vec3a(1.0, 1.0, 1.0).lerp(vec3a(0.5, 0.7, 1.0), t);

*out = RayPayload::new_miss(color);
}

我们在这里将颜色放在了 RayPayload 当中的 position 属性中。

实现 intersection shader 和 closest hit shader

由于目标场景中只有球体,因此只需要为球体创建这两个 shader 即可。

我们假设 BLAS 是以中心为原点长度为 2 的 AABB(半径为 1 的球),然后从 TLAS 通过变换矩阵得到 BLAS。

我们在 shader/src/lib.rs 中添加如下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
impl RayPayload {
pub fn new_hit(position: Vec3A, normal: Vec3A, ray_dir: Vec3A, material: u32) -> Self {
let front_face = ray_dir.dot(normal) < 0.0;
let normal = if front_face { normal } else { -normal };
Self {
is_miss: 0,
position,
normal,
material,
front_face: front_face as u32,
}
}
}

#[spirv(intersection)]
pub fn main_intersection(
// 局部坐标系下的光线原点
#[spirv(object_ray_origin)] ray_origin: Vec3A,
// 局部坐标系下的光线方向
#[spirv(object_ray_direction)] ray_direction: Vec3A,
// 光线的开始时间
#[spirv(ray_tmin)] t_min: f32,
// 光线的结束时间
#[spirv(ray_tmax)] t_max: f32,
) {
let oc = ray_origin;
let a = ray_direction.length_squared();
let half_b = oc.dot(ray_direction);
let c = oc.length_squared() - 1.0;

let discriminant = half_b * half_b - a * c;
if discriminant < 0.0 {
return;
}

let sqrtd = discriminant.sqrt();
let root0 = (-half_b - sqrtd) / a;
let root1 = (-half_b + sqrtd) / a;

if root0 > t_min && root0 < t_max {
unsafe {
spirv_std::arch::report_intersection(root0, 0);
}
}
if root1 > t_min && root1 < t_max {
unsafe {
spirv_std::arch::report_intersection(root1, 0);
}
}
}

由于 glam 的矩阵和 SPIR-V 的矩阵类不同,因此我们需要手动定义一个 SPIR-V 的矩阵类,我们在 shader/src/lib.rs 中添加如下代码:

1
2
3
4
5
6
7
8
9
#[derive(Clone, Copy)]
#[spirv(matrix)]
#[repr(C)]
pub struct Affine3 {
pub x: Vec3A,
pub y: Vec3A,
pub z: Vec3A,
pub w: Vec3A,
}

然后我们就可以在 shader/src/lib.rs 中定义 closest hit shader 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#[spirv(closest_hit)]
pub fn main_closest_hit(
#[spirv(ray_tmax)] t: f32,
// TLAS 变换矩阵
#[spirv(object_to_world)] object_to_world: Affine3,
#[spirv(world_ray_origin)] world_ray_origin: Vec3A,
#[spirv(world_ray_direction)] world_ray_direction: Vec3A,
#[spirv(incoming_ray_payload)] out: &mut RayPayload,
// TLAS 中注册的自定义索引,用作材质索引
#[spirv(instance_custom_index)] material: u32,
) {
let hit_pos = world_ray_origin + t * world_ray_direction;
// object_to_world.w 作为变换矩阵的平移部分
let normal = (hit_pos - object_to_world.w).normalize();
*out = RayPayload::new_hit(hit_pos, normal, world_ray_direction, material);
}

材质

材质在本文中接受光线和 RayPayload 并返回颜色和反射光线,或者不返回任何内容。

我们在这里实现 One Weekend 中出现的三种材质:

  • Lambertian
    • 只需要返回一个 albedo 即可。
  • Metal
    • 通过 albedo 和 fuzzy 系数进行镜面反射。
  • Dielectric
    • 通过折射率进行折射和反射。

如前面所述,无法使用 enum,因此我们使用 struct 来表示材质。

我们新建一个 shader/src/material.rs 文件:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
use crate::{
camera::Ray,
math::{random_in_unit_sphere, IsNearZero},
rand::DefaultRng,
RayPayload,
};
use spirv_std::glam::{vec3a, vec4, Vec3A, Vec4, Vec4Swizzles};
#[allow(unused_imports)]
use spirv_std::num_traits::Float;

#[derive(Clone, Copy, Default)]
#[repr(transparent)]
pub struct EnumMaterialData {
v0: Vec4,
}

#[derive(Clone, Copy, Default)]
pub struct EnumMaterial {
t: u32,
data: EnumMaterialData,
}

由于指针不应该被类型转换,因此每种材质都被实现为 struct,以 &'a EnumMaterial 作为成员。

首先先创造一个材质的 trait

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#[derive(Clone, Default)]
pub struct Scatter {
pub color: Vec3A,
pub ray: Ray,
}

#[repr(transparent)]
struct Lambertian<'a> {
data: &'a EnumMaterialData,
}

#[repr(transparent)]
struct Metal<'a> {
data: &'a EnumMaterialData,
}

#[repr(transparent)]
struct Dielectric<'a> {
data: &'a EnumMaterialData,
}

pub trait Material {
fn scatter(
&self,
ray: &Ray,
ray_payload: &RayPayload,
rng: &mut DefaultRng,
scatter: &mut Scatter,
) -> bool;
}

其实应该可以改成返回 Option<Scatter>,但是由于前文提到的原因,因此我们在这里通过 &mut Scatter 作为返回值,并通过 bool 判断是否有效。

我们依照 One Weekend 中的代码实现各个材质的 scatter 函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
fn reflect(v: Vec3A, n: Vec3A) -> Vec3A {
v - 2.0 * v.dot(n) * n
}

fn refract(uv: Vec3A, n: Vec3A, etai_over_etat: f32) -> Vec3A {
let cos_theta = (-uv).dot(n).min(1.0);
let r_out_perp = etai_over_etat * (uv + cos_theta * n);
let r_out_parallel = -(1.0 - r_out_perp.length_squared()).abs().sqrt() * n;
r_out_perp + r_out_parallel
}

fn reflectance(cosine: f32, ref_idx: f32) -> f32 {
let r0 = (1.0 - ref_idx) / (1.0 + ref_idx);
let r0 = r0 * r0;
r0 + (1.0 - r0) * (1.0 - cosine).powf(5.0)
}

impl Lambertian<'_> {
fn albedo(&self) -> Vec3A {
self.data.v0.xyz().into()
}
}
impl Material for Lambertian<'_> {
fn scatter(
&self,
_ray: &Ray,
ray_payload: &RayPayload,
rng: &mut DefaultRng,
scatter: &mut Scatter,
) -> bool {
let scatter_direction = ray_payload.normal + random_in_unit_sphere(rng).normalize();

let scatter_direction = if scatter_direction.is_near_zero() {
ray_payload.normal
} else {
scatter_direction
};

let scattered = Ray {
origin: ray_payload.position,
direction: scatter_direction,
};

*scatter = Scatter {
color: self.albedo(),
ray: scattered,
};
true
}
}

impl Metal<'_> {
fn albedo(&self) -> Vec3A {
self.data.v0.xyz().into()
}

fn fuzz(&self) -> f32 {
self.data.v0.w
}
}

impl Material for Metal<'_> {
fn scatter(
&self,
ray: &Ray,
ray_payload: &RayPayload,
rng: &mut DefaultRng,
scatter: &mut Scatter,
) -> bool {
let reflected = reflect(ray.direction, ray_payload.normal);
let scattered = reflected + self.fuzz() * random_in_unit_sphere(rng);
if scattered.dot(ray_payload.normal) > 0.0 {
*scatter = Scatter {
color: self.albedo(),
ray: Ray {
origin: ray_payload.position,
direction: scattered,
},
};
true
} else {
false
}
}
}

impl Dielectric<'_> {
fn ir(&self) -> f32 {
self.data.v0.x
}
}

impl Material for Dielectric<'_> {
fn scatter(
&self,
ray: &Ray,
ray_payload: &RayPayload,
rng: &mut DefaultRng,
scatter: &mut Scatter,
) -> bool {
let refraction_ratio = if ray_payload.front_face != 0 {
1.0 / self.ir()
} else {
self.ir()
};

let unit_direction = ray.direction.normalize();
let cos_theta = (-unit_direction).dot(ray_payload.normal).min(1.0);
let sin_theta = (1.0 - cos_theta * cos_theta).sqrt();
let cannot_refract = refraction_ratio * sin_theta > 1.0;

let direction =
if cannot_refract || reflectance(cos_theta, refraction_ratio) > rng.next_f32() {
reflect(unit_direction, ray_payload.normal)
} else {
refract(unit_direction, ray_payload.normal, refraction_ratio)
};

*scatter = Scatter {
color: vec3a(1.0, 1.0, 1.0),
ray: Ray {
origin: ray_payload.position,
direction,
},
};
true
}
}

其中 Metal 中用到的 random_in_unit_sphere 被定义在 math.rs 中。

1
2
3
4
5
6
7
8
9
10
11
12
13
pub fn random_in_unit_sphere(rng: &mut DefaultRng) -> Vec3A {
loop {
let v = vec3a(
rng.next_f32_range(-1.0, 1.0),
rng.next_f32_range(-1.0, 1.0),
rng.next_f32_range(-1.0, 1.0),
);

if v.length_squared() < 1.0 {
break v;
}
}
}

ray generation shader 编写

现在我们完成了所有的前置要求,可以完成 ray generation shader 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
#[spirv(ray_generation)]
pub fn main_ray_generation(
#[spirv(launch_id)] launch_id: UVec3,
#[spirv(launch_size)] launch_size: UVec3,
#[spirv(push_constant)] constants: &PushConstants,
// TLAS
#[spirv(descriptor_set = 0, binding = 0)] tlas: &spirv_std::ray_tracing::AccelerationStructure,
// 输出图像
#[spirv(descriptor_set = 0, binding = 1)] output: &Image!(
2D,
format = rgba32f,
sampled = false
),
// 材质
#[spirv(storage_buffer, descriptor_set = 0, binding = 2)] materials: &[EnumMaterial],
// RayPayload
#[spirv(ray_payload)] payload: &mut RayPayload,
) {
let rand_seed = (launch_id.y * launch_size.x + launch_id.x) ^ constants.seed;
let mut rng = DefaultRng::new(rand_seed);

// 创建了一个简单的相机
let camera = Camera::new(
vec3a(13.0, 2.0, 3.0),
vec3a(0.0, 0.0, 0.0),
vec3a(0.0, 1.0, 0.0),
20.0 / 180.0 * core::f32::consts::PI,
launch_size.x as f32 / launch_size.y as f32,
0.1,
10.0,
);

let u = (launch_id.x as f32 + rng.next_f32()) / (launch_size.x - 1) as f32;
let v = (launch_id.y as f32 + rng.next_f32()) / (launch_size.y - 1) as f32;

// 不考虑 cull_mask 的功能
let cull_mask = 0xFF;
let tmin = f32::EPSILON;
let tmax = f32::INFINITY;

// Light color
let mut color = vec3a(1.0, 1.0, 1.0);
let mut ray = camera.get_ray(u, v, &mut rng);

// 最大 bounce 次数为 50
for _ in 0..50 {
unsafe {
tlas.trace_ray(
// 使用 OPAQUE 跳过 any hit shader
spirv_std::ray_tracing::RayFlags::OPAQUE,
cull_mask,
// sbt_offset, sbt_stride, miss_index 全为 0
0,
0,
0,
ray.origin,
tmin,
ray.direction,
tmax,
payload,
);
}

if payload.is_miss != 0 {
color *= payload.position;
break;
} else {
let mut scatter = Scatter::default();
match unsafe {
materials
.index_unchecked(payload.material as usize)
.scatter(&ray, payload, &mut rng, &mut scatter)
} {
true => {
color *= scatter.color;
ray = scatter.ray;
}
false => break,
}
}
}

// y 坐标上下取反交换
let pos = uvec2(launch_id.x, launch_size.y - 1 - launch_id.y);
let prev: Vec4 = output.read(pos);

unsafe {
// 加算最后求平均
output.write(pos, prev + color.extend(1.0));
}
}

由此我们就完成了光线追踪 shader 的编码,在下一章中我们将执行这一组 shader。

调用光线追踪 shader

在完成了光线追踪的 shader 之后我们就可以使用 ash 调用了。

为了方便使用 GPU Buffer 而创建的结构体

这一部分我们会先创建一个 struct 用于方便的新建、储存、映射、销毁 GPU 内存。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#[derive(Clone)]
struct BufferResource {
buffer: vk::Buffer,
memory: vk::DeviceMemory,
size: vk::DeviceSize,
}

impl BufferResource {
fn new(
size: vk::DeviceSize,
usage: vk::BufferUsageFlags,
memory_properties: vk::MemoryPropertyFlags,
device: &ash::Device,
device_memory_properties: vk::PhysicalDeviceMemoryProperties,
) -> Self {
unsafe {
let buffer_info = vk::BufferCreateInfo::default()
.size(size)
.usage(usage)
.sharing_mode(vk::SharingMode::EXCLUSIVE);

let buffer = device.create_buffer(&buffer_info, None).unwrap();

let memory_req = device.get_buffer_memory_requirements(buffer);

let memory_index = get_memory_type_index(
device_memory_properties,
memory_req.memory_type_bits,
memory_properties,
);

let mut memory_allocate_flags_info = vk::MemoryAllocateFlagsInfo::default()
.flags(vk::MemoryAllocateFlags::DEVICE_ADDRESS);

let mut allocate_info_default = vk::MemoryAllocateInfo::default();

if usage.contains(vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS) {
allocate_info_default =
allocate_info_default.push_next(&mut memory_allocate_flags_info);
}

let allocate_info = allocate_info_default
.allocation_size(memory_req.size)
.memory_type_index(memory_index);

let memory = device.allocate_memory(&allocate_info, None).unwrap();

device.bind_buffer_memory(buffer, memory, 0).unwrap();

BufferResource {
buffer,
memory,
size,
}
}
}

fn store<T: Copy>(&mut self, data: &[T], device: &ash::Device) {
unsafe {
let size = (std::mem::size_of::<T>() * data.len()) as u64;
assert!(self.size >= size);
let mapped_ptr = self.map(size, device);
let mut mapped_slice = Align::new(mapped_ptr, std::mem::align_of::<T>() as u64, size);
mapped_slice.copy_from_slice(&data);
self.unmap(device);
}
}

fn map(&mut self, size: vk::DeviceSize, device: &ash::Device) -> *mut std::ffi::c_void {
unsafe {
let data: *mut std::ffi::c_void = device
.map_memory(self.memory, 0, size, vk::MemoryMapFlags::empty())
.unwrap();
data
}
}

fn unmap(&mut self, device: &ash::Device) {
unsafe {
device.unmap_memory(self.memory);
}
}

unsafe fn destroy(self, device: &ash::Device) {
device.destroy_buffer(self.buffer, None);
device.free_memory(self.memory, None);
}
}

当然我们也可以尝试使用 gpu-allocator 库来管理内存,但是我们在这里还是使用了自己的实现。

创建 BLAS

在这里我们将使用一个可复用的 BLAS 来处理 TLAS 的变换,因此我们只需要创建一个 BLAS。

注意在创建 AS 的时候我们需要一个额外的 scratch buffer 用于存储 AS 在构建时的数据。至于这个 buffer 将会被如何使用取决于驱动程序的实现。Vulkan 不会保证这块内存区域的存在,因此需要我们自己创建。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
let acceleration_structure = khr::acceleration_structure::Device::new(&instance, &device);

// acceleration structures

// Create bottom-level acceleration structure

let (bottom_as_sphere, bottom_as_sphere_buffer, aabb_buffer) = {
// 2.0^3 的 AABB
let aabb = vk::AabbPositionsKHR::default()
.min_x(-1.0)
.max_x(1.0)
.min_y(-1.0)
.max_y(1.0)
.min_z(-1.0)
.max_z(1.0);

// GPU 用的 AABB buffer
let mut aabb_buffer = BufferResource::new(
std::mem::size_of::<vk::AabbPositionsKHR>() as u64,
vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
| vk::BufferUsageFlags::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_KHR,
vk::MemoryPropertyFlags::HOST_VISIBLE
| vk::MemoryPropertyFlags::HOST_COHERENT
| vk::MemoryPropertyFlags::DEVICE_LOCAL,
&device,
device_memory_properties,
);

aabb_buffer.store(&[aabb], &device);

let geometry = vk::AccelerationStructureGeometryKHR::default()
.geometry_type(vk::GeometryTypeKHR::AABBS)
.geometry(vk::AccelerationStructureGeometryDataKHR {
aabbs: vk::AccelerationStructureGeometryAabbsDataKHR::default()
.data(vk::DeviceOrHostAddressConstKHR {
device_address: unsafe {
get_buffer_device_address(&device, aabb_buffer.buffer)
},
})
.stride(std::mem::size_of::<vk::AabbPositionsKHR>() as u64),
})
// 我们并不使用 any hit shader,因此在这里也要设置为 `OPAQUE`
.flags(vk::GeometryFlagsKHR::OPAQUE);

let build_range_info = vk::AccelerationStructureBuildRangeInfoKHR::default()
.first_vertex(0)
.primitive_count(1)
.primitive_offset(0)
.transform_offset(0);

let geometries = [geometry];

let mut build_info = vk::AccelerationStructureBuildGeometryInfoKHR::default()
// 这里可以选择对光线追踪进行彻底的优化并构建,还是适当的优化来缩短构建时间
// 由于我们在这里只构建一次 TLAS,所以选择 `PREFER_FAST_TRACE`
.flags(vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE)
.geometries(&geometries)
.mode(vk::BuildAccelerationStructureModeKHR::BUILD)
.ty(vk::AccelerationStructureTypeKHR::BOTTOM_LEVEL);

// 创建 BLAS 和 scratch buffer 时需要提供大小
// 虽然实际上可能不会用到这么多的内存,但是为了方便我们还是这样设定了大小
let mut size_info = vk::AccelerationStructureBuildSizesInfoKHR::default();
unsafe {
acceleration_structure.get_acceleration_structure_build_sizes(
vk::AccelerationStructureBuildTypeKHR::DEVICE,
&build_info,
&[1],
&mut size_info,
)
};

let bottom_as_buffer = BufferResource::new(
size_info.acceleration_structure_size,
vk::BufferUsageFlags::ACCELERATION_STRUCTURE_STORAGE_KHR
| vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
| vk::BufferUsageFlags::STORAGE_BUFFER,
vk::MemoryPropertyFlags::DEVICE_LOCAL,
&device,
device_memory_properties,
);

let as_create_info = vk::AccelerationStructureCreateInfoKHR::default()
.ty(build_info.ty)
.size(size_info.acceleration_structure_size)
.buffer(bottom_as_buffer.buffer)
.offset(0);

let bottom_as =
unsafe { acceleration_structure.create_acceleration_structure(&as_create_info, None) }
.unwrap();

build_info.dst_acceleration_structure = bottom_as;

let scratch_buffer = BufferResource::new(
size_info.build_scratch_size,
vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS | vk::BufferUsageFlags::STORAGE_BUFFER,
vk::MemoryPropertyFlags::DEVICE_LOCAL,
&device,
device_memory_properties,
);

build_info.scratch_data = vk::DeviceOrHostAddressKHR {
device_address: unsafe { get_buffer_device_address(&device, scratch_buffer.buffer) },
};

let build_command_buffer = {
let allocate_info = vk::CommandBufferAllocateInfo::default()
.command_buffer_count(1)
.command_pool(command_pool)
.level(vk::CommandBufferLevel::PRIMARY);

let command_buffers =
unsafe { device.allocate_command_buffers(&allocate_info) }.unwrap();
command_buffers[0]
};

unsafe {
device
.begin_command_buffer(
build_command_buffer,
&vk::CommandBufferBeginInfo::default()
.flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
)
.unwrap();

let build_infos = [build_info];
let build_range_infos: &[&[_]] = &[&[build_range_info]];

// AS 可以在 GPU 上构建,也会提供实时的 API
acceleration_structure.cmd_build_acceleration_structures(
build_command_buffer,
&build_infos,
build_range_infos,
);
device.end_command_buffer(build_command_buffer).unwrap();
device
.queue_submit(
graphics_queue,
&[vk::SubmitInfo::default().command_buffers(&[build_command_buffer])],
vk::Fence::null(),
)
.expect("queue submit failed.");

device.queue_wait_idle(graphics_queue).unwrap();
device.free_command_buffers(command_pool, &[build_command_buffer]);
scratch_buffer.destroy(&device);
}
(bottom_as, bottom_as_buffer, aabb_buffer)
};

创建 TLAS

参照上面的代码我们创建 TLAS 和材质 buffer。

首先我们将创建一个和 One Weekend 相同的场景。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
// 一个球的 TLAS
fn create_sphere_instance(
pos: glam::Vec3A,
size: f32,
sphere_accel_handle: u64,
) -> vk::AccelerationStructureInstanceKHR {
vk::AccelerationStructureInstanceKHR {
transform: vk::TransformMatrixKHR {
// 变换矩阵的形状为 3x4
// 但是一般来说我们都会当作是 4x4 的矩阵的前 12 项
matrix: [
size, 0.0, 0.0, pos.x, 0.0, size, 0.0, pos.y, 0.0, 0.0, size, pos.z,
],
},
// 从最高位开始的 8 位是一个 mask,如果和 `TraceRay` 当中指定的 mask 不同就会被忽略
// 剩下的 24 位是自定义的 index,我们在这里用作材质的索引,稍后会进行编辑
// `vk::Packed24_8` 可以将一个 32 位的整数分为 24 位和 8 位两部分
instance_custom_index_and_mask: vk::Packed24_8::new(0, 0xff),
// 从最高位开始的 8 位是一个 flag,在这里我们设为 `OPAQUE`
// 剩下 24 位是 SBT 的偏移量,这里设为 0
instance_shader_binding_table_record_offset_and_flags: vk::Packed24_8::new(
0,
vk::GeometryInstanceFlagsKHR::FORCE_OPAQUE.as_raw() as u8,
),
acceleration_structure_reference: vk::AccelerationStructureReferenceKHR {
device_handle: sphere_accel_handle,
},
}
}

// 创建 TLAS 实例和材质
// 和 One Weekend 相同
fn sample_scene(
sphere_accel_handle: u64,
) -> (Vec<vk::AccelerationStructureInstanceKHR>, Vec<EnumMaterial>) {
let mut rng = StdRng::from_entropy();
let mut world = Vec::new();

world.push((
create_sphere_instance(vec3a(0.0, -1000.0, 0.0), 1000.0, sphere_accel_handle),
EnumMaterial::new_lambertian(vec3a(0.5, 0.5, 0.5)),
));

for a in -11..11 {
for b in -11..11 {
let center = vec3a(
a as f32 + 0.9 * rng.gen::<f32>(),
0.2,
b as f32 + 0.9 * rng.gen::<f32>(),
);

let choose_mat: f32 = rng.gen();

if (center - vec3a(4.0, 0.2, 0.0)).length() > 0.9 {
match choose_mat {
x if x < 0.8 => {
let albedo = vec3a(rng.gen(), rng.gen(), rng.gen())
* vec3a(rng.gen(), rng.gen(), rng.gen());

world.push((
create_sphere_instance(center, 0.3, sphere_accel_handle),
EnumMaterial::new_lambertian(albedo),
));
}
x if x < 0.95 => {
let albedo = vec3a(
rng.gen_range(0.5..1.0),
rng.gen_range(0.5..1.0),
rng.gen_range(0.5..1.0),
);
let fuzz = rng.gen_range(0.0..0.5);

world.push((
create_sphere_instance(center, 0.2, sphere_accel_handle),
EnumMaterial::new_metal(albedo, fuzz),
));
}
_ => world.push((
create_sphere_instance(center, 0.2, sphere_accel_handle),
EnumMaterial::new_dielectric(1.5),
)),
}
}
}
}

world.push((
create_sphere_instance(vec3a(0.0, 1.0, 0.0), 1.0, sphere_accel_handle),
EnumMaterial::new_dielectric(1.5),
));

world.push((
create_sphere_instance(vec3a(-4.0, 1.0, 0.0), 1.0, sphere_accel_handle),
EnumMaterial::new_lambertian(vec3a(0.4, 0.2, 0.1)),
));

world.push((
create_sphere_instance(vec3a(4.0, 1.0, 0.0), 1.0, sphere_accel_handle),
EnumMaterial::new_metal(vec3a(0.7, 0.6, 0.5), 0.0),
));

let mut spheres = Vec::new();
let mut materials = Vec::new();

for (i, (mut sphere, material)) in world.into_iter().enumerate() {
sphere.instance_custom_index_and_mask =
vk::Packed24_8::new(i as u32, sphere.instance_custom_index_and_mask.high_8());
spheres.push(sphere);
materials.push(material);
}

(spheres, materials)
}

创建 TLAS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
let sphere_accel_handle = {
let as_addr_info = vk::AccelerationStructureDeviceAddressInfoKHR::default()
.acceleration_structure(bottom_as_sphere);
unsafe { acceleration_structure.get_acceleration_structure_device_address(&as_addr_info) }
};

let (sphere_instances, materials) = sample_scene(sphere_accel_handle);

// 把上面做好的 TLAS 实例放到 GPU 当中
let (instance_count, instance_buffer) = {
let instances = sphere_instances;

let instance_buffer_size =
std::mem::size_of::<vk::AccelerationStructureInstanceKHR>() * instances.len();

let mut instance_buffer = BufferResource::new(
instance_buffer_size as vk::DeviceSize,
vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
| vk::BufferUsageFlags::ACCELERATION_STRUCTURE_BUILD_INPUT_READ_ONLY_KHR,
vk::MemoryPropertyFlags::HOST_VISIBLE
| vk::MemoryPropertyFlags::HOST_COHERENT
| vk::MemoryPropertyFlags::DEVICE_LOCAL,
&device,
device_memory_properties,
);

instance_buffer.store(&instances, &device);

(instances.len(), instance_buffer)
};

// 和 BLAS 一样的创建方法
let (top_as, top_as_buffer) = {
let build_range_info = vk::AccelerationStructureBuildRangeInfoKHR::default()
.first_vertex(0)
.primitive_count(instance_count as u32)
.primitive_offset(0)
.transform_offset(0);

let build_command_buffer = {
let allocate_info = vk::CommandBufferAllocateInfo::default()
.command_buffer_count(1)
.command_pool(command_pool)
.level(vk::CommandBufferLevel::PRIMARY);

let command_buffers =
unsafe { device.allocate_command_buffers(&allocate_info) }.unwrap();
command_buffers[0]
};

unsafe {
device
.begin_command_buffer(
build_command_buffer,
&vk::CommandBufferBeginInfo::default()
.flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT),
)
.unwrap();
let memory_barrier = vk::MemoryBarrier::default()
.src_access_mask(vk::AccessFlags::TRANSFER_WRITE)
.dst_access_mask(vk::AccessFlags::ACCELERATION_STRUCTURE_WRITE_KHR);
device.cmd_pipeline_barrier(
build_command_buffer,
vk::PipelineStageFlags::TRANSFER,
vk::PipelineStageFlags::ACCELERATION_STRUCTURE_BUILD_KHR,
vk::DependencyFlags::empty(),
&[memory_barrier],
&[],
&[],
);
}

let instances = vk::AccelerationStructureGeometryInstancesDataKHR::default()
.array_of_pointers(false)
.data(vk::DeviceOrHostAddressConstKHR {
device_address: unsafe {
get_buffer_device_address(&device, instance_buffer.buffer)
},
});

let geometry = vk::AccelerationStructureGeometryKHR::default()
.geometry_type(vk::GeometryTypeKHR::INSTANCES)
.geometry(vk::AccelerationStructureGeometryDataKHR { instances });

let geometries = [geometry];

let mut build_info = vk::AccelerationStructureBuildGeometryInfoKHR::default()
.flags(vk::BuildAccelerationStructureFlagsKHR::PREFER_FAST_TRACE)
.geometries(&geometries)
.mode(vk::BuildAccelerationStructureModeKHR::BUILD)
.ty(vk::AccelerationStructureTypeKHR::TOP_LEVEL);

let mut size_info = vk::AccelerationStructureBuildSizesInfoKHR::default();
unsafe {
acceleration_structure.get_acceleration_structure_build_sizes(
vk::AccelerationStructureBuildTypeKHR::DEVICE,
&build_info,
&[build_range_info.primitive_count],
&mut size_info,
)
};

let top_as_buffer = BufferResource::new(
size_info.acceleration_structure_size,
vk::BufferUsageFlags::ACCELERATION_STRUCTURE_STORAGE_KHR
| vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
| vk::BufferUsageFlags::STORAGE_BUFFER,
vk::MemoryPropertyFlags::DEVICE_LOCAL,
&device,
device_memory_properties,
);

let as_create_info = vk::AccelerationStructureCreateInfoKHR::default()
.ty(build_info.ty)
.size(size_info.acceleration_structure_size)
.buffer(top_as_buffer.buffer)
.offset(0);

let top_as =
unsafe { acceleration_structure.create_acceleration_structure(&as_create_info, None) }
.unwrap();

build_info.dst_acceleration_structure = top_as;

let scratch_buffer = BufferResource::new(
size_info.build_scratch_size,
vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS | vk::BufferUsageFlags::STORAGE_BUFFER,
vk::MemoryPropertyFlags::DEVICE_LOCAL,
&device,
device_memory_properties,
);

build_info.scratch_data = vk::DeviceOrHostAddressKHR {
device_address: unsafe { get_buffer_device_address(&device, scratch_buffer.buffer) },
};

unsafe {
let build_infos = [build_info];
let build_range_infos: &[&[_]] = &[&[build_range_info]];
acceleration_structure.cmd_build_acceleration_structures(
build_command_buffer,
&build_infos,
build_range_infos,
);
device.end_command_buffer(build_command_buffer).unwrap();
device
.queue_submit(
graphics_queue,
&[vk::SubmitInfo::default().command_buffers(&[build_command_buffer])],
vk::Fence::null(),
)
.expect("queue submit failed.");

device.queue_wait_idle(graphics_queue).unwrap();
device.free_command_buffers(command_pool, &[build_command_buffer]);
scratch_buffer.destroy(&device);
}

(top_as, top_as_buffer)
};

同时我们还需要把材质 buffer 放到 GPU 当中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
let material_buffer = {
let buffer_size = (materials.len() * std::mem::size_of::<EnumMaterial>()) as vk::DeviceSize;

let mut material_buffer = BufferResource::new(
buffer_size,
vk::BufferUsageFlags::STORAGE_BUFFER,
vk::MemoryPropertyFlags::HOST_VISIBLE
| vk::MemoryPropertyFlags::HOST_COHERENT
| vk::MemoryPropertyFlags::DEVICE_LOCAL,
&device,
device_memory_properties,
);
material_buffer.store(&materials, &device);

material_buffer
};

创建光线追踪流水线(Ray Tracing Pipeline)

光线追踪流水线就和图形流水线一样,需要注册一些 shader 并填写 descriptor set 以及 push constant 的信息。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
let (descriptor_set_layout, graphics_pipeline, pipeline_layout, shader_groups_len) = {
let descriptor_set_layout = unsafe {
device.create_descriptor_set_layout(
&vk::DescriptorSetLayoutCreateInfo::default().bindings(&[
// descriptor_set = 0, binding = 0
// TLAS
vk::DescriptorSetLayoutBinding::default()
.descriptor_count(1)
.descriptor_type(vk::DescriptorType::ACCELERATION_STRUCTURE_KHR)
.stage_flags(vk::ShaderStageFlags::RAYGEN_KHR)
.binding(0),
// descriptor_set = 1, binding = 1
// 返回的图像
vk::DescriptorSetLayoutBinding::default()
.descriptor_count(1)
.descriptor_type(vk::DescriptorType::STORAGE_IMAGE)
.stage_flags(vk::ShaderStageFlags::RAYGEN_KHR)
.binding(1),
// descriptor_set = 2, binding = 2
// 材质
vk::DescriptorSetLayoutBinding::default()
.descriptor_count(1)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.stage_flags(vk::ShaderStageFlags::RAYGEN_KHR)
.binding(2),
]),
None,
)
}
.unwrap();

// 随机的一个 4byte 数字
let push_constant_range = vk::PushConstantRange::default()
.offset(0)
.size(4)
.stage_flags(vk::ShaderStageFlags::RAYGEN_KHR);

const SHADER: &[u8] = include_bytes!(env!("shader.spv"));

let shader_module = unsafe { create_shader_module(&device, SHADER).unwrap() };

let layouts = [descriptor_set_layout];
let pipeline_layout = unsafe {
device.create_pipeline_layout(
&vk::PipelineLayoutCreateInfo::default()
.set_layouts(&layouts)
.push_constant_ranges(&[push_constant_range]),
None,
)
}
.unwrap();

let shader_groups = vec![
// group0 = [ raygen ]
vk::RayTracingShaderGroupCreateInfoKHR::default()
.ty(vk::RayTracingShaderGroupTypeKHR::GENERAL)
.general_shader(0)
.closest_hit_shader(vk::SHADER_UNUSED_KHR)
.any_hit_shader(vk::SHADER_UNUSED_KHR)
.intersection_shader(vk::SHADER_UNUSED_KHR),
// group1 = [ miss ]
vk::RayTracingShaderGroupCreateInfoKHR::default()
.ty(vk::RayTracingShaderGroupTypeKHR::GENERAL)
.general_shader(1)
.closest_hit_shader(vk::SHADER_UNUSED_KHR)
.any_hit_shader(vk::SHADER_UNUSED_KHR)
.intersection_shader(vk::SHADER_UNUSED_KHR),
// group2 = [ chit ]
vk::RayTracingShaderGroupCreateInfoKHR::default()
.ty(vk::RayTracingShaderGroupTypeKHR::PROCEDURAL_HIT_GROUP)
.general_shader(vk::SHADER_UNUSED_KHR)
.closest_hit_shader(3)
.any_hit_shader(vk::SHADER_UNUSED_KHR)
.intersection_shader(2),
];

let shader_stages = vec![
vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::RAYGEN_KHR)
.module(shader_module)
.name(std::ffi::CStr::from_bytes_with_nul(b"main_ray_generation\0").unwrap()),
vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::MISS_KHR)
.module(shader_module)
.name(std::ffi::CStr::from_bytes_with_nul(b"main_miss\0").unwrap()),
vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::INTERSECTION_KHR)
.module(shader_module)
.name(std::ffi::CStr::from_bytes_with_nul(b"main_intersection\0").unwrap()),
vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::CLOSEST_HIT_KHR)
.module(shader_module)
.name(std::ffi::CStr::from_bytes_with_nul(b"main_closest_hit\0").unwrap()),
];

let pipeline = unsafe {
rt_pipeline.create_ray_tracing_pipelines(
vk::DeferredOperationKHR::null(),
vk::PipelineCache::null(),
&[vk::RayTracingPipelineCreateInfoKHR::default()
.stages(&shader_stages)
.groups(&shader_groups)
.max_pipeline_ray_recursion_depth(0)
.layout(pipeline_layout)],
None,
)
}
.unwrap()[0];

unsafe {
device.destroy_shader_module(shader_module, None);
}

(
descriptor_set_layout,
pipeline,
pipeline_layout,
shader_groups.len(),
)
}

设置 descriptor

我们在这里需要创建一个 descriptor 来传递给 shader,如前面所说,我们需要传递 TLAS、输出图像和材质 buffer。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
let descriptor_sizes = [
vk::DescriptorPoolSize {
ty: vk::DescriptorType::ACCELERATION_STRUCTURE_KHR,
descriptor_count: 1,
},
vk::DescriptorPoolSize {
ty: vk::DescriptorType::STORAGE_IMAGE,
descriptor_count: 1,
},
vk::DescriptorPoolSize {
ty: vk::DescriptorType::STORAGE_BUFFER,
descriptor_count: 1,
},
];

let descriptor_pool_info = vk::DescriptorPoolCreateInfo::default()
.pool_sizes(&descriptor_sizes)
.max_sets(1);

let descriptor_pool =
unsafe { device.create_descriptor_pool(&descriptor_pool_info, None) }.unwrap();

let descriptor_counts = [1];

let mut count_allocate_info = vk::DescriptorSetVariableDescriptorCountAllocateInfo::default()
.descriptor_counts(&descriptor_counts);

let descriptor_sets = unsafe {
device.allocate_descriptor_sets(
&vk::DescriptorSetAllocateInfo::default()
.descriptor_pool(descriptor_pool)
.set_layouts(&[descriptor_set_layout])
.push_next(&mut count_allocate_info),
)
}
.unwrap();

let descriptor_set = descriptor_sets[0];

let accel_structs = [top_as];
let mut accel_info = vk::WriteDescriptorSetAccelerationStructureKHR::default()
.acceleration_structures(&accel_structs);

let mut accel_write = vk::WriteDescriptorSet::default()
.dst_set(descriptor_set)
.dst_binding(0)
.dst_array_element(0)
.descriptor_type(vk::DescriptorType::ACCELERATION_STRUCTURE_KHR)
.push_next(&mut accel_info);

// This is only set by the default for images, buffers, or views; need to set explicitly after
accel_write.descriptor_count = 1;

// 省略了 image_view 的创建
let image_info = [vk::DescriptorImageInfo::default()
.image_layout(vk::ImageLayout::GENERAL)
.image_view(image_view)];

let image_write = vk::WriteDescriptorSet::default()
.dst_set(descriptor_set)
.dst_binding(1)
.dst_array_element(0)
.descriptor_type(vk::DescriptorType::STORAGE_IMAGE)
.image_info(&image_info);

let buffer_info = [vk::DescriptorBufferInfo::default()
.buffer(material_buffer.buffer)
.range(vk::WHOLE_SIZE)];

let buffers_write = vk::WriteDescriptorSet::default()
.dst_set(descriptor_set)
.dst_binding(2)
.dst_array_element(0)
.descriptor_type(vk::DescriptorType::STORAGE_BUFFER)
.buffer_info(&buffer_info);

unsafe {
device.update_descriptor_sets(&[accel_write, image_write, buffers_write], &[]);
}

创建 shader binding table

在流水线中为 SBT 创建一块 buffer,用于储存 shader 的信息。虽说并非所有的 shader record 的信息都需要内存连续,但是为了方便起见,我们在这里将它们放在一起。

为了让 vkGetRayTracingShaderGroupHandlesKHR 能够正确方便地获得 shader 信息,我们在这里要将 shader 进行字节对齐。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
fn aligned_size(value: u32, alignment: u32) -> u32 {
(value + alignment - 1) & !(alignment - 1)
}

let shader_binding_table_buffer = {
let incoming_table_data = unsafe {
rt_pipeline.get_ray_tracing_shader_group_handles(
graphics_pipeline,
0,
shader_groups_len as u32,
shader_groups_len * rt_pipeline_properties.shader_group_handle_size as usize,
)
}
.unwrap();

// 对齐 shader handle
let handle_size_aligned = aligned_size(
rt_pipeline_properties.shader_group_handle_size,
rt_pipeline_properties.shader_group_base_alignment,
);

let table_size = shader_groups_len * handle_size_aligned as usize;
let mut table_data = vec![0u8; table_size];

// 再次配置
for i in 0..shader_groups_len {
table_data[i * handle_size_aligned as usize
..i * handle_size_aligned as usize
+ rt_pipeline_properties.shader_group_handle_size as usize]
.copy_from_slice(
&incoming_table_data[i * rt_pipeline_properties.shader_group_handle_size
as usize
..i * rt_pipeline_properties.shader_group_handle_size as usize
+ rt_pipeline_properties.shader_group_handle_size as usize],
);
}

let mut shader_binding_table_buffer = BufferResource::new(
table_size as u64,
vk::BufferUsageFlags::SHADER_DEVICE_ADDRESS
| vk::BufferUsageFlags::SHADER_BINDING_TABLE_KHR
| vk::BufferUsageFlags::STORAGE_BUFFER,
vk::MemoryPropertyFlags::HOST_VISIBLE
| vk::MemoryPropertyFlags::HOST_COHERENT
| vk::MemoryPropertyFlags::DEVICE_LOCAL,
&device,
device_memory_properties,
);

shader_binding_table_buffer.store(&table_data, &device);

shader_binding_table_buffer
};

调用 vkCmdTraceRaysKHR

至此,我们所有的前置准备都已经完成,终于可以调用光线追踪的命令了,我们在这里连续调用 100 次。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
{
let handle_size_aligned = aligned_size(
rt_pipeline_properties.shader_group_handle_size,
rt_pipeline_properties.shader_group_base_alignment,
) as u64;

// |[ raygen shader ]|[ miss shader ]|[ hit shader ]|
// | | | |
// | 0 | 1 | 2 |

let sbt_address =
unsafe { get_buffer_device_address(&device, shader_binding_table_buffer.buffer) };

let sbt_raygen_region = vk::StridedDeviceAddressRegionKHR::default()
.device_address(sbt_address + 0)
.size(handle_size_aligned)
.stride(handle_size_aligned);

let sbt_miss_region = vk::StridedDeviceAddressRegionKHR::default()
.device_address(sbt_address + 1 * handle_size_aligned)
.size(handle_size_aligned)
.stride(handle_size_aligned);

let sbt_hit_region = vk::StridedDeviceAddressRegionKHR::default()
.device_address(sbt_address + 2 * handle_size_aligned)
.size(handle_size_aligned)
.stride(handle_size_aligned);

let sbt_call_region = vk::StridedDeviceAddressRegionKHR::default();

let command_buffer = {
let command_buffer_allocate_info = vk::CommandBufferAllocateInfo::default()
.command_buffer_count(1)
.command_pool(command_pool)
.level(vk::CommandBufferLevel::PRIMARY);

unsafe { device.allocate_command_buffers(&command_buffer_allocate_info) }
.expect("Failed to allocate Command Buffers!")[0]
};

{
let command_buffer_begin_info = vk::CommandBufferBeginInfo::default()
.flags(vk::CommandBufferUsageFlags::SIMULTANEOUS_USE);

unsafe { device.begin_command_buffer(command_buffer, &command_buffer_begin_info) }
.expect("Failed to begin recording Command Buffer at beginning!");
}
unsafe {
let range = vk::ImageSubresourceRange::default()
.aspect_mask(vk::ImageAspectFlags::COLOR)
.base_mip_level(0)
.level_count(1)
.base_array_layer(0)
.layer_count(1);

device.cmd_clear_color_image(
command_buffer,
image,
vk::ImageLayout::GENERAL,
&vk::ClearColorValue {
float32: [0.0, 0.0, 0.0, 0.0],
},
&[range],
);

let image_barrier = vk::ImageMemoryBarrier::default()
.src_access_mask(vk::AccessFlags::COLOR_ATTACHMENT_WRITE)
.dst_access_mask(vk::AccessFlags::SHADER_WRITE | vk::AccessFlags::SHADER_READ)
.old_layout(vk::ImageLayout::GENERAL)
.new_layout(vk::ImageLayout::GENERAL)
.image(image)
.subresource_range(
vk::ImageSubresourceRange::default()
.aspect_mask(vk::ImageAspectFlags::COLOR)
.base_mip_level(0)
.level_count(1)
.base_array_layer(0)
.layer_count(1),
);

device.cmd_pipeline_barrier(
command_buffer,
vk::PipelineStageFlags::COLOR_ATTACHMENT_OUTPUT,
vk::PipelineStageFlags::RAY_TRACING_SHADER_KHR,
vk::DependencyFlags::empty(),
&[],
&[],
&[image_barrier],
);

device.end_command_buffer(command_buffer).unwrap();
}

let command_buffers = [command_buffer];

let submit_infos = [vk::SubmitInfo::default().command_buffers(&command_buffers)];

unsafe {
device
.queue_submit(graphics_queue, &submit_infos, vk::Fence::null())
.expect("Failed to execute queue submit.");

device.queue_wait_idle(graphics_queue).unwrap();
device.free_command_buffers(command_pool, &[command_buffer]);
}

let image_barrier2 = vk::ImageMemoryBarrier::default()
.src_access_mask(vk::AccessFlags::SHADER_WRITE | vk::AccessFlags::SHADER_READ)
.dst_access_mask(vk::AccessFlags::SHADER_WRITE | vk::AccessFlags::SHADER_READ)
.old_layout(vk::ImageLayout::GENERAL)
.new_layout(vk::ImageLayout::GENERAL)
.image(image)
.subresource_range(
vk::ImageSubresourceRange::default()
.aspect_mask(vk::ImageAspectFlags::COLOR)
.base_mip_level(0)
.level_count(1)
.base_array_layer(0)
.layer_count(1),
);

let mut rng = StdRng::from_entropy();
let mut sampled = 0;

let command_buffer = {
let command_buffer_allocate_info = vk::CommandBufferAllocateInfo::default()
.command_buffer_count(1)
.command_pool(command_pool)
.level(vk::CommandBufferLevel::PRIMARY);

unsafe { device.allocate_command_buffers(&command_buffer_allocate_info) }
.expect("Failed to allocate Command Buffers!")[0]
};

while sampled < N_SAMPLES {
// 进行 `N_SAMPLES`(100) 次光线追踪
let samples = std::cmp::min(N_SAMPLES - sampled, N_SAMPLES_ITER);
sampled += samples;

{
let command_buffer_begin_info = vk::CommandBufferBeginInfo::default()
.flags(vk::CommandBufferUsageFlags::SIMULTANEOUS_USE);

unsafe { device.begin_command_buffer(command_buffer, &command_buffer_begin_info) }
.expect("Failed to begin recording Command Buffer at beginning!");
}

unsafe {
device.cmd_bind_pipeline(
command_buffer,
vk::PipelineBindPoint::RAY_TRACING_KHR,
graphics_pipeline,
);
device.cmd_bind_descriptor_sets(
command_buffer,
vk::PipelineBindPoint::RAY_TRACING_KHR,
pipeline_layout,
0,
&[descriptor_set],
&[],
);
}
for _ in 0..samples {
unsafe {
device.cmd_pipeline_barrier(
command_buffer,
vk::PipelineStageFlags::RAY_TRACING_SHADER_KHR,
vk::PipelineStageFlags::RAY_TRACING_SHADER_KHR,
vk::DependencyFlags::empty(),
&[],
&[],
&[image_barrier2],
);

// 指定 push constant 作为随机数种子的一部分
device.cmd_push_constants(
command_buffer,
pipeline_layout,
vk::ShaderStageFlags::RAYGEN_KHR,
0,
&rng.next_u32().to_le_bytes(),
);

// 按照 WIDTH * HEIGHT 进行并行
rt_pipeline.cmd_trace_rays(
command_buffer,
&sbt_raygen_region,
&sbt_miss_region,
&sbt_hit_region,
&sbt_call_region,
WIDTH,
HEIGHT,
1,
);
}
}
unsafe {
device.end_command_buffer(command_buffer).unwrap();

let command_buffers = [command_buffer];

let submit_infos = [vk::SubmitInfo::default().command_buffers(&command_buffers)];

device
.queue_submit(graphics_queue, &submit_infos, vk::Fence::null())
.expect("Failed to execute queue submit.");

device.queue_wait_idle(graphics_queue).unwrap();
}
eprint!("\rSamples: {} / {} ", sampled, N_SAMPLES);
}
unsafe {
device.free_command_buffers(command_pool, &[command_buffer]);
}
eprint!("\nDone");
}

最终我们得到了一个渲染结果,如下图