mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Fix sigmoid gradient calculation and move sigmoid into a specialized op (#2114)
* add sigmoid op * small fix * add as a method on `Tensor` * implement gradient calculation for sigmoid * add sigmoid tests * we should have a specialized op for this * fix clippy * fix clippy 2 * Revert all previous commits in favor of a `CustomOp` based solution * use `CustomOp1` implementation * fix rustfmt * experimental add metal impl * add cuda kernel impl * fix fmt * Add a test + reduce some cuda duplication. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -18,7 +18,7 @@ pub use device::{CudaDevice, DeviceId};
|
|||||||
pub use error::{CudaError, WrapErr};
|
pub use error::{CudaError, WrapErr};
|
||||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
||||||
|
|
||||||
enum SlicePtrOrNull<T> {
|
pub enum SlicePtrOrNull<T> {
|
||||||
Ptr(CudaSlice<T>),
|
Ptr(CudaSlice<T>),
|
||||||
Null,
|
Null,
|
||||||
}
|
}
|
||||||
@ -33,7 +33,7 @@ unsafe impl<T: DeviceRepr> DeviceRepr for &SlicePtrOrNull<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl SlicePtrOrNull<usize> {
|
impl SlicePtrOrNull<usize> {
|
||||||
fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||||
let ds = if l.is_contiguous() {
|
let ds = if l.is_contiguous() {
|
||||||
SlicePtrOrNull::Null
|
SlicePtrOrNull::Null
|
||||||
} else {
|
} else {
|
||||||
|
@ -60,6 +60,11 @@ __device__ __forceinline__ T silu_fwd(T x) {
|
|||||||
return x / (static_cast<T>(1) + expg(-x));
|
return x / (static_cast<T>(1) + expg(-x));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
__device__ __forceinline__ T sigmoid_fwd(T x) {
|
||||||
|
return recipg(static_cast<T>(1) + expg(-x));
|
||||||
|
}
|
||||||
|
|
||||||
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t numel, \
|
const size_t numel, \
|
||||||
@ -116,6 +121,7 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
|||||||
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
|
||||||
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
|
||||||
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
|
UNARY_OP(__nv_bfloat16, usign_bf16, sign_(x))
|
||||||
|
UNARY_OP(__nv_bfloat16, usigmoid_bf16, sigmoid_fwd(x))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
@ -142,6 +148,7 @@ UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
|
|||||||
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
UNARY_OP(__half, usilu_f16, silu_fwd(x))
|
||||||
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
UNARY_OP1(__half, upowf_f16, powg(x, param))
|
||||||
UNARY_OP(__half, usign_f16, sign_(x))
|
UNARY_OP(__half, usign_f16, sign_(x))
|
||||||
|
UNARY_OP(__half, usigmoid_f16, sigmoid_fwd(x))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
UNARY_OP(uint8_t, ucopy_u8, x)
|
UNARY_OP(uint8_t, ucopy_u8, x)
|
||||||
@ -193,3 +200,5 @@ UNARY_OP1(float, upowf_f32, powg(x, param))
|
|||||||
UNARY_OP1(double, upowf_f64, powg(x, param))
|
UNARY_OP1(double, upowf_f64, powg(x, param))
|
||||||
UNARY_OP(float, usign_f32, sign_(x))
|
UNARY_OP(float, usign_f32, sign_(x))
|
||||||
UNARY_OP(double, usign_f64, sign_(x))
|
UNARY_OP(double, usign_f64, sign_(x))
|
||||||
|
UNARY_OP(float, usigmoid_f32, sigmoid_fwd(x))
|
||||||
|
UNARY_OP(double, usigmoid_f64, sigmoid_fwd(x))
|
||||||
|
@ -129,7 +129,7 @@ macro_rules! ops{
|
|||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(
|
ops!(
|
||||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||||
tanh, recip, silu, sign
|
tanh, recip, silu, sign, sigmoid
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
|
@ -67,6 +67,9 @@ template <typename T> METAL_FUNC T relu(T in){
|
|||||||
template <typename T> METAL_FUNC T silu(T in){
|
template <typename T> METAL_FUNC T silu(T in){
|
||||||
return in / (static_cast<T>(1) + exp(-in));
|
return in / (static_cast<T>(1) + exp(-in));
|
||||||
}
|
}
|
||||||
|
template <typename T> METAL_FUNC T sigmoid(T in) {
|
||||||
|
return recip(static_cast<T>(1) + exp(-in));
|
||||||
|
}
|
||||||
|
|
||||||
#define TILE_SIZE 2
|
#define TILE_SIZE 2
|
||||||
|
|
||||||
@ -155,6 +158,7 @@ UNARY_OP(tanh)
|
|||||||
UNARY_OP(recip)
|
UNARY_OP(recip)
|
||||||
UNARY_OP(relu)
|
UNARY_OP(relu)
|
||||||
UNARY_OP(sign)
|
UNARY_OP(sign)
|
||||||
|
UNARY_OP(sigmoid)
|
||||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||||
@ -185,6 +189,7 @@ BFLOAT_UNARY_OP(tanh)
|
|||||||
BFLOAT_UNARY_OP(recip)
|
BFLOAT_UNARY_OP(recip)
|
||||||
BFLOAT_UNARY_OP(relu)
|
BFLOAT_UNARY_OP(relu)
|
||||||
BFLOAT_UNARY_OP(sign)
|
BFLOAT_UNARY_OP(sign)
|
||||||
|
BFLOAT_UNARY_OP(sigmoid)
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||||
|
|
||||||
|
@ -43,9 +43,193 @@ pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
&xs[0].silu()? * &xs[1]
|
&xs[0].silu()? * &xs[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Sigmoid;
|
||||||
|
|
||||||
|
impl candle::CustomOp1 for Sigmoid {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"sigmoid"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape)> {
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
|
||||||
|
fn fwd<T: num_traits::Float>(v: T) -> T {
|
||||||
|
(v.neg().exp() + T::one()).recip()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FIXME: using `candle::map_dtype` causes compilation errors.
|
||||||
|
let storage = match storage {
|
||||||
|
CpuStorage::BF16(slice) => {
|
||||||
|
CpuStorage::BF16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||||
|
}
|
||||||
|
CpuStorage::F16(slice) => {
|
||||||
|
CpuStorage::F16(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||||
|
}
|
||||||
|
CpuStorage::F32(slice) => {
|
||||||
|
CpuStorage::F32(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||||
|
}
|
||||||
|
CpuStorage::F64(slice) => {
|
||||||
|
CpuStorage::F64(candle::cpu_backend::unary_map(slice, layout, fwd))
|
||||||
|
}
|
||||||
|
_ => Err(candle::Error::UnsupportedDTypeForOp(
|
||||||
|
storage.dtype(),
|
||||||
|
self.name(),
|
||||||
|
))?,
|
||||||
|
};
|
||||||
|
Ok((storage, layout.shape().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &candle::CudaStorage,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits,
|
||||||
|
};
|
||||||
|
use candle::cuda_backend::SlicePtrOrNull;
|
||||||
|
use candle::cuda_backend::{kernel_name, kernels, Map1, WrapErr};
|
||||||
|
use candle::{CudaDevice, WithDType};
|
||||||
|
|
||||||
|
struct S;
|
||||||
|
impl Map1 for S {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
|
let src = &src.slice(layout.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("usigmoid"), kernels::UNARY)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(el_count) }.w()?;
|
||||||
|
|
||||||
|
let params = (el_count, dims.len(), &ds, src, &out);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dev = storage.device();
|
||||||
|
let slice = S.map(&storage.slice, dev, layout)?;
|
||||||
|
let dst = candle::CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: dev.clone(),
|
||||||
|
};
|
||||||
|
Ok((dst, layout.shape().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &candle::MetalStorage,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<(candle::MetalStorage, Shape)> {
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
use candle::MetalError;
|
||||||
|
let device = storage.device();
|
||||||
|
let dtype = storage.dtype();
|
||||||
|
let shape = layout.shape();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let buffer = device.new_buffer(el_count, dtype, "sigmoid")?;
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
command_buffer.set_label("sigmoid");
|
||||||
|
let src = candle_metal_kernels::BufferOffset {
|
||||||
|
buffer: storage.buffer(),
|
||||||
|
offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||||
|
};
|
||||||
|
|
||||||
|
match (el_count % 2, dtype, layout.is_contiguous()) {
|
||||||
|
(0, DType::BF16 | DType::F16, true) => {
|
||||||
|
use candle_metal_kernels::unary::contiguous_tiled;
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
DType::F16 => contiguous_tiled::sigmoid::HALF,
|
||||||
|
DType::F32 => contiguous_tiled::sigmoid::FLOAT,
|
||||||
|
DType::BF16 => contiguous_tiled::sigmoid::BFLOAT,
|
||||||
|
dtype => {
|
||||||
|
candle::bail!(
|
||||||
|
"Metal contiguous_tiled unary sigmoid {dtype:?} not implemented"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_unary_contiguous_tiled(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
device.kernels(),
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
src,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
|
(_, _, true) => {
|
||||||
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
DType::F16 => contiguous::sigmoid::HALF,
|
||||||
|
DType::F32 => contiguous::sigmoid::FLOAT,
|
||||||
|
DType::BF16 => contiguous::sigmoid::BFLOAT,
|
||||||
|
dtype => {
|
||||||
|
candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_unary_contiguous(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
device.kernels(),
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
src,
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
|
(_, _, false) => {
|
||||||
|
use candle_metal_kernels::unary::strided;
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
DType::F16 => strided::sigmoid::HALF,
|
||||||
|
DType::F32 => strided::sigmoid::FLOAT,
|
||||||
|
DType::BF16 => strided::sigmoid::BFLOAT,
|
||||||
|
dtype => {
|
||||||
|
candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer);
|
||||||
|
candle_metal_kernels::call_unary_strided(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
device.kernels(),
|
||||||
|
kernel_name,
|
||||||
|
layout.dims(),
|
||||||
|
src,
|
||||||
|
layout.stride(),
|
||||||
|
dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype);
|
||||||
|
Ok((new_storage, layout.shape().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn bwd(&self, _arg: &Tensor, res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||||
|
// d/dx sigmoid(x) = (1 - sigmoid(x)) * sigmoid(x)
|
||||||
|
let d_dx_sigmoid = res.ones_like()?.sub(res)?.mul(res)?;
|
||||||
|
Ok(Some(grad_res.mul(&d_dx_sigmoid)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||||
// TODO: Should we have a specialized op for this?
|
xs.apply_op1(Sigmoid)
|
||||||
(xs.neg()?.exp()? + 1.0)?.recip()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
|
pub fn hard_sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
@ -170,8 +170,19 @@ fn rope_thd(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn sigmoid(device: &Device) -> Result<()> {
|
||||||
|
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
let s1 = candle_nn::ops::sigmoid(&tensor)?;
|
||||||
|
let s2 = (1. / (1. + tensor.neg()?.exp()?)?)?;
|
||||||
|
let diff = (s1 - s2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert_eq!(diff, 0.);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
|
test_device!(ropei, ropei_cpu, ropei_gpu, ropei_metal);
|
||||||
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
||||||
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
||||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||||
|
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
|
||||||
|
Reference in New Issue
Block a user