mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
More cuda graph attempts.
This commit is contained in:
@ -17,9 +17,13 @@ fn cuda_graph() -> Result<()> {
|
|||||||
{
|
{
|
||||||
// load_ptx cannot be called while capturing the stream so we need this to happen
|
// load_ptx cannot be called while capturing the stream so we need this to happen
|
||||||
// beforehand.
|
// beforehand.
|
||||||
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
y.slice_set(&x, 0, 0)?;
|
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let _x = x.mul(&u)?.broadcast_add(&v)?;
|
||||||
device.synchronize()?;
|
device.synchronize()?;
|
||||||
}
|
}
|
||||||
unsafe {
|
unsafe {
|
||||||
@ -31,10 +35,15 @@ fn cuda_graph() -> Result<()> {
|
|||||||
.result()?
|
.result()?
|
||||||
};
|
};
|
||||||
{
|
{
|
||||||
let x = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
let y = Tensor::zeros(16, candle_core::DType::F32, &device)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
y.slice_set(&x, 0, 0)?;
|
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
// let y = x.affine(2., 1.)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
for _i in 0..1 {
|
||||||
|
x = x.mul(&u)?.broadcast_add(&v)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
let cu_graph = unsafe {
|
let cu_graph = unsafe {
|
||||||
let mut cu_graph = std::mem::MaybeUninit::uninit();
|
let mut cu_graph = std::mem::MaybeUninit::uninit();
|
||||||
|
Reference in New Issue
Block a user