mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Also use Map1 for embedding.
This commit is contained in:
@ -373,6 +373,44 @@ impl<U: crate::op::UnaryOp> Map1 for U {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Embedding<'a>(&'a CudaStorage, &'a Layout);
|
||||||
|
impl<'a> Map1 for Embedding<'a> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
rhs: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
rhs_l: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let ids_l = &self.1;
|
||||||
|
let ids = match &self.0.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => slice.slice(ids_l.start_offset()..),
|
||||||
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
|
msg: "embedding ids should be u32",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: self.0.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let ids = &ids;
|
||||||
|
let shape = ids_l.shape();
|
||||||
|
let (v_size, h_size) = rhs_l
|
||||||
|
.shape()
|
||||||
|
.r2()
|
||||||
|
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||||
|
let ds = dev.htod_copy([dims, ids_l.stride()].concat())?;
|
||||||
|
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("emb"), kernels::EMBEDDINGS)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let out = unsafe { dev.alloc::<T>(el * h_size) }?;
|
||||||
|
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }?;
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn slice_src_and_dst<'a, T>(
|
fn slice_src_and_dst<'a, T>(
|
||||||
src: &'a CudaSlice<T>,
|
src: &'a CudaSlice<T>,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
@ -760,79 +798,8 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||||
let ids = match &self.slice {
|
let device = self.device().clone();
|
||||||
CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..),
|
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
||||||
_ => Err(CudaError::UnexpectedDType {
|
|
||||||
msg: "embedding ids should be u32",
|
|
||||||
expected: DType::U32,
|
|
||||||
got: self.dtype(),
|
|
||||||
})?,
|
|
||||||
};
|
|
||||||
let ids = &ids;
|
|
||||||
let shape = layout.shape();
|
|
||||||
let (v_size, h_size) = rhs_l
|
|
||||||
.shape()
|
|
||||||
.r2()
|
|
||||||
.map_err(|e| CudaError::WrappedError(Box::new(e)))?;
|
|
||||||
let dims = shape.dims();
|
|
||||||
let el = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
|
||||||
let dev = self.device();
|
|
||||||
let ds = dev.htod_copy([dims, layout.stride()].concat())?;
|
|
||||||
let slice = match &rhs.slice {
|
|
||||||
// The kernels below assume that rhs is contiguous.
|
|
||||||
CudaStorageSlice::U32(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<u32>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::U32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::BF16(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<bf16>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::BF16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F16(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f16>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F16(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F32(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f32>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F32(out)
|
|
||||||
}
|
|
||||||
CudaStorageSlice::F64(arg) => {
|
|
||||||
let arg = &arg.slice(rhs_l.start_offset()..);
|
|
||||||
let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?;
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<f64>(el * h_size) }?;
|
|
||||||
let params = (el, dims.len(), &ds, ids, arg, &out, h_size, v_size);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }?;
|
|
||||||
CudaStorageSlice::F64(out)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let device = dev.clone();
|
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user