mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select. * Also remove the cuda kernels.
This commit is contained in:
@ -690,46 +690,6 @@ impl<U: UnaryOpT> Map1 for U {
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a>(&'a CudaStorage, &'a Layout);
|
||||
impl<'a> Map1 for Embedding<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
rhs: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let ids_l = &self.1;
|
||||
let (name, ids) = match &self.0.slice {
|
||||
CudaStorageSlice::U32(slice) => {
|
||||
("emb_u32", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
CudaStorageSlice::U8(slice) => {
|
||||
("emb_u8", *slice.slice(ids_l.start_offset()..).device_ptr())
|
||||
}
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "embedding ids should be u8 or u32",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
.w()?,
|
||||
};
|
||||
let shape = ids_l.shape();
|
||||
let (v_size, h_size) = rhs_l.shape().dims2()?;
|
||||
let dims = shape.dims();
|
||||
let el = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el as u32);
|
||||
let ds = dev.htod_copy([dims, ids_l.stride()].concat()).w()?;
|
||||
let rhs = &rhs.slice(rhs_l.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::INDEXING)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let out = unsafe { dev.alloc::<T>(el * h_size) }.w()?;
|
||||
let params = (el, dims.len(), &ds, ids, rhs, &out, h_size, v_size);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
struct IndexSelect<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl<'a> Map1 for IndexSelect<'a> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
@ -1421,12 +1381,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
let slice = IndexSelect(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
|
Reference in New Issue
Block a user