mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
5 Commits
0.9.1
...
cuda-graph
Author | SHA1 | Date | |
---|---|---|---|
543b5b5898 | |||
c87f0fa5d6 | |||
1bb68854d3 | |||
b2956857ef | |||
9076dee432 |
@ -42,7 +42,7 @@ clap = { workspace = true }
|
|||||||
criterion = { workspace = true }
|
criterion = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = ["cuda"]
|
||||||
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
|
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
|
||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
@ -56,3 +56,7 @@ harness = false
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "metal_basics"
|
name = "metal_basics"
|
||||||
required-features = ["metal"]
|
required-features = ["metal"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "cuda_basics"
|
||||||
|
required-features = ["cuda"]
|
||||||
|
@ -7,8 +7,79 @@ extern crate intel_mkl_src;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
|
const USE_CUDA_GRAPH: bool = true;
|
||||||
|
|
||||||
|
fn cuda_graph() -> Result<()> {
|
||||||
|
let device = Device::new_cuda_with_stream(0)?;
|
||||||
|
let cu_device = match &device {
|
||||||
|
Device::Cuda(dev) => dev,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let cu_stream = cu_device.cuda_stream();
|
||||||
|
{
|
||||||
|
// load_ptx cannot be called while capturing the stream so we need this to happen
|
||||||
|
// beforehand.
|
||||||
|
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let _x = x.mul(&u)?.broadcast_add(&v)?;
|
||||||
|
let _x = x.affine(1., 0.5)?;
|
||||||
|
x.slice_set(&u, 0, 0)?;
|
||||||
|
device.synchronize()?;
|
||||||
|
}
|
||||||
|
if USE_CUDA_GRAPH {
|
||||||
|
cu_stream.begin_capture(
|
||||||
|
cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let v = Tensor::zeros((4096, 1), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
for _i in 0..100 {
|
||||||
|
// x.slice_set(&u, 0, 0)?;
|
||||||
|
// x.broadcast_add(&v)?;
|
||||||
|
x = x.affine(1., 0.5)?;
|
||||||
|
// x = (&u + &x)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if USE_CUDA_GRAPH {
|
||||||
|
println!("capturing graph");
|
||||||
|
let cu_graph = match cu_stream.end_capture(
|
||||||
|
cudarc::driver::sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
|
||||||
|
)? {
|
||||||
|
None => anyhow::bail!("no graph captured"),
|
||||||
|
Some(cu_graph) => cu_graph,
|
||||||
|
};
|
||||||
|
println!("graph captured!");
|
||||||
|
for i in 1..100 {
|
||||||
|
println!("graph exec {i}");
|
||||||
|
cu_graph.launch()?;
|
||||||
|
println!("sync");
|
||||||
|
if let Err(err) = device.synchronize() {
|
||||||
|
println!("err: {err:?}")
|
||||||
|
}
|
||||||
|
println!("done syncing");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device.synchronize()?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
cuda_graph()?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _matmul() -> Result<()> {
|
||||||
|
let device = Device::new_cuda_with_stream(0)?;
|
||||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||||
.to_dtype(candle_core::DType::BF16)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||||
|
Reference in New Issue
Block a user