From 367170da4527101c7e3aae8fbd7f0551fcddf5d0 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 09:45:27 +0100 Subject: [PATCH] Also use Map1 for embedding. --- candle-core/src/cuda_backend.rs | 113 +++++++++++--------------------- 1 file changed, 40 insertions(+), 73 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index dc0d51bf..7d06dd72 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -373,6 +373,44 @@ impl Map1 for U { } } +struct Embedding<'a>(&'a CudaStorage, &'a Layout); +impl<'a> Map1 for Embedding<'a> { + fn f( + &self, + rhs: &CudaSlice, + dev: &CudaDevice, + rhs_l: &Layout, + ) -> Result> { + let ids_l = &self.1; + let ids = match &self.0.slice { + CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..), + _ => Err(CudaError::UnexpectedDType { + msg: "embedding ids should be u32", + expected: DType::U32, + got: self.0.dtype(), + })?, + }; + let ids = &ids; + let shape = ids_l.shape(); + let (v_size, h_size) = rhs_l + .shape() + .r2() + .map_err(|e| CudaError::WrappedError(Box::new(e)))?; + let dims = shape.dims(); + let el = shape.elem_count(); + let cfg = LaunchConfig::for_num_elems(el as u32); + let ds = dev.htod_copy([dims, ids_l.stride()].concat())?; + let rhs = &rhs.slice(rhs_l.start_offset()..); + let func = dev.get_or_load_func(&kernel_name::("emb"), kernels::EMBEDDINGS)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { dev.alloc::(el * h_size) }?; + let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }?; + Ok(out) + } +} + fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, src_l: &Layout, @@ -760,79 +798,8 @@ impl CudaStorage { } pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { - let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), - _ => Err(CudaError::UnexpectedDType { - msg: "embedding ids should be u32", - expected: DType::U32, - got: self.dtype(), - })?, - }; - let ids = &ids; - let shape = layout.shape(); - let (v_size, h_size) = rhs_l - .shape() - .r2() - .map_err(|e| CudaError::WrappedError(Box::new(e)))?; - let dims = shape.dims(); - let el = shape.elem_count(); - let cfg = LaunchConfig::for_num_elems(el as u32); - let dev = self.device(); - let ds = dev.htod_copy([dims, layout.stride()].concat())?; - let slice = match &rhs.slice { - // The kernels below assume that rhs is contiguous. - CudaStorageSlice::U32(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::U32(out) - } - CudaStorageSlice::BF16(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::BF16(out) - } - CudaStorageSlice::F16(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F16(out) - } - CudaStorageSlice::F32(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F32(out) - } - CudaStorageSlice::F64(arg) => { - let arg = &arg.slice(rhs_l.start_offset()..); - let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?; - // SAFETY: Set later by running the kernel. - let out = unsafe { dev.alloc::(el * h_size) }?; - let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size); - // SAFETY: ffi. - unsafe { func.launch(cfg, params) }?; - CudaStorageSlice::F64(out) - } - }; - let device = dev.clone(); + let device = self.device().clone(); + let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; Ok(Self { slice, device }) }