Rework the embeddings so that it works on non-contiguous weights + factor out some code.

This commit is contained in:
laurent
2023-06-25 17:37:47 +01:00
parent 334524e2c4
commit 817e4b5005
6 changed files with 66 additions and 48 deletions

View File

@ -88,6 +88,30 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
}
}
fn take<T: Copy>(
ids: &[u32],
shape: &Shape,
stride: &[usize],
vs: &[T],
vocab_size: usize,
hidden_size: usize,
) -> Result<Vec<T>> {
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
for index in StridedIndex::new(shape.dims(), stride) {
let index = ids[index].try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "take",
});
} else {
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(values)
}
fn copy_strided_src_<T: Copy + std::fmt::Display>(
src: &[T],
dst: &mut [T],
@ -380,52 +404,30 @@ impl CpuStorage {
pub(crate) fn embedding_impl(
&self,
rhs: &Self,
shape: &Shape,
stride: &[usize],
vs: &Self,
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
match self {
CpuStorage::U32(lhs) => match rhs {
CpuStorage::F32(rhs) => {
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
for &index in lhs {
let index: usize = index.try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "embedding",
});
} else {
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(CpuStorage::F32(weights))
CpuStorage::U32(ids) => match vs {
CpuStorage::F32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F32(storage))
}
CpuStorage::F64(rhs) => {
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
for &index in lhs {
let index: usize = index.try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "embedding",
});
} else {
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(CpuStorage::F64(weights))
CpuStorage::F64(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F64(storage))
}
CpuStorage::U32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::U32(storage))
}
rhs => Err(Error::UnexpectedDType {
expected: DType::F32,
got: rhs.dtype(),
}),
},
lhs => Err(Error::UnexpectedDType {
ids => Err(Error::UnexpectedDType {
expected: DType::U32,
got: lhs.dtype(),
got: ids.dtype(),
}),
}
}