diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index ebd2c519..f2a4ce81 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -42,7 +42,7 @@ clap = { workspace = true } criterion = { workspace = true } [features] -default = [] +default = ["cuda"] cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 315fe0a2..aea13be0 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -15,7 +15,7 @@ fn cuda_graph() -> Result<()> { Device::Cuda(dev) => dev, _ => unreachable!(), }; - let cu_stream = cu_device.cu_stream(); + let cu_stream = cu_device.cuda_stream(); { // load_ptx cannot be called while capturing the stream so we need this to happen // beforehand. @@ -31,14 +31,9 @@ fn cuda_graph() -> Result<()> { device.synchronize()?; } if USE_CUDA_GRAPH { - unsafe { - cudarc::driver::sys::lib() - .cuStreamBeginCapture_v2( - *cu_stream, - cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL, - ) - .result()? - } + cu_stream.begin_capture( + cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL, + )?; } { let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? @@ -55,28 +50,17 @@ fn cuda_graph() -> Result<()> { } } if USE_CUDA_GRAPH { - let cu_graph: cudarc::driver::sys::CUgraph = unsafe { - let mut cu_graph = std::mem::MaybeUninit::uninit(); - cudarc::driver::sys::lib() - .cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr()) - .result()?; - cu_graph.assume_init() - }; - let cu_graph_e: cudarc::driver::sys::CUgraphExec = unsafe { - let mut cu_graph_e = std::mem::MaybeUninit::uninit(); - cudarc::driver::sys::lib() - .cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0) - .result()?; - cu_graph_e.assume_init() + println!("capturing graph"); + let cu_graph = match cu_stream.end_capture( + cudarc::driver::sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY + )? { + None => anyhow::bail!("no graph captured"), + Some(cu_graph) => cu_graph, }; println!("graph captured!"); for i in 1..100 { println!("graph exec {i}"); - unsafe { - cudarc::driver::sys::lib() - .cuGraphLaunch(cu_graph_e, *cu_stream) - .result()? - } + cu_graph.launch()?; println!("sync"); if let Err(err) = device.synchronize() { println!("err: {err:?}") @@ -92,6 +76,9 @@ fn cuda_graph() -> Result<()> { fn main() -> Result<()> { cuda_graph()?; return Ok(()); +} + +fn _matmul() -> Result<()> { let device = Device::new_cuda_with_stream(0)?; let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)? .to_dtype(candle_core::DType::BF16)?;