mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add the layernorm specialized op. (#2212)
* Add the layernorm cuda kernels. * Dedicated layer norm op. * Add the slower variant. * Plug the cuda implementation. * Add the metal variant. * Add a dedicated test. * Bugfix.
This commit is contained in:
@ -16,7 +16,7 @@ mod error;
|
||||
mod utils;
|
||||
pub use device::{CudaDevice, DeviceId};
|
||||
pub use error::{CudaError, WrapErr};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
|
||||
|
||||
pub enum SlicePtrOrNull<T> {
|
||||
Ptr(CudaSlice<T>),
|
||||
|
@ -54,6 +54,44 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map3 {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src1: &CudaSlice<T>,
|
||||
layout1: &Layout,
|
||||
src2: &CudaSlice<T>,
|
||||
layout2: &Layout,
|
||||
src3: &CudaSlice<T>,
|
||||
layout3: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaSlice<T>>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn map(
|
||||
&self,
|
||||
s1: &S,
|
||||
l1: &Layout,
|
||||
s2: &S,
|
||||
l2: &Layout,
|
||||
s3: &S,
|
||||
l3: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<S> {
|
||||
let out = match (s1, s2, s3) {
|
||||
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
|
||||
};
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
|
Reference in New Issue
Block a user