From b2956857efcd7aecc6e53f53d761503ef118d3be Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 3 Oct 2024 12:43:08 +0200 Subject: [PATCH] More cuda graph attempts. --- candle-core/examples/cuda_basics.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 31db9e81..0cd933b5 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -17,9 +17,13 @@ fn cuda_graph() -> Result<()> { { // load_ptx cannot be called while capturing the stream so we need this to happen // beforehand. - let x = Tensor::zeros(16, candle_core::DType::F32, &device)?; - let y = Tensor::zeros(16, candle_core::DType::F32, &device)?; - y.slice_set(&x, 0, 0)?; + let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let v = Tensor::zeros(4096, candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let _x = x.mul(&u)?.broadcast_add(&v)?; device.synchronize()?; } unsafe { @@ -31,10 +35,15 @@ fn cuda_graph() -> Result<()> { .result()? }; { - let x = Tensor::zeros(16, candle_core::DType::F32, &device)?; - let y = Tensor::zeros(16, candle_core::DType::F32, &device)?; - y.slice_set(&x, 0, 0)?; - // let y = x.affine(2., 1.)?; + let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let v = Tensor::zeros(4096, candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + for _i in 0..1 { + x = x.mul(&u)?.broadcast_add(&v)?; + } } let cu_graph = unsafe { let mut cu_graph = std::mem::MaybeUninit::uninit();