mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
More cuda graph attempts.
This commit is contained in:
@ -17,9 +17,13 @@ fn cuda_graph() -> Result<()> {
|
||||
{
|
||||
// load_ptx cannot be called while capturing the stream so we need this to happen
|
||||
// beforehand.
|
||||
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
||||
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
||||
y.slice_set(&x, 0, 0)?;
|
||||
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let _x = x.mul(&u)?.broadcast_add(&v)?;
|
||||
device.synchronize()?;
|
||||
}
|
||||
unsafe {
|
||||
@ -31,10 +35,15 @@ fn cuda_graph() -> Result<()> {
|
||||
.result()?
|
||||
};
|
||||
{
|
||||
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
||||
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
||||
y.slice_set(&x, 0, 0)?;
|
||||
// let y = x.affine(2., 1.)?;
|
||||
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
for _i in 0..1 {
|
||||
x = x.mul(&u)?.broadcast_add(&v)?;
|
||||
}
|
||||
}
|
||||
let cu_graph = unsafe {
|
||||
let mut cu_graph = std::mem::MaybeUninit::uninit();
|
||||
|
Reference in New Issue
Block a user