Tweaks to the graph experiment.

This commit is contained in:
laurent
2024-10-03 17:12:52 +02:00
parent b2956857ef
commit 1bb68854d3

View File

@ -7,6 +7,8 @@ extern crate intel_mkl_src;
use anyhow::Result; use anyhow::Result;
use candle_core::{Device, Tensor}; use candle_core::{Device, Tensor};
const USE_CUDA_GRAPH: bool = true;
fn cuda_graph() -> Result<()> { fn cuda_graph() -> Result<()> {
let device = Device::new_cuda_with_stream(0)?; let device = Device::new_cuda_with_stream(0)?;
let cu_device = match &device { let cu_device = match &device {
@ -24,8 +26,11 @@ fn cuda_graph() -> Result<()> {
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)? let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?; .to_dtype(candle_core::DType::BF16)?;
let _x = x.mul(&u)?.broadcast_add(&v)?; let _x = x.mul(&u)?.broadcast_add(&v)?;
let _x = x.affine(1., 0.5)?;
x.slice_set(&u, 0, 0)?;
device.synchronize()?; device.synchronize()?;
} }
if USE_CUDA_GRAPH {
unsafe { unsafe {
cudarc::driver::sys::lib() cudarc::driver::sys::lib()
.cuStreamBeginCapture_v2( .cuStreamBeginCapture_v2(
@ -33,38 +38,53 @@ fn cuda_graph() -> Result<()> {
cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL, cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
) )
.result()? .result()?
}; }
}
{ {
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?; .to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?; .to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)? let v = Tensor::zeros((4096, 1), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?; .to_dtype(candle_core::DType::BF16)?;
for _i in 0..1 { for _i in 0..100 {
x = x.mul(&u)?.broadcast_add(&v)?; // x.slice_set(&u, 0, 0)?;
// x.broadcast_add(&v)?;
x = x.affine(1., 0.5)?;
// x = (&u + &x)?;
} }
} }
let cu_graph = unsafe { if USE_CUDA_GRAPH {
let cu_graph: cudarc::driver::sys::CUgraph = unsafe {
let mut cu_graph = std::mem::MaybeUninit::uninit(); let mut cu_graph = std::mem::MaybeUninit::uninit();
cudarc::driver::sys::lib() cudarc::driver::sys::lib()
.cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr()) .cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr())
.result()?; .result()?;
cu_graph.assume_init() cu_graph.assume_init()
}; };
let cu_graph_e = unsafe { let cu_graph_e: cudarc::driver::sys::CUgraphExec = unsafe {
let mut cu_graph_e = std::mem::MaybeUninit::uninit(); let mut cu_graph_e = std::mem::MaybeUninit::uninit();
cudarc::driver::sys::lib() cudarc::driver::sys::lib()
.cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0) .cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0)
.result()?; .result()?;
cu_graph_e.assume_init() cu_graph_e.assume_init()
}; };
for _i in 0..100 { println!("graph captured!");
for i in 1..100 {
println!("graph exec {i}");
unsafe { unsafe {
cudarc::driver::sys::lib() cudarc::driver::sys::lib()
.cuGraphLaunch(cu_graph_e, *cu_stream) .cuGraphLaunch(cu_graph_e, *cu_stream)
.result()? .result()?
} }
println!("sync");
if let Err(err) = device.synchronize() {
println!("err: {err:?}")
}
println!("done syncing");
}
} else {
device.synchronize()?;
} }
Ok(()) Ok(())
} }