Update for the latest cudarc.

This commit is contained in:
laurent
2025-04-11 14:02:41 +02:00
parent c87f0fa5d6
commit 543b5b5898
2 changed files with 15 additions and 28 deletions

View File

@ -42,7 +42,7 @@ clap = { workspace = true }
criterion = { workspace = true }
[features]
default = []
default = ["cuda"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]

View File

@ -15,7 +15,7 @@ fn cuda_graph() -> Result<()> {
Device::Cuda(dev) => dev,
_ => unreachable!(),
};
let cu_stream = cu_device.cu_stream();
let cu_stream = cu_device.cuda_stream();
{
// load_ptx cannot be called while capturing the stream so we need this to happen
// beforehand.
@ -31,14 +31,9 @@ fn cuda_graph() -> Result<()> {
device.synchronize()?;
}
if USE_CUDA_GRAPH {
unsafe {
cudarc::driver::sys::lib()
.cuStreamBeginCapture_v2(
*cu_stream,
cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
)
.result()?
}
cu_stream.begin_capture(
cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
)?;
}
{
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
@ -55,28 +50,17 @@ fn cuda_graph() -> Result<()> {
}
}
if USE_CUDA_GRAPH {
let cu_graph: cudarc::driver::sys::CUgraph = unsafe {
let mut cu_graph = std::mem::MaybeUninit::uninit();
cudarc::driver::sys::lib()
.cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr())
.result()?;
cu_graph.assume_init()
};
let cu_graph_e: cudarc::driver::sys::CUgraphExec = unsafe {
let mut cu_graph_e = std::mem::MaybeUninit::uninit();
cudarc::driver::sys::lib()
.cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0)
.result()?;
cu_graph_e.assume_init()
println!("capturing graph");
let cu_graph = match cu_stream.end_capture(
cudarc::driver::sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
)? {
None => anyhow::bail!("no graph captured"),
Some(cu_graph) => cu_graph,
};
println!("graph captured!");
for i in 1..100 {
println!("graph exec {i}");
unsafe {
cudarc::driver::sys::lib()
.cuGraphLaunch(cu_graph_e, *cu_stream)
.result()?
}
cu_graph.launch()?;
println!("sync");
if let Err(err) = device.synchronize() {
println!("err: {err:?}")
@ -92,6 +76,9 @@ fn cuda_graph() -> Result<()> {
fn main() -> Result<()> {
cuda_graph()?;
return Ok(());
}
fn _matmul() -> Result<()> {
let device = Device::new_cuda_with_stream(0)?;
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
.to_dtype(candle_core::DType::BF16)?;