Compare commits

...

5 Commits

Author SHA1 Message Date
543b5b5898 Update for the latest cudarc. 2025-04-11 14:02:41 +02:00
c87f0fa5d6 Merge remote-tracking branch 'origin/main' into cuda-graph-exp 2025-04-11 13:47:35 +02:00
1bb68854d3 Tweaks to the graph experiment. 2024-10-03 17:12:52 +02:00
b2956857ef More cuda graph attempts. 2024-10-03 12:43:08 +02:00
9076dee432 Cuda graph experiments. 2024-10-03 08:43:00 +02:00
2 changed files with 77 additions and 2 deletions

View File

@ -42,7 +42,7 @@ clap = { workspace = true }
criterion = { workspace = true }
[features]
default = []
default = ["cuda"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
@ -56,3 +56,7 @@ harness = false
[[example]]
name = "metal_basics"
required-features = ["metal"]
[[example]]
name = "cuda_basics"
required-features = ["cuda"]

View File

@ -7,8 +7,79 @@ extern crate intel_mkl_src;
use anyhow::Result;
use candle_core::{Device, Tensor};
const USE_CUDA_GRAPH: bool = true;
fn cuda_graph() -> Result<()> {
let device = Device::new_cuda_with_stream(0)?;
let cu_device = match &device {
Device::Cuda(dev) => dev,
_ => unreachable!(),
};
let cu_stream = cu_device.cuda_stream();
{
// load_ptx cannot be called while capturing the stream so we need this to happen
// beforehand.
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let _x = x.mul(&u)?.broadcast_add(&v)?;
let _x = x.affine(1., 0.5)?;
x.slice_set(&u, 0, 0)?;
device.synchronize()?;
}
if USE_CUDA_GRAPH {
cu_stream.begin_capture(
cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
)?;
}
{
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros((4096, 1), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
for _i in 0..100 {
// x.slice_set(&u, 0, 0)?;
// x.broadcast_add(&v)?;
x = x.affine(1., 0.5)?;
// x = (&u + &x)?;
}
}
if USE_CUDA_GRAPH {
println!("capturing graph");
let cu_graph = match cu_stream.end_capture(
cudarc::driver::sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
)? {
None => anyhow::bail!("no graph captured"),
Some(cu_graph) => cu_graph,
};
println!("graph captured!");
for i in 1..100 {
println!("graph exec {i}");
cu_graph.launch()?;
println!("sync");
if let Err(err) = device.synchronize() {
println!("err: {err:?}")
}
println!("done syncing");
}
} else {
device.synchronize()?;
}
Ok(())
}
fn main() -> Result<()> {
let device = Device::new_cuda(0)?;
cuda_graph()?;
return Ok(());
}
fn _matmul() -> Result<()> {
let device = Device::new_cuda_with_stream(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;
candle_core::cuda::set_gemm_reduced_precision_f32(false);