diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index f50d7cbb..9d9a5f99 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -27,7 +27,7 @@ pub enum CudaError { InternalError(&'static str), #[error("internal error '{0}'")] - WrappedError(Box), + WrappedError(Box), #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { @@ -245,13 +245,14 @@ enum CudaStorageSlice { fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, - src_offset: usize, + src_l: &Layout, dst: &'a mut CudaSlice, dst_offset: usize, ) -> ( cudarc::driver::CudaView<'a, T>, cudarc::driver::CudaViewMut<'a, T>, ) { + let src_offset = src_l.start_offset(); let to_copy = dst .len() .saturating_sub(dst_offset) @@ -366,13 +367,18 @@ impl CudaStorage { let dev = self.device(); let ds = dev.htod_copy([dims, layout.stride()].concat())?; let start_o = layout.start_offset(); + // This returns an i64 rather than a &i64, this is useful to get around some temporary + // lifetime issue and is safe as long as self.slice does not go out of scope before inp + // is used. let inp = match &self.slice { - CudaStorageSlice::U32(inp) => inp.slice(start_o..).device_ptr(), - CudaStorageSlice::BF16(inp) => inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F16(inp) => inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F32(inp) => inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F64(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), }; + let inp = &inp; + let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; let slice = match dtype { @@ -739,13 +745,14 @@ impl CudaStorage { layout_f: &Layout, ) -> Result { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..), + CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "where conditions should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let ids = &ids; let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); @@ -818,13 +825,14 @@ impl CudaStorage { pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..), + CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "embedding ids should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let ids = &ids; let shape = layout.shape(); let (v_size, h_size) = rhs_l .shape() @@ -953,15 +961,16 @@ impl CudaStorage { dst_offset: usize, src_l: &Layout, ) -> Result<()> { + let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, src_stride].concat())?; + let ds = dev.htod_copy([dims, src_l.stride()].concat())?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; @@ -972,8 +981,8 @@ impl CudaStorage { } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; @@ -984,8 +993,8 @@ impl CudaStorage { } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; @@ -996,8 +1005,8 @@ impl CudaStorage { } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; @@ -1008,8 +1017,8 @@ impl CudaStorage { } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;