From acc5bd335f6dfdf4ebb10ba76fda5d7c95434282 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 11 Apr 2025 21:43:35 +0200 Subject: [PATCH] Cuda cleanup. (#2880) * Cuda cleanup. * More fixes. --- candle-core/src/cuda_backend/device.rs | 150 +++++++++++++------- candle-core/src/cuda_backend/mod.rs | 123 ++++++++-------- candle-core/src/quantized/cuda.rs | 52 +++---- candle-core/src/sort.rs | 2 +- candle-examples/examples/custom-ops/main.rs | 2 +- candle-flash-attn/src/lib.rs | 11 +- candle-nn/src/ops.rs | 8 +- candle-nn/src/rotary_emb.rs | 6 +- 8 files changed, 193 insertions(+), 161 deletions(-) diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 8967eb98..a2674d67 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -46,11 +46,61 @@ impl std::fmt::Debug for CudaDevice { } } -impl std::ops::Deref for CudaDevice { - type Target = Arc; +impl CudaDevice { + #[allow(clippy::missing_safety_doc)] + pub unsafe fn alloc( + &self, + len: usize, + ) -> Result> { + self.stream.alloc::(len).w() + } - fn deref(&self) -> &Self::Target { - &self.stream + pub fn alloc_zeros( + &self, + len: usize, + ) -> Result> { + self.stream.alloc_zeros::(len).w() + } + + pub fn memcpy_htod< + T: cudarc::driver::DeviceRepr, + Src: cudarc::driver::HostSlice + ?Sized, + Dst: cudarc::driver::DevicePtrMut, + >( + &self, + src: &Src, + dst: &mut Dst, + ) -> Result<()> { + self.stream.memcpy_htod(src, dst).w() + } + + pub fn memcpy_dtov>( + &self, + src: &Src, + ) -> Result> { + self.stream.memcpy_dtov(src).w() + } + + pub fn memcpy_dtod< + T, + Src: cudarc::driver::DevicePtr, + Dst: cudarc::driver::DevicePtrMut, + >( + &self, + src: &Src, + dst: &mut Dst, + ) -> Result<()> { + self.stream.memcpy_dtod(src, dst).w() + } + + pub fn memcpy_stod< + T: cudarc::driver::DeviceRepr, + Src: cudarc::driver::HostSlice + ?Sized, + >( + &self, + src: &Src, + ) -> Result> { + self.stream.memcpy_stod(src).w() } } @@ -126,7 +176,7 @@ impl CudaDevice { let slice = match dtype { DType::U8 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; + let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_u8", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); let v = v as u8; @@ -138,7 +188,7 @@ impl CudaDevice { } DType::U32 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; + let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_u32", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); let v = v as u32; @@ -150,7 +200,7 @@ impl CudaDevice { } DType::I64 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; + let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_i64", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); let v = v as i64; @@ -162,7 +212,7 @@ impl CudaDevice { } DType::BF16 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; + let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); let v = bf16::from_f64(v); @@ -174,7 +224,7 @@ impl CudaDevice { } DType::F16 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; + let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_f16", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); let v = f16::from_f64(v); @@ -186,7 +236,7 @@ impl CudaDevice { } DType::F32 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; + let data = unsafe { self.alloc::(elem_count)? }; let func = self.get_or_load_func("fill_f32", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); let v = v as f32; @@ -198,7 +248,7 @@ impl CudaDevice { } DType::F64 => { // SAFETY: Set later by running the fill kernel. - let data = unsafe { self.alloc::(elem_count) }.w()?; + let data = unsafe { self.alloc::(elem_count) }?; let func = self.get_or_load_func("fill_f64", &kernels::FILL)?; let mut builder = self.stream.launch_builder(&func); builder.arg(&data); @@ -325,31 +375,31 @@ impl BackendDevice for CudaDevice { let elem_count = shape.elem_count(); let slice = match dtype { DType::U8 => { - let data = self.alloc_zeros::(elem_count).w()?; + let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::U8(data) } DType::U32 => { - let data = self.alloc_zeros::(elem_count).w()?; + let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::U32(data) } DType::I64 => { - let data = self.alloc_zeros::(elem_count).w()?; + let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::I64(data) } DType::BF16 => { - let data = self.alloc_zeros::(elem_count).w()?; + let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::BF16(data) } DType::F16 => { - let data = self.alloc_zeros::(elem_count).w()?; + let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::F16(data) } DType::F32 => { - let data = self.alloc_zeros::(elem_count).w()?; + let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::F32(data) } DType::F64 => { - let data = self.alloc_zeros::(elem_count).w()?; + let data = self.alloc_zeros::(elem_count)?; CudaStorageSlice::F64(data) } }; @@ -373,12 +423,12 @@ impl BackendDevice for CudaDevice { .w()? } DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count) }.w()?; + let mut data = unsafe { self.alloc::(elem_count)? }; curand.0.fill_with_uniform(&mut data).w()?; CudaStorageSlice::F32(data) } DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count) }.w()?; + let mut data = unsafe { self.alloc::(elem_count)? }; curand.0.fill_with_uniform(&mut data).w()?; CudaStorageSlice::F64(data) } @@ -417,7 +467,7 @@ impl BackendDevice for CudaDevice { .w()? } DType::F32 => { - let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; + let mut data = unsafe { self.alloc::(elem_count_round)? }; curand .0 .fill_with_normal(&mut data, mean as f32, std as f32) @@ -425,7 +475,7 @@ impl BackendDevice for CudaDevice { CudaStorageSlice::F32(data) } DType::F64 => { - let mut data = unsafe { self.alloc::(elem_count_round) }.w()?; + let mut data = unsafe { self.alloc::(elem_count_round)? }; curand.0.fill_with_normal(&mut data, mean, std).w()?; CudaStorageSlice::F64(data) } @@ -444,31 +494,31 @@ impl BackendDevice for CudaDevice { let elem_count = shape.elem_count(); let slice = match dtype { DType::U8 => { - let data = self.alloc::(elem_count).w()?; + let data = self.alloc::(elem_count)?; CudaStorageSlice::U8(data) } DType::U32 => { - let data = self.alloc::(elem_count).w()?; + let data = self.alloc::(elem_count)?; CudaStorageSlice::U32(data) } DType::I64 => { - let data = self.alloc::(elem_count).w()?; + let data = self.alloc::(elem_count)?; CudaStorageSlice::I64(data) } DType::BF16 => { - let data = self.alloc::(elem_count).w()?; + let data = self.alloc::(elem_count)?; CudaStorageSlice::BF16(data) } DType::F16 => { - let data = self.alloc::(elem_count).w()?; + let data = self.alloc::(elem_count)?; CudaStorageSlice::F16(data) } DType::F32 => { - let data = self.alloc::(elem_count).w()?; + let data = self.alloc::(elem_count)?; CudaStorageSlice::F32(data) } DType::F64 => { - let data = self.alloc::(elem_count).w()?; + let data = self.alloc::(elem_count)?; CudaStorageSlice::F64(data) } }; @@ -481,31 +531,31 @@ impl BackendDevice for CudaDevice { fn storage_from_slice(&self, s: &[T]) -> Result { let slice = match T::cpu_storage_ref(s) { CpuStorageRef::U8(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::U8(data) } CpuStorageRef::U32(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::U32(data) } CpuStorageRef::I64(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::I64(data) } CpuStorageRef::BF16(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::BF16(data) } CpuStorageRef::F16(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::F16(data) } CpuStorageRef::F32(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::F32(data) } CpuStorageRef::F64(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::F64(data) } }; @@ -518,31 +568,31 @@ impl BackendDevice for CudaDevice { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result { let slice = match storage { CpuStorage::U8(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::U8(data) } CpuStorage::U32(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::U32(data) } CpuStorage::I64(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::I64(data) } CpuStorage::BF16(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::BF16(data) } CpuStorage::F16(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::F16(data) } CpuStorage::F32(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::F32(data) } CpuStorage::F64(storage) => { - let data = self.memcpy_stod(storage).w()?; + let data = self.memcpy_stod(storage)?; CudaStorageSlice::F64(data) } }; @@ -555,31 +605,31 @@ impl BackendDevice for CudaDevice { fn storage_from_cpu_storage_owned(&self, storage: CpuStorage) -> Result { let slice = match storage { CpuStorage::U8(storage) => { - let data = self.memcpy_stod(&storage).w()?; + let data = self.memcpy_stod(&storage)?; CudaStorageSlice::U8(data) } CpuStorage::U32(storage) => { - let data = self.memcpy_stod(&storage).w()?; + let data = self.memcpy_stod(&storage)?; CudaStorageSlice::U32(data) } CpuStorage::I64(storage) => { - let data = self.memcpy_stod(&storage).w()?; + let data = self.memcpy_stod(&storage)?; CudaStorageSlice::I64(data) } CpuStorage::BF16(storage) => { - let data = self.memcpy_stod(&storage).w()?; + let data = self.memcpy_stod(&storage)?; CudaStorageSlice::BF16(data) } CpuStorage::F16(storage) => { - let data = self.memcpy_stod(&storage).w()?; + let data = self.memcpy_stod(&storage)?; CudaStorageSlice::F16(data) } CpuStorage::F32(storage) => { - let data = self.memcpy_stod(&storage).w()?; + let data = self.memcpy_stod(&storage)?; CudaStorageSlice::F32(data) } CpuStorage::F64(storage) => { - let data = self.memcpy_stod(&storage).w()?; + let data = self.memcpy_stod(&storage)?; CudaStorageSlice::F64(data) } }; diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index a509e97a..df1aed29 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -39,7 +39,7 @@ impl SlicePtrOrNull { let ds = if l.is_contiguous() { SlicePtrOrNull::Null } else { - SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat()).w()?) + SlicePtrOrNull::Ptr(dev.memcpy_stod(&[l.dims(), l.stride()].concat())?) }; Ok(ds) } @@ -89,7 +89,7 @@ impl Map1 for Affine { let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("affine"), &kernels::AFFINE)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -120,7 +120,7 @@ impl Map1 for Elu { let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("uelu"), &kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -159,11 +159,11 @@ impl Map1 for Im2Col1D { let l_out = self.l_out(dims[2]); let dst_el = dims[0] * l_out * dims[1] * self.l_k; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?; + let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("im2col1d"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(dst_el) }.w()?; + let dst = unsafe { dev.alloc::(dst_el)? }; let mut builder = func.builder(); barg!(builder, dst_el); barg!(builder, l_out); @@ -210,11 +210,11 @@ impl Map1 for Im2Col { let (h_out, w_out) = self.hw_out(dims[2], dims[3]); let dst_el = dims[0] * h_out * w_out * dims[1] * self.h_k * self.w_k; let cfg = LaunchConfig::for_num_elems(dst_el as u32); - let ds = dev.memcpy_stod(&[dims, layout.stride()].concat()).w()?; + let ds = dev.memcpy_stod(&[dims, layout.stride()].concat())?; let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("im2col"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(dst_el) }.w()?; + let dst = unsafe { dev.alloc::(dst_el)? }; let mut builder = func.builder(); barg!(builder, dst_el); barg!(builder, h_out); @@ -249,7 +249,7 @@ impl Map1 for Powf { let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("upowf"), &kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -302,9 +302,7 @@ impl Map1Any for FastReduce<'_> { block_dim: (block_dim as u32, 1, 1), shared_mem_bytes: 0, }; - let ds = dev - .memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat()) - .w()?; + let ds = dev.memcpy_stod(&[dims.as_slice(), stride.as_slice()].concat())?; let src = &src.slice(layout.start_offset()..); let (name, check_empty, return_index) = match self.1 { ReduceOp::Sum => ("fast_sum", false, false), @@ -319,7 +317,7 @@ impl Map1Any for FastReduce<'_> { let func = dev.get_or_load_func(&kernel_name::(name), &kernels::REDUCE)?; if return_index { // SAFETY: filled in by the follow up kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; + let out = unsafe { dev.alloc::(dst_el)? }; let mut builder = func.builder(); barg!(builder, src_el); barg!(builder, el_to_sum_per_block); @@ -332,7 +330,7 @@ impl Map1Any for FastReduce<'_> { Ok(S::U32(out)) } else { // SAFETY: filled in by the follow up kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; + let out = unsafe { dev.alloc::(dst_el)? }; let mut builder = func.builder(); barg!(builder, src_el); barg!(builder, el_to_sum_per_block); @@ -362,7 +360,7 @@ impl Map1 for U { let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), &kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let mut out = unsafe { dev.alloc::(el_count) }.w()?; + let mut out = unsafe { dev.alloc::(el_count)? }; let mut builder = func.builder(); barg!(builder, el_count); barg!(builder, dims.len()); @@ -403,7 +401,7 @@ impl Map1 for IndexSelect<'_> { }; let ids_shape = ids_l.shape(); let ids_dims = ids_shape.dims(); - let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat()).w()?; + let ds = dev.memcpy_stod(&[ids_dims, ids_l.stride()].concat())?; let src = match src_l.contiguous_offsets() { Some((o1, o2)) => src.slice(o1..o2), None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?, @@ -416,7 +414,7 @@ impl Map1 for IndexSelect<'_> { let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; + let out = unsafe { dev.alloc::(dst_el)? }; let mut builder = func.builder(); barg!(builder, dst_el); barg!(builder, ids_dims.len()); @@ -471,7 +469,7 @@ impl Map1 for Gather<'_> { let ids_dim_sz = ids_l.dims()[dim]; let func = dev.get_or_load_func(&kernel_name::(name), &kernels::INDEXING)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, ids); @@ -608,7 +606,7 @@ impl Map2 for Conv1D<'_> { let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv1d"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; + let out = unsafe { dev.alloc::(dst_el)? }; let ds = if dims.len() == 3 { [dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat() } else if dims.len() == 2 { @@ -616,7 +614,7 @@ impl Map2 for Conv1D<'_> { } else { crate::bail!("unexpected input shape for conv1d {dims:?}") }; - let ds = dev.memcpy_stod(&ds).w()?; + let ds = dev.memcpy_stod(&ds)?; let mut builder = func.builder(); barg!(builder, el, l_out, p.stride, p.padding, p.dilation); builder.arg(&ds); @@ -651,7 +649,7 @@ impl Map2 for Conv2D<'_> { let el = shape.elem_count(); // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; + let out = unsafe { dev.alloc::(dst_el)? }; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv2d"), &kernels::CONV)?; let ds = if dims.len() == 4 { @@ -659,7 +657,7 @@ impl Map2 for Conv2D<'_> { } else { crate::bail!("unexpected input shape for conv2d {dims:?}") }; - let ds = dev.memcpy_stod(&ds).w()?; + let ds = dev.memcpy_stod(&ds)?; let mut builder = func.builder(); barg!(builder, el, out_w, out_h, p.stride, p.padding, p.dilation); builder.arg(&ds); @@ -687,7 +685,7 @@ impl Map1 for Col2Im1D { let stride = self.stride; let l_out = (l_in - 1) * stride + k_size; let dst_el = b_size * c_out * l_out; - let mut im = unsafe { dev.alloc::(dst_el) }.w()?; + let mut im = unsafe { dev.alloc::(dst_el)? }; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("col2im1d"), &kernels::CONV)?; @@ -722,7 +720,7 @@ impl Map2 for ConvTranspose1D<'_> { let el = shape.elem_count(); // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; + let out = unsafe { dev.alloc::(dst_el)? }; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv_transpose1d"), &kernels::CONV)?; let ds = if dims.len() == 3 { @@ -730,7 +728,7 @@ impl Map2 for ConvTranspose1D<'_> { } else { crate::bail!("unexpected input shape for conv_transpose1d {dims:?}") }; - let ds = dev.memcpy_stod(&ds).w()?; + let ds = dev.memcpy_stod(&ds)?; let mut builder = func.builder(); barg!(builder, el); barg!(builder, l_out); @@ -770,7 +768,7 @@ impl Map2 for ConvTranspose2D<'_> { let el = shape.elem_count(); // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; + let out = unsafe { dev.alloc::(dst_el)? }; let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("conv_transpose2d"), &kernels::CONV)?; let ds = if dims.len() == 4 { @@ -778,7 +776,7 @@ impl Map2 for ConvTranspose2D<'_> { } else { crate::bail!("unexpected input shape for conv_transpose2d {dims:?}") }; - let ds = dev.memcpy_stod(&ds).w()?; + let ds = dev.memcpy_stod(&ds)?; let mut builder = func.builder(); barg!(builder, el); barg!(builder, out_w); @@ -837,8 +835,8 @@ impl Map1 for Pool2D { }; let func = dev.get_or_load_func(&kernel_name::(kname), &kernels::CONV)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; - let ds = dev.memcpy_stod(&ds).w()?; + let out = unsafe { dev.alloc::(dst_el)? }; + let ds = dev.memcpy_stod(&ds)?; let mut builder = func.builder(); barg!(builder, el); barg!(builder, self.w_k); @@ -876,8 +874,8 @@ impl Map1 for UpsampleNearest2D { let cfg = LaunchConfig::for_num_elems(dst_el as u32); let func = dev.get_or_load_func(&kernel_name::("upsample_nearest2d"), &kernels::CONV)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(dst_el) }.w()?; - let ds = dev.memcpy_stod(&ds).w()?; + let out = unsafe { dev.alloc::(dst_el)? }; + let ds = dev.memcpy_stod(&ds)?; let scale_w = dims[2] as f64 / out_w as f64; let scale_h = dims[3] as f64 / out_h as f64; let mut builder = func.builder(); @@ -930,13 +928,12 @@ impl Map2 for WhereCond<'_> { let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let ds = dev - .memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat()) - .w()?; + .memcpy_stod(&[dims, ids_l.stride(), layout_t.stride(), layout_f.stride()].concat())?; let t = &t.slice(layout_t.start_offset()..); let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(name), &kernels::TERNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -967,16 +964,13 @@ impl Map2 for U { let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { SlicePtrOrNull::Null } else { - SlicePtrOrNull::Ptr( - dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat()) - .w()?, - ) + SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?) }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(&kernel_name::(U::KERNEL), &kernels::BINARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }.w()?; + let out = unsafe { dev.alloc::(elem_count)? }; let mut builder = func.builder(); barg!(builder, elem_count); barg!(builder, dims.len()); @@ -1007,10 +1001,7 @@ impl Map2Any for Cmp { let dims_and_strides = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { SlicePtrOrNull::Null } else { - SlicePtrOrNull::Ptr( - dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat()) - .w()?, - ) + SlicePtrOrNull::Ptr(dev.memcpy_stod(&[dims, lhs_l.stride(), rhs_l.stride()].concat())?) }; let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); @@ -1024,7 +1015,7 @@ impl Map2Any for Cmp { }; let func = dev.get_or_load_func(&kernel_name::(name), &kernels::BINARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(elem_count) }.w()?; + let out = unsafe { dev.alloc::(elem_count)? }; let mut builder = func.builder(); barg!(builder, elem_count); barg!(builder, dims.len()); @@ -1269,7 +1260,7 @@ impl BackendStorage for CudaStorage { let func = dev.get_or_load_func(&kernel_name, &kernels::CAST)?; let slice = match dtype { DType::U8 => { - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -1280,7 +1271,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::U8(out) } DType::U32 => { - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -1291,7 +1282,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::U32(out) } DType::I64 => { - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -1302,7 +1293,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::I64(out) } DType::BF16 => { - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -1313,7 +1304,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::BF16(out) } DType::F16 => { - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -1324,7 +1315,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F16(out) } DType::F32 => { - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -1335,7 +1326,7 @@ impl BackendStorage for CudaStorage { CudaStorageSlice::F32(out) } DType::F64 => { - let out = unsafe { dev.alloc::(el) }.w()?; + let out = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); barg!(builder, el); barg!(builder, dims.len()); @@ -1632,7 +1623,7 @@ impl BackendStorage for CudaStorage { (S::U8(inp), S::U8(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); - let mut out = unsafe { device.alloc::(dst_el) }.w()?; + let mut out = unsafe { device.alloc::(dst_el)? }; crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::U8(out) @@ -1640,7 +1631,7 @@ impl BackendStorage for CudaStorage { (S::BF16(inp), S::BF16(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); - let mut out = unsafe { device.alloc::(dst_el) }.w()?; + let mut out = unsafe { device.alloc::(dst_el)? }; // Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16" // version. // https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88 @@ -1651,7 +1642,7 @@ impl BackendStorage for CudaStorage { (S::F16(inp), S::F16(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); - let mut out = unsafe { device.alloc::(dst_el) }.w()?; + let mut out = unsafe { device.alloc::(dst_el)? }; crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F16(out) @@ -1659,7 +1650,7 @@ impl BackendStorage for CudaStorage { (S::F32(inp), S::F32(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); - let mut out = unsafe { device.alloc::(dst_el) }.w()?; + let mut out = unsafe { device.alloc::(dst_el)? }; crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F32(out) @@ -1667,7 +1658,7 @@ impl BackendStorage for CudaStorage { (S::F64(inp), S::F64(k)) => { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); - let mut out = unsafe { device.alloc::(dst_el) }.w()?; + let mut out = unsafe { device.alloc::(dst_el)? }; crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F64(out) @@ -1783,7 +1774,7 @@ impl BackendStorage for CudaStorage { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(bf16::ONE, bf16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + let mut out = unsafe { dev.alloc::(elem_count)? }; unsafe { gemm_strided_batched_bf16(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::BF16(out) @@ -1792,7 +1783,7 @@ impl BackendStorage for CudaStorage { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + let mut out = unsafe { dev.alloc::(elem_count)? }; unsafe { gemm_strided_batched_f16(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::F16(out) @@ -1801,7 +1792,7 @@ impl BackendStorage for CudaStorage { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + let mut out = unsafe { dev.alloc::(elem_count)? }; unsafe { gemm_strided_batched_f32(&self.device.blas, cfg, rhs, lhs, &mut out) } .w()?; CudaStorageSlice::F32(out) @@ -1810,7 +1801,7 @@ impl BackendStorage for CudaStorage { let lhs = &lhs.slice(lhs_l.start_offset()..); let rhs = &rhs.slice(rhs_l.start_offset()..); let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; - let mut out = unsafe { dev.alloc::(elem_count) }.w()?; + let mut out = unsafe { dev.alloc::(elem_count)? }; unsafe { self.device .blas @@ -1883,7 +1874,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.memcpy_dtod(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_bf16", &kernels::UNARY)?; let mut builder = func.builder(); @@ -1899,7 +1890,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.memcpy_dtod(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f16", &kernels::UNARY)?; let mut builder = func.builder(); @@ -1915,7 +1906,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.memcpy_dtod(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f32", &kernels::UNARY)?; let mut builder = func.builder(); @@ -1931,7 +1922,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::U8(src), CudaStorageSlice::U8(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.memcpy_dtod(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_u8", &kernels::UNARY)?; let mut builder = func.builder(); @@ -1947,7 +1938,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.memcpy_dtod(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_u32", &kernels::UNARY)?; let mut builder = func.builder(); @@ -1963,7 +1954,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::I64(src), CudaStorageSlice::I64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.memcpy_dtod(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_i64", &kernels::UNARY)?; let mut builder = func.builder(); @@ -1979,7 +1970,7 @@ impl BackendStorage for CudaStorage { (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); if src_l.is_contiguous() { - dev.memcpy_dtod(&src, &mut dst).w()? + dev.memcpy_dtod(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f64", &kernels::UNARY)?; let mut builder = func.builder(); diff --git a/candle-core/src/quantized/cuda.rs b/candle-core/src/quantized/cuda.rs index 21f6ae0c..c8d483a3 100644 --- a/candle-core/src/quantized/cuda.rs +++ b/candle-core/src/quantized/cuda.rs @@ -99,7 +99,7 @@ fn dequantize_f32( _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; - let dst = unsafe { dev.alloc::(elem_count).w()? }; + let dst = unsafe { dev.alloc::(elem_count)? }; // See e.g. // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 let cfg = cudarc::driver::LaunchConfig { @@ -159,7 +159,7 @@ fn dequantize_f16( _ => crate::bail!("unsupported dtype for dequantize {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; - let dst = unsafe { dev.alloc::(elem_count).w()? }; + let dst = unsafe { dev.alloc::(elem_count)? }; // See e.g. // https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270 let cfg = cudarc::driver::LaunchConfig { @@ -216,7 +216,7 @@ fn dequantize_mul_mat_vec( _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; - let dst = unsafe { dev.alloc::(nrows).w()? }; + let dst = unsafe { dev.alloc::(nrows)? }; let block_num_y = ceil_div(nrows, GGML_CUDA_MMV_Y); let cfg = cudarc::driver::LaunchConfig { grid_dim: (block_num_y as u32, 1, 1), @@ -256,7 +256,7 @@ fn mul_mat_vec_via_q8_1( let ncols_padded = pad(ncols, MATRIX_ROW_PADDING); let y_size_in_bytes = b_size * ncols_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); - let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes)? }; quantize_q8_1(y, &mut y_q8_1, ncols, b_size, dev)?; let kernel_name = match dtype { @@ -274,7 +274,7 @@ fn mul_mat_vec_via_q8_1( }; let kernel_name = format!("{kernel_name}{b_size}"); let func = dev.get_or_load_func(&kernel_name, &candle_kernels::QUANTIZED)?; - let dst = unsafe { dev.alloc::(nrows * b_size).w()? }; + let dst = unsafe { dev.alloc::(nrows * b_size)? }; // https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98 let (nblocks, nwarps) = match b_size { 1 => (nrows as u32, 4), @@ -329,7 +329,7 @@ fn mul_mat_via_q8_1( let k_padded = pad(k, MATRIX_ROW_PADDING); let y_size_in_bytes = k_padded * y_cols * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); - let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes)? }; quantize_q8_1(y, &mut y_q8_1, k, y_cols, dev)?; let (kernel_name, mmq_x, mmq_y) = match dtype { @@ -346,7 +346,7 @@ fn mul_mat_via_q8_1( _ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"), }; let func = dev.get_or_load_func(kernel_name, &candle_kernels::QUANTIZED)?; - let dst = unsafe { dev.alloc::(x_rows * y_cols).w()? }; + let dst = unsafe { dev.alloc::(x_rows * y_cols)? }; let cfg = cudarc::driver::LaunchConfig { grid_dim: ( ceil_div(x_rows, mmq_y) as u32, @@ -378,7 +378,7 @@ impl QCudaStorage { let size_in_bytes = ceil_div(el_count, dtype.block_size()) * dtype.type_size(); let padded_size_in_bytes = ceil_div(el_count + MATRIX_ROW_PADDING, dtype.block_size()) * dtype.type_size(); - let inner = device.alloc_zeros::(padded_size_in_bytes).w()?; + let inner = device.alloc_zeros::(padded_size_in_bytes)?; Ok(QCudaStorage { data: PaddedCudaSlice { inner, @@ -425,8 +425,7 @@ impl QCudaStorage { let buffer = self .device - .memcpy_dtov(&self.data.inner.slice(..self.data.len)) - .w()?; + .memcpy_dtov(&self.data.inner.slice(..self.data.len))?; let mut out = vec![0.0; elem_count]; let block_len = elem_count / self.dtype.block_size(); match self.dtype { @@ -457,9 +456,7 @@ impl QCudaStorage { pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> { // Run the quantization on cpu. let src = match &src.slice { - crate::cuda_backend::CudaStorageSlice::F32(data) => { - self.device.memcpy_dtov(data).w()? - } + crate::cuda_backend::CudaStorageSlice::F32(data) => self.device.memcpy_dtov(data)?, _ => crate::bail!("only f32 can be quantized"), }; let src_len = src.len(); @@ -469,10 +466,9 @@ impl QCudaStorage { let data = qcpu_storage.data()?; let padded_len = data.len() + MATRIX_ROW_PADDING * self.dtype.type_size() / self.dtype.block_size(); - let mut inner = unsafe { self.device.alloc::(padded_len).w()? }; + let mut inner = unsafe { self.device.alloc::(padded_len)? }; self.device - .memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len())) - .w()?; + .memcpy_htod(data.as_ref(), &mut inner.slice_mut(..data.len()))?; self.data = PaddedCudaSlice { inner, len: data.len(), @@ -606,10 +602,8 @@ pub fn load_quantized( }; let dtype = T::DTYPE; let padded_len = data.len() + MATRIX_ROW_PADDING * dtype.type_size() / dtype.block_size(); - let mut inner = unsafe { device.alloc::(padded_len).w()? }; - device - .memcpy_htod(data, &mut inner.slice_mut(..data.len())) - .w()?; + let mut inner = unsafe { device.alloc::(padded_len)? }; + device.memcpy_htod(data, &mut inner.slice_mut(..data.len()))?; Ok(QStorage::Cuda(QCudaStorage { data: PaddedCudaSlice { inner, @@ -631,9 +625,9 @@ mod test { let el_padded = pad(el, MATRIX_ROW_PADDING); let y_size_in_bytes = el_padded * GgmlDType::Q8_1.type_size() / GgmlDType::Q8_1.block_size(); - let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes).w()? }; + let mut y_q8_1 = unsafe { dev.alloc::(y_size_in_bytes)? }; let vs: Vec = (0..el).map(|v| v as f32).collect(); - let y = dev.memcpy_stod(&vs).w()?; + let y = dev.memcpy_stod(&vs)?; quantize_q8_1(&y.slice(..), &mut y_q8_1, el, 1, &dev)?; Ok(()) } @@ -643,7 +637,7 @@ mod test { let dev = CudaDevice::new(0)?; let ncols = 256; let vs: Vec = (0..ncols).map(|v| v as f32).collect(); - let y = dev.memcpy_stod(&vs).w()?; + let y = dev.memcpy_stod(&vs)?; let mut xs = QCudaStorage::zeros(&dev, ncols, GgmlDType::Q4_0)?; xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_vec_via_q8_1( @@ -656,7 +650,7 @@ mod test { &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; - let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap(); + let vs = dev.memcpy_dtov(&vs.slice(..))?; assert_eq!(vs.len(), 1); // for n = 255, n.(n+1).(2n+1) / 6 = 5559680 // Q8 means 1/256 precision. @@ -671,7 +665,7 @@ mod test { &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; - let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap(); + let vs = dev.memcpy_dtov(&vs.slice(..))?; assert_eq!(vs.len(), 1); assert_eq!(vs[0], 5561851.0); Ok(()) @@ -682,7 +676,7 @@ mod test { let dev = CudaDevice::new(0)?; let ncols = 256; let vs: Vec = (0..ncols * 4).map(|v| v as f32 / 4.).collect(); - let y = dev.memcpy_stod(&vs).w()?; + let y = dev.memcpy_stod(&vs)?; let mut xs = QCudaStorage::zeros(&dev, ncols * 4, GgmlDType::Q4_0)?; xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_via_q8_1( @@ -696,7 +690,7 @@ mod test { &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; - let vs = dev.memcpy_dtov(&vs.slice(..)).unwrap(); + let vs = dev.memcpy_dtov(&vs.slice(..))?; /* x = torch.tensor([float(v) for v in range(1024)]).reshape(4, 256) @@ -723,7 +717,7 @@ mod test { let dev = CudaDevice::new(0)?; let (x_rows, ncols, y_cols) = (4, 16, 2048); let vs: Vec = (0..ncols * y_cols).map(|v| v as f32 / 256.).collect(); - let y = dev.memcpy_stod(&vs).w()?; + let y = dev.memcpy_stod(&vs)?; let mut xs = QCudaStorage::zeros(&dev, ncols * x_rows, GgmlDType::Q4_0)?; xs.quantize(&CudaStorage::wrap_cuda_slice(y.clone(), dev.clone()))?; let cuda_storage = mul_mat_via_q8_1( @@ -737,7 +731,7 @@ mod test { &dev, )?; let vs = cuda_storage.as_cuda_slice::()?; - let _vs = dev.memcpy_dtov(&vs.slice(..)).unwrap(); + let _vs = dev.memcpy_dtov(&vs.slice(..))?; Ok(()) } } diff --git a/candle-core/src/sort.rs b/candle-core/src/sort.rs index 9a8597d3..af536617 100644 --- a/candle-core/src/sort.rs +++ b/candle-core/src/sort.rs @@ -76,7 +76,7 @@ mod cuda { Some((o1, o2)) => src.slice(o1..o2), }; let elem_count = layout.shape().elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let dst = unsafe { dev.alloc::(elem_count)? }; let func = if self.asc { dev.get_or_load_func(&kernel_name::("asort_asc"), &kernels::SORT)? } else { diff --git a/candle-examples/examples/custom-ops/main.rs b/candle-examples/examples/custom-ops/main.rs index 9a312cb2..029d3134 100644 --- a/candle-examples/examples/custom-ops/main.rs +++ b/candle-examples/examples/custom-ops/main.rs @@ -68,7 +68,7 @@ impl CustomOp1 for LayerNorm { Some((o1, o2)) => slice.slice(o1..o2), }; let elem_count = layout.shape().elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; + let dst = unsafe { dev.alloc::(elem_count) }?; let func = dev.get_or_load_custom_func("rms_f32", "mymodule", cuda_kernels::LAYERNORM_KERNELS)?; let cfg = LaunchConfig { diff --git a/candle-flash-attn/src/lib.rs b/candle-flash-attn/src/lib.rs index e84edd14..643783b3 100644 --- a/candle-flash-attn/src/lib.rs +++ b/candle-flash-attn/src/lib.rs @@ -2,7 +2,6 @@ mod ffi; use candle::backend::BackendStorage; use candle::cuda_backend::cudarc::driver::DevicePtr; -use candle::cuda_backend::WrapErr; use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor}; use half::{bf16, f16}; @@ -142,10 +141,8 @@ impl FlashAttn { let seqlen_k_rounded = round_multiple(seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; - let softmax_lse = dev - .alloc_zeros::(b_sz * 128 * num_heads * seqlen_q) - .w()?; + let dst = unsafe { dev.alloc::(elem_count)? }; + let softmax_lse = dev.alloc_zeros::(b_sz * 128 * num_heads * seqlen_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; @@ -607,8 +604,8 @@ impl FlashAttnVarLen { let seqlen_k_rounded = round_multiple(self.max_seqlen_k, 128); let elem_count = out_shape.elem_count(); - let dst = unsafe { dev.alloc::(elem_count) }.w()?; - let softmax_lse = dev.alloc_zeros::(num_heads * total_q).w()?; + let dst = unsafe { dev.alloc::(elem_count)? }; + let softmax_lse = dev.alloc_zeros::(num_heads * total_q)?; let is_bf16 = if is_bf16 { 1 } else { 0 }; diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 74169190..79affdae 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -112,7 +112,7 @@ impl candle::CustomOp1 for Sigmoid { let src = &src.slice(layout.start_offset()..); let func = dev.get_or_load_func(&kernel_name::("usigmoid"), &kernels::UNARY)?; // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el_count) }.w()?; + let out = unsafe { dev.alloc::(el_count)? }; let mut builder = func.builder(); candle::builder_arg!(builder, el_count, dims.len()); @@ -373,7 +373,7 @@ impl candle::CustomOp1 for SoftmaxLastDim { }; let func = dev.get_or_load_func(&kernel_name::("softmax"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(el) }.w()?; + let dst = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); builder.arg(&src); builder.arg(&dst); @@ -561,7 +561,7 @@ impl candle::CustomOp2 for RmsNorm { }; let func = dev.get_or_load_func(&kernel_name::("rmsnorm"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(el) }.w()?; + let dst = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); builder.arg(&src); builder.arg(&dst); @@ -800,7 +800,7 @@ impl candle::CustomOp3 for LayerNorm { let func = dev.get_or_load_func(&kernel_name::("layernorm"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(el) }.w()?; + let dst = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); builder.arg(&src); builder.arg(&dst); diff --git a/candle-nn/src/rotary_emb.rs b/candle-nn/src/rotary_emb.rs index a1d7cfae..e9fa24ce 100644 --- a/candle-nn/src/rotary_emb.rs +++ b/candle-nn/src/rotary_emb.rs @@ -119,7 +119,7 @@ impl candle::CustomOp3 for RotaryEmbI { let cfg = LaunchConfig::for_num_elems((el / 2) as u32); let func = dev.get_or_load_func(&kernel_name::("rope_i"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(el) }.w()?; + let dst = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); builder.arg(&src); builder.arg(&cos); @@ -369,7 +369,7 @@ impl candle::CustomOp3 for RotaryEmb { let cfg = LaunchConfig::for_num_elems((el / 2) as u32); let func = dev.get_or_load_func(&kernel_name::("rope"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(el) }.w()?; + let dst = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); builder.arg(&src); builder.arg(&cos); @@ -620,7 +620,7 @@ impl candle::CustomOp3 for RotaryEmbThd { let cfg = LaunchConfig::for_num_elems((el / 2) as u32); let func = dev.get_or_load_func(&kernel_name::("rope_thd"), &kernels::REDUCE)?; // SAFETY: Set later by running the kernel. - let dst = unsafe { dev.alloc::(el) }.w()?; + let dst = unsafe { dev.alloc::(el)? }; let mut builder = func.builder(); builder.arg(&src); builder.arg(&cos);