mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Get the cuda tests to pass.
This commit is contained in:
@ -27,7 +27,7 @@ pub enum CudaError {
|
|||||||
InternalError(&'static str),
|
InternalError(&'static str),
|
||||||
|
|
||||||
#[error("internal error '{0}'")]
|
#[error("internal error '{0}'")]
|
||||||
WrappedError(Box<dyn std::error::Error>),
|
WrappedError(Box<dyn std::error::Error + 'static + std::marker::Send + std::marker::Sync>),
|
||||||
|
|
||||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
MatMulNonContiguous {
|
MatMulNonContiguous {
|
||||||
@ -245,13 +245,14 @@ enum CudaStorageSlice {
|
|||||||
|
|
||||||
fn slice_src_and_dst<'a, T>(
|
fn slice_src_and_dst<'a, T>(
|
||||||
src: &'a CudaSlice<T>,
|
src: &'a CudaSlice<T>,
|
||||||
src_offset: usize,
|
src_l: &Layout,
|
||||||
dst: &'a mut CudaSlice<T>,
|
dst: &'a mut CudaSlice<T>,
|
||||||
dst_offset: usize,
|
dst_offset: usize,
|
||||||
) -> (
|
) -> (
|
||||||
cudarc::driver::CudaView<'a, T>,
|
cudarc::driver::CudaView<'a, T>,
|
||||||
cudarc::driver::CudaViewMut<'a, T>,
|
cudarc::driver::CudaViewMut<'a, T>,
|
||||||
) {
|
) {
|
||||||
|
let src_offset = src_l.start_offset();
|
||||||
let to_copy = dst
|
let to_copy = dst
|
||||||
.len()
|
.len()
|
||||||
.saturating_sub(dst_offset)
|
.saturating_sub(dst_offset)
|
||||||
@ -366,13 +367,18 @@ impl CudaStorage {
|
|||||||
let dev = self.device();
|
let dev = self.device();
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
||||||
let start_o = layout.start_offset();
|
let start_o = layout.start_offset();
|
||||||
|
// This returns an i64 rather than a &i64, this is useful to get around some temporary
|
||||||
|
// lifetime issue and is safe as long as self.slice does not go out of scope before inp
|
||||||
|
// is used.
|
||||||
let inp = match &self.slice {
|
let inp = match &self.slice {
|
||||||
CudaStorageSlice::U32(inp) => inp.slice(start_o..).device_ptr(),
|
CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||||
CudaStorageSlice::BF16(inp) => inp.slice(start_o..).device_ptr(),
|
CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||||
CudaStorageSlice::F16(inp) => inp.slice(start_o..).device_ptr(),
|
CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(),
|
||||||
CudaStorageSlice::F32(inp) => inp.slice(start_o..).device_ptr(),
|
CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(),
|
||||||
CudaStorageSlice::F64(inp) => inp.slice(start_o..).device_ptr(),
|
CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(),
|
||||||
};
|
};
|
||||||
|
let inp = &inp;
|
||||||
|
|
||||||
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
|
let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str());
|
||||||
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
|
let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?;
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
@ -739,13 +745,14 @@ impl CudaStorage {
|
|||||||
layout_f: &Layout,
|
layout_f: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let ids = match &self.slice {
|
let ids = match &self.slice {
|
||||||
CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..),
|
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
|
||||||
_ => Err(CudaError::UnexpectedDType {
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
msg: "where conditions should be u32",
|
msg: "where conditions should be u32",
|
||||||
expected: DType::U32,
|
expected: DType::U32,
|
||||||
got: self.dtype(),
|
got: self.dtype(),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
|
let ids = &ids;
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
@ -818,13 +825,14 @@ impl CudaStorage {
|
|||||||
|
|
||||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let ids = match &self.slice {
|
let ids = match &self.slice {
|
||||||
CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..),
|
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
|
||||||
_ => Err(CudaError::UnexpectedDType {
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
msg: "embedding ids should be u32",
|
msg: "embedding ids should be u32",
|
||||||
expected: DType::U32,
|
expected: DType::U32,
|
||||||
got: self.dtype(),
|
got: self.dtype(),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
|
let ids = &ids;
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let (v_size, h_size) = rhs_l
|
let (v_size, h_size) = rhs_l
|
||||||
.shape()
|
.shape()
|
||||||
@ -953,15 +961,16 @@ impl CudaStorage {
|
|||||||
dst_offset: usize,
|
dst_offset: usize,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
let src_shape = src_l.shape();
|
||||||
let dims = src_shape.dims();
|
let dims = src_shape.dims();
|
||||||
let el_count = src_shape.elem_count();
|
let el_count = src_shape.elem_count();
|
||||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
let dev = &self.device;
|
let dev = &self.device;
|
||||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
let ds = dev.htod_copy([dims, src_l.stride()].concat())?;
|
||||||
match (&self.slice, &mut dst.slice) {
|
match (&self.slice, &mut dst.slice) {
|
||||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_l.is_contiguous() {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?;
|
||||||
@ -972,8 +981,8 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_l.is_contiguous() {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?;
|
||||||
@ -984,8 +993,8 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_l.is_contiguous() {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?;
|
||||||
@ -996,8 +1005,8 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_l.is_contiguous() {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?;
|
||||||
@ -1008,8 +1017,8 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset);
|
||||||
if src_shape.is_contiguous(src_stride) {
|
if src_l.is_contiguous() {
|
||||||
dev.dtod_copy(&src, &mut dst)?
|
dev.dtod_copy(&src, &mut dst)?
|
||||||
} else {
|
} else {
|
||||||
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
|
||||||
|
Reference in New Issue
Block a user