mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Cuda implementation for copying data around.
This commit is contained in:
@ -18,11 +18,14 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
} \
|
} \
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
|
UNARY_OP(__half, ucopy_f16, x)
|
||||||
UNARY_OP(__half, uneg_f16, -x)
|
UNARY_OP(__half, uneg_f16, -x)
|
||||||
UNARY_OP(__half, usqr_f16, x*x)
|
UNARY_OP(__half, usqr_f16, x*x)
|
||||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
UNARY_OP(float, ucopy_f32, x)
|
||||||
|
UNARY_OP(float, ucopy_f64, x)
|
||||||
UNARY_OP(float, uneg_f32, -x)
|
UNARY_OP(float, uneg_f32, -x)
|
||||||
UNARY_OP(float, uneg_f64, -x)
|
UNARY_OP(float, uneg_f64, -x)
|
||||||
UNARY_OP(float, usqr_f32, x*x)
|
UNARY_OP(float, usqr_f32, x*x)
|
||||||
|
@ -381,4 +381,42 @@ impl CudaStorage {
|
|||||||
let device = dev.clone();
|
let device = dev.clone();
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn copy_strided_src(
|
||||||
|
&self,
|
||||||
|
dst: &mut Self,
|
||||||
|
src_shape: &Shape,
|
||||||
|
src_stride: &[usize],
|
||||||
|
dst_offset: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
let dims = src_shape.dims();
|
||||||
|
let el_count = src_shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
|
let dev = &self.device;
|
||||||
|
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||||
|
match (&self.slice, &mut dst.slice) {
|
||||||
|
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||||
|
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
||||||
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?
|
||||||
|
}
|
||||||
|
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||||
|
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||||
|
let mut dst = dst.slice_mut(dst_offset..);
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let params = (el_count, dims.len(), &ds, src, &mut dst);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
return Err(CudaError::InternalError(
|
||||||
|
"dtype mismatch in copy_strided op",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -85,4 +85,14 @@ impl CudaStorage {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn copy_strided_src(
|
||||||
|
&self,
|
||||||
|
_: &mut Self,
|
||||||
|
_: &Shape,
|
||||||
|
_: &[usize],
|
||||||
|
_: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -160,8 +160,8 @@ impl Storage {
|
|||||||
(Self::Cpu(src), Self::Cpu(dst)) => {
|
(Self::Cpu(src), Self::Cpu(dst)) => {
|
||||||
src.copy_strided_src(dst, src_shape, src_stride, dst_offset)
|
src.copy_strided_src(dst, src_shape, src_stride, dst_offset)
|
||||||
}
|
}
|
||||||
(Self::Cuda(_src), Self::Cuda(_dst)) => {
|
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||||
todo!()
|
Ok(src.copy_strided_src(dst, src_shape, src_stride, dst_offset)?)
|
||||||
}
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
lhs: lhs.device().location(),
|
lhs: lhs.device().location(),
|
||||||
|
@ -596,6 +596,9 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
// TODO: Do we want to allow target shape using -1 on some dimensions?
|
||||||
|
/// Reshape returns a tensor with the target shape provided that the number of elements of the
|
||||||
|
/// original tensor is the same. This uses a new storage and copies the data over, the returned
|
||||||
|
/// tensor is always contiguous.
|
||||||
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
pub fn reshape<S: Into<Shape>>(&self, shape: S) -> Result<Tensor> {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
if shape.elem_count() != self.elem_count() {
|
if shape.elem_count() != self.elem_count() {
|
||||||
|
Reference in New Issue
Block a user