mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Cuda graph experiments.
This commit is contained in:
@ -52,3 +52,7 @@ harness = false
|
||||
[[example]]
|
||||
name = "metal_basics"
|
||||
required-features = ["metal"]
|
||||
|
||||
[[example]]
|
||||
name = "cuda_basics"
|
||||
required-features = ["cuda"]
|
||||
|
@ -7,8 +7,63 @@ extern crate intel_mkl_src;
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn cuda_graph() -> Result<()> {
|
||||
let device = Device::new_cuda_with_stream(0)?;
|
||||
let cu_device = match &device {
|
||||
Device::Cuda(dev) => dev,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let cu_stream = cu_device.cu_stream();
|
||||
{
|
||||
// load_ptx cannot be called while capturing the stream so we need this to happen
|
||||
// beforehand.
|
||||
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
||||
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
||||
y.slice_set(&x, 0, 0)?;
|
||||
device.synchronize()?;
|
||||
}
|
||||
unsafe {
|
||||
cudarc::driver::sys::lib()
|
||||
.cuStreamBeginCapture_v2(
|
||||
*cu_stream,
|
||||
cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
||||
)
|
||||
.result()?
|
||||
};
|
||||
{
|
||||
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
||||
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
||||
y.slice_set(&x, 0, 0)?;
|
||||
// let y = x.affine(2., 1.)?;
|
||||
}
|
||||
let cu_graph = unsafe {
|
||||
let mut cu_graph = std::mem::MaybeUninit::uninit();
|
||||
cudarc::driver::sys::lib()
|
||||
.cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr())
|
||||
.result()?;
|
||||
cu_graph.assume_init()
|
||||
};
|
||||
let cu_graph_e = unsafe {
|
||||
let mut cu_graph_e = std::mem::MaybeUninit::uninit();
|
||||
cudarc::driver::sys::lib()
|
||||
.cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0)
|
||||
.result()?;
|
||||
cu_graph_e.assume_init()
|
||||
};
|
||||
for _i in 0..100 {
|
||||
unsafe {
|
||||
cudarc::driver::sys::lib()
|
||||
.cuGraphLaunch(cu_graph_e, *cu_stream)
|
||||
.result()?
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
cuda_graph()?;
|
||||
return Ok(());
|
||||
let device = Device::new_cuda_with_stream(0)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
|
Reference in New Issue
Block a user