mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Rework the embeddings so that it works on non-contiguous weights + factor out some code.
This commit is contained in:
@ -88,6 +88,30 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
}
|
||||
}
|
||||
|
||||
fn take<T: Copy>(
|
||||
ids: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &[T],
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
|
||||
for index in StridedIndex::new(shape.dims(), stride) {
|
||||
let index = ids[index].try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "take",
|
||||
});
|
||||
} else {
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
@ -380,52 +404,30 @@ impl CpuStorage {
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
CpuStorage::U32(lhs) => match rhs {
|
||||
CpuStorage::F32(rhs) => {
|
||||
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
||||
for &index in lhs {
|
||||
let index: usize = index.try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "embedding",
|
||||
});
|
||||
} else {
|
||||
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(CpuStorage::F32(weights))
|
||||
CpuStorage::U32(ids) => match vs {
|
||||
CpuStorage::F32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F32(storage))
|
||||
}
|
||||
CpuStorage::F64(rhs) => {
|
||||
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
||||
for &index in lhs {
|
||||
let index: usize = index.try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "embedding",
|
||||
});
|
||||
} else {
|
||||
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(CpuStorage::F64(weights))
|
||||
CpuStorage::F64(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F64(storage))
|
||||
}
|
||||
CpuStorage::U32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::U32(storage))
|
||||
}
|
||||
rhs => Err(Error::UnexpectedDType {
|
||||
expected: DType::F32,
|
||||
got: rhs.dtype(),
|
||||
}),
|
||||
},
|
||||
lhs => Err(Error::UnexpectedDType {
|
||||
ids => Err(Error::UnexpectedDType {
|
||||
expected: DType::U32,
|
||||
got: lhs.dtype(),
|
||||
got: ids.dtype(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user