mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Wrapping code to call the custom op. (#225)
* Wrapping code to call the custom op. * Get the rms example to work. * Get around rustfmt failing in the CI. * Fix the rms computation.
This commit is contained in:
@ -4,6 +4,8 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
mod cuda_kernels;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
use candle::backend::BackendStorage;
|
||||
@ -40,17 +42,30 @@ impl CustomOp1 for LayerNorm {
|
||||
s: &candle::CudaStorage,
|
||||
l: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
let device = s.device().clone();
|
||||
use candle::cuda_backend::{cudarc, WrapErr};
|
||||
use cudarc::driver::{LaunchAsync, LaunchConfig};
|
||||
let (d1, d2) = l.shape().dims2()?;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
let dev = s.device().clone();
|
||||
let s = s.as_cuda_slice::<f32>()?;
|
||||
let s = match l.contiguous_offsets() {
|
||||
None => Err(Error::Wrapped("input has to be contiguous".into()))?,
|
||||
Some((o1, o2)) => s, // TODO: slice with o1 and o2
|
||||
Some((o1, o2)) => s.slice(o1..o2),
|
||||
};
|
||||
let s: std::result::Result<_, candle::cuda_backend::CudaError> =
|
||||
s.try_clone().map_err(|v| v.into());
|
||||
let s = s?;
|
||||
let s = candle::CudaStorage::wrap_cuda_slice(s, device);
|
||||
Ok((s, l.shape().clone()))
|
||||
let elem_count = l.shape().elem_count();
|
||||
let dst = unsafe { dev.alloc::<f32>(elem_count) }.w()?;
|
||||
let func = dev.get_or_load_func("rms_f32", cuda_kernels::LAYERNORM_KERNELS)?;
|
||||
let params = (&dst, &s, 1e-5f32, d1, d2);
|
||||
let cfg = LaunchConfig {
|
||||
grid_dim: (d1, 1, 1),
|
||||
block_dim: (d2, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
|
||||
let dst = candle::CudaStorage::wrap_cuda_slice(dst, dev);
|
||||
Ok((dst, l.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user