mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add cuda support for unary ops.
This commit is contained in:
@ -8,6 +8,7 @@ fn main() -> Result<()> {
|
||||
let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?;
|
||||
let z = (y + x * 3.)?;
|
||||
println!("{:?}", z.to_vec1::<f32>()?);
|
||||
println!("{:?}", z.sqrt()?.to_vec1::<f32>()?);
|
||||
let x = Tensor::ones((3, 2), DType::F32, &device)?;
|
||||
println!("{:?}", x.to_vec2::<f32>()?);
|
||||
Ok(())
|
||||
|
@ -2,28 +2,16 @@
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
BINARY_OP(__half, badd_f16, x + y)
|
||||
BINARY_OP(__half, bdiv_f16, x / y)
|
||||
BINARY_OP(__half, bmul_f16, x * y)
|
||||
BINARY_OP(__half, bsub_f16, x - y)
|
||||
#endif
|
||||
|
||||
BINARY_OP(float, badd_f32, x + y)
|
||||
BINARY_OP(double, badd_fwd_f64, x + y);
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
BINARY_OP(__half, bdiv_f16, x / y)
|
||||
#endif
|
||||
|
||||
BINARY_OP(float, bdiv_f32, x / y)
|
||||
BINARY_OP(double, bdiv_f64, x / y);
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
BINARY_OP(__half, bmul_f16, x * y)
|
||||
#endif
|
||||
|
||||
BINARY_OP(float, bmul_f32, x * y)
|
||||
BINARY_OP(double, bmul_f64, x * y);
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
BINARY_OP(__half, bsub_f16, x - y)
|
||||
#endif
|
||||
|
||||
BINARY_OP(float, bsub_f32, x - y)
|
||||
BINARY_OP(double, bsub_f64, x - y);
|
||||
|
@ -1,3 +1,4 @@
|
||||
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
|
||||
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
26
kernels/src/unary.cu
Normal file
26
kernels/src/unary.cu
Normal file
@ -0,0 +1,26 @@
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
#define UNARY_OP(TYPENAME, FN_NAME, FUNC) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
TYPENAME x = inp ? inp[i] : out[i]; \
|
||||
out[i] = FUNC; \
|
||||
} \
|
||||
} \
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
UNARY_OP(__half, uneg_f16, -x)
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||
#endif
|
||||
|
||||
UNARY_OP(float, uneg_f32, -x)
|
||||
UNARY_OP(float, uneg_f64, -x)
|
||||
UNARY_OP(float, usqr_f32, x*x)
|
||||
UNARY_OP(float, usqr_f64, x*x)
|
||||
UNARY_OP(float, usqrt_f32, sqrtg(x))
|
||||
UNARY_OP(float, usqrt_f64, sqrtg(x))
|
@ -145,21 +145,53 @@ impl CudaStorage {
|
||||
match self {
|
||||
Self::F32(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?;
|
||||
// SAFETY: if this function returns Ok(..), the kernel has been applied
|
||||
// and has set the initially unset memory.
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out, mul as f32, add as f32);
|
||||
// SAFETY: well, well, well...
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F32(out))
|
||||
}
|
||||
Self::F64(arg) => {
|
||||
let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?;
|
||||
// SAFETY: if this function returns Ok(..), the kernel has been applied
|
||||
// and has set the initially unset memory.
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out, mul, add);
|
||||
// SAFETY: well, well, well...
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F64(out))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<U: crate::storage::UnaryOp>(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
if !shape.is_contiguous(stride) {
|
||||
return Err(CudaError::RequiresContiguous { op: "affine" });
|
||||
}
|
||||
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dev = self.device();
|
||||
match self {
|
||||
Self::F32(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F32(out))
|
||||
}
|
||||
Self::F64(arg) => {
|
||||
let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
|
||||
let params = (elem_count, arg, &out);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F64(out))
|
||||
}
|
||||
|
@ -54,6 +54,14 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn unary_impl<B: crate::storage::UnaryOp>(
|
||||
&self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn binary_impl<B: crate::storage::BinaryOp>(
|
||||
&self,
|
||||
_: &Self,
|
||||
|
@ -8,12 +8,18 @@ pub enum Storage {
|
||||
|
||||
pub(crate) trait UnaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32) -> f32;
|
||||
fn f64(v1: f64) -> f64;
|
||||
}
|
||||
|
||||
pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
// TODO: These kernels are compatible with arbitrary strides. We should also consider the
|
||||
// contiguous case separately as it's easy to optimize things out there.
|
||||
const KERNEL_F32: &'static str;
|
||||
const KERNEL_F64: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
@ -84,6 +90,8 @@ impl UnaryOp for Neg {
|
||||
fn f64(v1: f64) -> f64 {
|
||||
-v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uneg_f32";
|
||||
const KERNEL_F64: &'static str = "uneg_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqr {
|
||||
@ -94,6 +102,8 @@ impl UnaryOp for Sqr {
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1 * v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqr_f32";
|
||||
const KERNEL_F64: &'static str = "usqr_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sqrt {
|
||||
@ -104,6 +114,8 @@ impl UnaryOp for Sqrt {
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sqrt()
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usqrt_f32";
|
||||
const KERNEL_F64: &'static str = "usqrt_f64";
|
||||
}
|
||||
|
||||
impl Storage {
|
||||
@ -168,7 +180,10 @@ impl Storage {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
Self::Cuda { .. } => todo!(),
|
||||
Self::Cuda(storage) => {
|
||||
let storage = storage.unary_impl::<B>(shape, stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -269,7 +284,14 @@ impl Storage {
|
||||
let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
_ => todo!(),
|
||||
(Self::Cuda(_), Self::Cuda(_)) => {
|
||||
todo!()
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "matmul",
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user