mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Remove the embedding ops in favor of index-select. (#299)
* Remove the embedding ops in favor of index-select. * Also remove the cuda kernels.
This commit is contained in:
@ -861,58 +861,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
}
|
||||
}
|
||||
|
||||
struct Embedding<'a, I: IntDType> {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
}
|
||||
|
||||
impl<'a, I: IntDType> Map1 for Embedding<'a, I> {
|
||||
fn f<T: WithDType>(&self, vs: &[T], layout: &Layout) -> Result<Vec<T>> {
|
||||
if !layout.is_contiguous() {
|
||||
Err(Error::RequiresContiguous { op: "embedding" })?
|
||||
}
|
||||
let vs = &vs[layout.start_offset()..];
|
||||
let mut values = Vec::with_capacity(self.ids_l.shape().elem_count() * self.hidden_size);
|
||||
match self.ids_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => {
|
||||
for index in self.ids[o1..o2].iter() {
|
||||
let index = index.as_usize();
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
for index in self.ids_l.strided_index() {
|
||||
let index = self.ids[index].as_usize();
|
||||
if index >= self.vocab_size {
|
||||
Err(Error::InvalidIndex {
|
||||
index,
|
||||
size: self.vocab_size,
|
||||
op: "take",
|
||||
}
|
||||
.bt())?
|
||||
} else {
|
||||
let hidden_size = self.hidden_size;
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(values)
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
||||
match src_l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
@ -1664,27 +1612,6 @@ impl BackendStorage for CpuStorage {
|
||||
Conv1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result<Self> {
|
||||
let (vocab_size, hidden_size) = rhs_l.shape().dims2()?;
|
||||
match self {
|
||||
Self::U8(ids) => Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
}
|
||||
.map(rhs, rhs_l),
|
||||
Self::U32(ids) => Embedding {
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
ids,
|
||||
ids_l,
|
||||
}
|
||||
.map(rhs, rhs_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "embedding")),
|
||||
}
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
match ids {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
|
Reference in New Issue
Block a user