mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add some first binary op (add).
This commit is contained in:
@ -5,9 +5,9 @@ fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?;
|
||||
println!("{:?}", x.to_vec1::<f32>()?);
|
||||
let x = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?;
|
||||
let y = (x * 3.)?;
|
||||
println!("{:?}", y.to_vec1::<f32>()?);
|
||||
let y = Tensor::new(&[2f32, 7., 1., 8., 2.], &device)?;
|
||||
let z = (y + x * 3.)?;
|
||||
println!("{:?}", z.to_vec1::<f32>()?);
|
||||
let x = Tensor::ones((3, 2), DType::F32, &device)?;
|
||||
println!("{:?}", x.to_vec2::<f32>()?);
|
||||
Ok(())
|
||||
|
@ -4,13 +4,14 @@
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *dims, \
|
||||
const size_t *lhs_strides, \
|
||||
const size_t *rhs_strides, \
|
||||
const size_t *dims_and_strides, \
|
||||
const TYPENAME *lhs, \
|
||||
const TYPENAME *rhs, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = dims_and_strides; \
|
||||
const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \
|
||||
const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned int tmp_i = i; \
|
||||
unsigned int lhs_i = 0; \
|
||||
|
@ -16,6 +16,9 @@ pub enum CudaError {
|
||||
|
||||
#[error("missing kernel '{module_name}'")]
|
||||
MissingKernel { module_name: &'static str },
|
||||
|
||||
#[error("internal error '{0}'")]
|
||||
InternalError(&'static str),
|
||||
}
|
||||
|
||||
type Result<T> = std::result::Result<T, CudaError>;
|
||||
@ -163,6 +166,44 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<Self> {
|
||||
let elem_count = shape.elem_count();
|
||||
let dims = shape.dims();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let dev = self.device();
|
||||
let dims_and_strides = [dims, lhs_stride, rhs_stride].concat();
|
||||
match (self, rhs) {
|
||||
(Self::F32(lhs), Self::F32(rhs)) => {
|
||||
let func = dev.get_or_load_func("badd_f32", kernels::BINARY_ADD)?;
|
||||
// SAFETY: Set later by running the add kernel.
|
||||
let out = unsafe { dev.0.alloc::<f32>(elem_count) }?;
|
||||
let dims_and_strides = dev.0.htod_copy(dims_and_strides)?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F32(out))
|
||||
}
|
||||
(Self::F64(lhs), Self::F64(rhs)) => {
|
||||
// SAFETY: Set later by running the add kernel.
|
||||
let func = dev.get_or_load_func("badd_f64", kernels::BINARY_ADD)?;
|
||||
let out = unsafe { dev.0.alloc::<f64>(elem_count) }?;
|
||||
let dims_and_strides = dev.0.htod_copy(dims_and_strides)?;
|
||||
let params = (elem_count, dims.len(), &dims_and_strides, lhs, rhs, &out);
|
||||
// SAFETY: ffi
|
||||
unsafe { func.launch(cfg, params) }?;
|
||||
Ok(Self::F64(out))
|
||||
}
|
||||
// The dtypes should have been checked at this point so this is an internal error.
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in add")),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
match self {
|
||||
Self::F32(slice) => {
|
||||
|
@ -53,4 +53,8 @@ impl CudaStorage {
|
||||
pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn add_impl(&self, _: &Self, _: &Shape, _: &[usize], _: &[usize]) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
@ -16,6 +16,13 @@ pub(crate) trait BinaryOp {
|
||||
const NAME: &'static str;
|
||||
fn f32(v1: f32, v2: f32) -> f32;
|
||||
fn f64(v1: f64, v2: f64) -> f64;
|
||||
fn cuda_impl(
|
||||
lhs: &CudaStorage,
|
||||
rhs: &CudaStorage,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<CudaStorage>;
|
||||
}
|
||||
|
||||
struct Add;
|
||||
@ -34,6 +41,15 @@ impl BinaryOp for Add {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 + v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
lhs: &CudaStorage,
|
||||
rhs: &CudaStorage,
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
rhs_stride: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
Ok(lhs.add_impl(rhs, shape, lhs_stride, rhs_stride)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Sub {
|
||||
@ -44,6 +60,15 @@ impl BinaryOp for Sub {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 - v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Mul {
|
||||
@ -54,6 +79,15 @@ impl BinaryOp for Mul {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 * v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOp for Div {
|
||||
@ -64,6 +98,15 @@ impl BinaryOp for Div {
|
||||
fn f64(v1: f64, v2: f64) -> f64 {
|
||||
v1 / v2
|
||||
}
|
||||
fn cuda_impl(
|
||||
_: &CudaStorage,
|
||||
_: &CudaStorage,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &[usize],
|
||||
) -> Result<CudaStorage> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
@ -177,7 +220,10 @@ impl Storage {
|
||||
let storage = lhs.binary_impl::<B>(rhs, shape, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda { .. }, Self::Cuda { .. }) => todo!(),
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = B::cuda_impl(lhs, rhs, shape, lhs_stride, rhs_stride)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => {
|
||||
// Should not happen because of the same device check above but we're defensive
|
||||
// anyway.
|
||||
|
Reference in New Issue
Block a user