Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)

* add sigmoid op

* small fix

* add as a method on `Tensor`

* implement gradient calculation for sigmoid

* add sigmoid tests

* we should have a specialized op for this

* fix clippy

* fix clippy 2

* Revert all previous commits in favor of a `CustomOp` based solution

* use `CustomOp1` implementation

* fix rustfmt

* experimental add metal impl

* add cuda kernel impl

* fix fmt

* Add a test + reduce some cuda duplication.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
MilkFather
2024-04-29 17:04:43 +08:00
committed by GitHub
parent ed7b99f525
commit 3bbb88fcb4
6 changed files with 214 additions and 5 deletions

View File

@ -43,9 +43,193 @@ pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
&xs[0].silu()? * &xs[1]
}
struct Sigmoid;
impl candle::CustomOp1 for Sigmoid {
fn name(&self) -> &'static str {
"sigmoid"
}
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
use candle::backend::BackendStorage;
fn fwd<T: num_traits::Float>(v: T) -> T {
(v.neg().exp() + T::one()).recip()
}
// FIXME: using `candle::map_dtype` causes compilation errors.
let storage = match storage {
CpuStorage::BF16(slice) => {
CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
}
CpuStorage::F16(slice) => {
CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
}
CpuStorage::F32(slice) => {
CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
}
CpuStorage::F64(slice) => {
CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
}
_ => Err(candle::Error::UnsupportedDTypeForOp(
storage.dtype(),
self.name(),
))?,
};
Ok((storage, layout.shape().clone()))
}
#[cfg(feature = "cuda")]
fn cuda_fwd(
&self,
storage: &candle::CudaStorage,
layout: &Layout,
) -> Result<(candle::CudaStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::{
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
};
use candle::cuda_backend::SlicePtrOrNull;
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
use candle::{CudaDevice, WithDType};
struct S;
impl Map1 for S {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
src: &CudaSlice<T>,
dev: &CudaDevice,
layout: &Layout,
) -> Result<CudaSlice<T>> {
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let cfg = LaunchConfig::for_num_elems(el_count as u32);
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
let src = &src.slice(layout.start_offset()..);
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
let params = (el_count, dims.len(), &ds, src, &out);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
let dev = storage.device();
let slice = S.map(&storage.slice, dev, layout)?;
let dst = candle::CudaStorage {
slice,
device: dev.clone(),
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::BackendStorage;
use candle::MetalError;
let device = storage.device();
let dtype = storage.dtype();
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
let command_buffer = device.command_buffer()?;
command_buffer.set_label("sigmoid");
let src = candle_metal_kernels::BufferOffset {
buffer: storage.buffer(),
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
};
match (el_count % 2, dtype, layout.is_contiguous()) {
(0, DType::BF16 | DType::F16, true) => {
use candle_metal_kernels::unary::contiguous_tiled;
let kernel_name = match dtype {
DType::F16 => contiguous_tiled::sigmoid::HALF,
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
dtype => {
candle::bail!(
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
)
}
};
candle_metal_kernels::call_unary_contiguous_tiled(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, true) => {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match dtype {
DType::F16 => contiguous::sigmoid::HALF,
DType::F32 => contiguous::sigmoid::FLOAT,
DType::BF16 => contiguous::sigmoid::BFLOAT,
dtype => {
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
}
};
candle_metal_kernels::call_unary_contiguous(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
el_count,
src,
&buffer,
)
.map_err(MetalError::from)?;
}
(_, _, false) => {
use candle_metal_kernels::unary::strided;
let kernel_name = match dtype {
DType::F16 => strided::sigmoid::HALF,
DType::F32 => strided::sigmoid::FLOAT,
DType::BF16 => strided::sigmoid::BFLOAT,
dtype => {
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
}
};
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
candle_metal_kernels::call_unary_strided(
device.metal_device(),
&command_buffer,
device.kernels(),
kernel_name,
layout.dims(),
src,
layout.stride(),
dst,
)
.map_err(MetalError::from)?;
}
}
let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
Ok((new_storage, layout.shape().clone()))
}
fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
// d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
}
}
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
(xs.neg()?.exp()? + 1.0)?.recip()
xs.apply_op1(Sigmoid)
}
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {