mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Tweaks to the graph experiment.
This commit is contained in:
@ -7,6 +7,8 @@ extern crate intel_mkl_src;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
|
const USE_CUDA_GRAPH: bool = true;
|
||||||
|
|
||||||
fn cuda_graph() -> Result<()> {
|
fn cuda_graph() -> Result<()> {
|
||||||
let device = Device::new_cuda_with_stream(0)?;
|
let device = Device::new_cuda_with_stream(0)?;
|
||||||
let cu_device = match &device {
|
let cu_device = match &device {
|
||||||
@ -24,8 +26,11 @@ fn cuda_graph() -> Result<()> {
|
|||||||
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
||||||
.to_dtype(candle_core::DType::BF16)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
let _x = x.mul(&u)?.broadcast_add(&v)?;
|
let _x = x.mul(&u)?.broadcast_add(&v)?;
|
||||||
|
let _x = x.affine(1., 0.5)?;
|
||||||
|
x.slice_set(&u, 0, 0)?;
|
||||||
device.synchronize()?;
|
device.synchronize()?;
|
||||||
}
|
}
|
||||||
|
if USE_CUDA_GRAPH {
|
||||||
unsafe {
|
unsafe {
|
||||||
cudarc::driver::sys::lib()
|
cudarc::driver::sys::lib()
|
||||||
.cuStreamBeginCapture_v2(
|
.cuStreamBeginCapture_v2(
|
||||||
@ -33,38 +38,53 @@ fn cuda_graph() -> Result<()> {
|
|||||||
cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
||||||
)
|
)
|
||||||
.result()?
|
.result()?
|
||||||
};
|
}
|
||||||
|
}
|
||||||
{
|
{
|
||||||
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
.to_dtype(candle_core::DType::BF16)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
.to_dtype(candle_core::DType::BF16)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
let v = Tensor::zeros((4096, 1), candle_core::DType::F32, &device)?
|
||||||
.to_dtype(candle_core::DType::BF16)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
for _i in 0..1 {
|
for _i in 0..100 {
|
||||||
x = x.mul(&u)?.broadcast_add(&v)?;
|
// x.slice_set(&u, 0, 0)?;
|
||||||
|
// x.broadcast_add(&v)?;
|
||||||
|
x = x.affine(1., 0.5)?;
|
||||||
|
// x = (&u + &x)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let cu_graph = unsafe {
|
if USE_CUDA_GRAPH {
|
||||||
|
let cu_graph: cudarc::driver::sys::CUgraph = unsafe {
|
||||||
let mut cu_graph = std::mem::MaybeUninit::uninit();
|
let mut cu_graph = std::mem::MaybeUninit::uninit();
|
||||||
cudarc::driver::sys::lib()
|
cudarc::driver::sys::lib()
|
||||||
.cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr())
|
.cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr())
|
||||||
.result()?;
|
.result()?;
|
||||||
cu_graph.assume_init()
|
cu_graph.assume_init()
|
||||||
};
|
};
|
||||||
let cu_graph_e = unsafe {
|
let cu_graph_e: cudarc::driver::sys::CUgraphExec = unsafe {
|
||||||
let mut cu_graph_e = std::mem::MaybeUninit::uninit();
|
let mut cu_graph_e = std::mem::MaybeUninit::uninit();
|
||||||
cudarc::driver::sys::lib()
|
cudarc::driver::sys::lib()
|
||||||
.cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0)
|
.cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0)
|
||||||
.result()?;
|
.result()?;
|
||||||
cu_graph_e.assume_init()
|
cu_graph_e.assume_init()
|
||||||
};
|
};
|
||||||
for _i in 0..100 {
|
println!("graph captured!");
|
||||||
|
for i in 1..100 {
|
||||||
|
println!("graph exec {i}");
|
||||||
unsafe {
|
unsafe {
|
||||||
cudarc::driver::sys::lib()
|
cudarc::driver::sys::lib()
|
||||||
.cuGraphLaunch(cu_graph_e, *cu_stream)
|
.cuGraphLaunch(cu_graph_e, *cu_stream)
|
||||||
.result()?
|
.result()?
|
||||||
}
|
}
|
||||||
|
println!("sync");
|
||||||
|
if let Err(err) = device.synchronize() {
|
||||||
|
println!("err: {err:?}")
|
||||||
|
}
|
||||||
|
println!("done syncing");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device.synchronize()?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user