Add the elu op. (#113)

This commit is contained in:
Laurent Mazare
2023-07-09 21:56:31 +01:00
committed by GitHub
parent ea5dfa69bc
commit 270997a055
8 changed files with 95 additions and 0 deletions

View File

@ -76,6 +76,7 @@ impl Tensor {
| Op::Sqrt(node)
| Op::Gelu(node)
| Op::Relu(node)
| Op::Elu(node, _)
| Op::Exp(node)
| Op::Log(node)
| Op::Sin(node)
@ -250,6 +251,7 @@ impl Tensor {
}
Op::Gelu(_) => return Err(Error::BackwardNotSupported { op: "gelu" }),
Op::Relu(_) => return Err(Error::BackwardNotSupported { op: "relu" }),
Op::Elu(..) => return Err(Error::BackwardNotSupported { op: "elu" }),
Op::Sqr(arg) => {
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
let sum_grad = grads.or_insert(arg)?;

View File

@ -463,6 +463,14 @@ fn divide_by_sum_over_dim<T: WithDType>(s: &mut [T], shape: &Shape, dim: usize)
Ok(())
}
fn elu<T: num_traits::Float>(v: T, alpha: T) -> T {
if v.is_sign_positive() {
v
} else {
(v.exp() - T::one()) * alpha
}
}
impl CpuStorage {
pub fn dtype(&self) -> DType {
match self {
@ -666,6 +674,30 @@ impl CpuStorage {
Affine(mul, add).map(self, layout)
}
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
// TODO: Have some generic map for functions that apply on num_traits::Float elements.
match self {
Self::BF16(storage) => {
let data = unary_map(storage, layout, |v| elu(v, bf16::from_f64(alpha)));
Ok(Self::BF16(data))
}
Self::F16(storage) => {
let data = unary_map(storage, layout, |v| elu(v, f16::from_f64(alpha)));
Ok(Self::F16(data))
}
Self::F32(storage) => {
let data = unary_map(storage, layout, |v| elu(v, f32::from_f64(alpha)));
Ok(Self::F32(data))
}
Self::F64(storage) => {
let data = unary_map(storage, layout, |v| elu(v, alpha));
Ok(Self::F64(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu")),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu")),
}
}
pub(crate) fn unary_impl<B: UnaryOp>(&self, layout: &Layout) -> Result<Self> {
match self {
Self::BF16(storage) => {

View File

@ -357,6 +357,30 @@ impl Map1 for Affine {
}
}
struct Elu(f64);
impl Map1 for Elu {
fn f<T: DeviceRepr + WithDType>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el as u32);
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("uelu"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el) }?;
let params = (el, dims.len(), &ds, T::from_f64(self.0), src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }?;
Ok(out)
}
}
#[allow(dead_code)]
struct Sum<'a>(&'a [usize]);
impl<'a> Map1 for Sum<'a> {
@ -815,6 +839,12 @@ impl CudaStorage {
Ok(Self { slice, device })
}
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
let device = self.device().clone();
let slice = Elu(alpha).map(&self.slice, &device, layout)?;
Ok(Self { slice, device })
}
pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
let device = self.device().clone();
let slice = FastSum(sum_dims).map(&self.slice, &device, layout)?;

View File

@ -64,6 +64,10 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -127,6 +127,9 @@ pub enum Error {
#[error("unsupported safetensor dtype {0:?}")]
UnsupportedSafeTensorDtype(safetensors::Dtype),
#[error("unsupported dtype {0:?} for op {1}")]
UnsupportedDTypeForOp(DType, &'static str),
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },

View File

@ -46,6 +46,7 @@ pub(crate) enum Op {
Transpose(Tensor, usize, usize),
Gelu(Tensor),
Relu(Tensor),
Elu(Tensor, f64),
// TODO: Support for custom ops.
}

View File

@ -66,6 +66,19 @@ impl Storage {
}
}
pub(crate) fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
match self {
Storage::Cpu(storage) => {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Cpu(storage))
}
Self::Cuda(storage) => {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Cuda(storage))
}
}
}
pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result<Self> {
match self {
Storage::Cpu(storage) => {

View File

@ -349,6 +349,16 @@ impl Tensor {
Ok(from_storage(storage, self.shape(), op, false))
}
pub fn elu(&self, alpha: f64) -> Result<Self> {
let storage = self.storage.elu(self.layout(), alpha)?;
let op = if self.track_op() {
Some(Op::Elu(self.clone(), alpha))
} else {
None
};
Ok(from_storage(storage, self.shape(), op, false))
}
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
if dim >= self.dims().len() {
Err(Error::DimOutOfRange {