mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op * small fix * add as a method on `Tensor` * implement gradient calculation for sigmoid * add sigmoid tests * we should have a specialized op for this * fix clippy * fix clippy 2 * Revert all previous commits in favor of a `CustomOp` based solution * use `CustomOp1` implementation * fix rustfmt * experimental add metal impl * add cuda kernel impl * fix fmt * Add a test + reduce some cuda duplication. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -43,9 +43,193 @@ pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||
&xs[0].silu()? * &xs[1]
|
||||
}
|
||||
|
||||
struct Sigmoid;
|
||||
|
||||
impl candle::CustomOp1 for Sigmoid {
|
||||
fn name(&self) -> &'static str {
|
||||
"sigmoid"
|
||||
}
|
||||
|
||||
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
|
||||
fn fwd<T: num_traits::Float>(v: T) -> T {
|
||||
(v.neg().exp() + T::one()).recip()
|
||||
}
|
||||
|
||||
// FIXME: using `candle::map_dtype` causes compilation errors.
|
||||
let storage = match storage {
|
||||
CpuStorage::BF16(slice) => {
|
||||
CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F16(slice) => {
|
||||
CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F32(slice) => {
|
||||
CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
CpuStorage::F64(slice) => {
|
||||
CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||
}
|
||||
_ => Err(candle::Error::UnsupportedDTypeForOp(
|
||||
storage.dtype(),
|
||||
self.name(),
|
||||
))?,
|
||||
};
|
||||
Ok((storage, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &candle::CudaStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::CudaStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::cuda_backend::cudarc::driver::{
|
||||
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||
};
|
||||
use candle::cuda_backend::SlicePtrOrNull;
|
||||
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||
use candle::{CudaDevice, WithDType};
|
||||
|
||||
struct S;
|
||||
impl Map1 for S {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||
|
||||
let params = (el_count, dims.len(), &ds, src, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
let dev = storage.device();
|
||||
let slice = S.map(&storage.slice, dev, layout)?;
|
||||
let dst = candle::CudaStorage {
|
||||
slice,
|
||||
device: dev.clone(),
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::BackendStorage;
|
||||
use candle::MetalError;
|
||||
let device = storage.device();
|
||||
let dtype = storage.dtype();
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label("sigmoid");
|
||||
let src = candle_metal_kernels::BufferOffset {
|
||||
buffer: storage.buffer(),
|
||||
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||
};
|
||||
|
||||
match (el_count % 2, dtype, layout.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous_tiled::sigmoid::HALF,
|
||||
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
|
||||
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!(
|
||||
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
|
||||
)
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous_tiled(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous::sigmoid::HALF,
|
||||
DType::F32 => contiguous::sigmoid::FLOAT,
|
||||
DType::BF16 => contiguous::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
el_count,
|
||||
src,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => strided::sigmoid::HALF,
|
||||
DType::F32 => strided::sigmoid::FLOAT,
|
||||
DType::BF16 => strided::sigmoid::BFLOAT,
|
||||
dtype => {
|
||||
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
|
||||
}
|
||||
};
|
||||
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
device.kernels(),
|
||||
kernel_name,
|
||||
layout.dims(),
|
||||
src,
|
||||
layout.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
|
||||
let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
|
||||
Ok((new_storage, layout.shape().clone()))
|
||||
}
|
||||
|
||||
fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||
// d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
|
||||
let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
|
||||
Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: Should we have a specialized op for this?
|
||||
(xs.neg()?.exp()? + 1.0)?.recip()
|
||||
xs.apply_op1(Sigmoid)
|
||||
}
|
||||
|
||||
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user