More cuda graph attempts.

This commit is contained in:
laurent
2024-10-03 12:43:08 +02:00
parent 9076dee432
commit b2956857ef

View File

@ -17,9 +17,13 @@ fn cuda_graph() -> Result<()> {
{
// load_ptx cannot be called while capturing the stream so we need this to happen
// beforehand.
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
y.slice_set(&x, 0, 0)?;
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let _x = x.mul(&u)?.broadcast_add(&v)?;
device.synchronize()?;
}
unsafe {
@ -31,10 +35,15 @@ fn cuda_graph() -> Result<()> {
.result()?
};
{
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
y.slice_set(&x, 0, 0)?;
// let y = x.affine(2., 1.)?;
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
.to_dtype(candle_core::DType::BF16)?;
for _i in 0..1 {
x = x.mul(&u)?.broadcast_add(&v)?;
}
}
let cu_graph = unsafe {
let mut cu_graph = std::mem::MaybeUninit::uninit();