Rework the embeddings so that it works on non-contiguous weights + factor out some code.

This commit is contained in:
laurent
2023-06-25 17:37:47 +01:00
parent 334524e2c4
commit 817e4b5005
6 changed files with 66 additions and 48 deletions

View File

@ -350,7 +350,8 @@ impl Llama {
}
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let (_, t) = x.shape().r2()?;
// TODO: Support for mini-batches? (i.e. r2)
let t = x.shape().r1()?;
let mut x = self.wte.forward(x)?;
for block in self.blocks.iter() {
x = block.forward(&x, freqs_cis)?;
@ -427,7 +428,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng();
for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
let input = Tensor::new(ctxt, &Device::Cpu)?;
let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;

View File

@ -88,6 +88,30 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
}
}
fn take<T: Copy>(
ids: &[u32],
shape: &Shape,
stride: &[usize],
vs: &[T],
vocab_size: usize,
hidden_size: usize,
) -> Result<Vec<T>> {
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
for index in StridedIndex::new(shape.dims(), stride) {
let index = ids[index].try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "take",
});
} else {
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(values)
}
fn copy_strided_src_<T: Copy + std::fmt::Display>(
src: &[T],
dst: &mut [T],
@ -380,52 +404,30 @@ impl CpuStorage {
pub(crate) fn embedding_impl(
&self,
rhs: &Self,
shape: &Shape,
stride: &[usize],
vs: &Self,
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
match self {
CpuStorage::U32(lhs) => match rhs {
CpuStorage::F32(rhs) => {
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
for &index in lhs {
let index: usize = index.try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "embedding",
});
} else {
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(CpuStorage::F32(weights))
CpuStorage::U32(ids) => match vs {
CpuStorage::F32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F32(storage))
}
CpuStorage::F64(rhs) => {
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
for &index in lhs {
let index: usize = index.try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "embedding",
});
} else {
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(CpuStorage::F64(weights))
CpuStorage::F64(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F64(storage))
}
CpuStorage::U32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::U32(storage))
}
rhs => Err(Error::UnexpectedDType {
expected: DType::F32,
got: rhs.dtype(),
}),
},
lhs => Err(Error::UnexpectedDType {
ids => Err(Error::UnexpectedDType {
expected: DType::U32,
got: lhs.dtype(),
got: ids.dtype(),
}),
}
}

View File

@ -413,6 +413,8 @@ impl CudaStorage {
pub(crate) fn embedding_impl(
&self,
_shape: &Shape,
_stride: &[usize],
_rhs: &Self,
_hidden_size: usize,
_vocab_size: usize,

View File

@ -90,7 +90,14 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn embedding_impl(&self, _: &Self, _: usize, _: usize) -> Result<Self> {
pub(crate) fn embedding_impl(
&self,
_: &Shape,
_: &[usize],
_: &Self,
_: usize,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}

View File

@ -156,19 +156,20 @@ impl Storage {
pub(crate) fn embedding_impl(
&self,
shape: &Shape,
stride: &[usize],
rhs: &Self,
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
self.same_device(rhs, "embedding")?;
self.same_dtype(rhs, "embedding")?;
match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
Ok(Self::Cuda(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {

View File

@ -435,11 +435,16 @@ impl Tensor {
op: "embedding",
});
}
let seq_len = ids.shape().r1()?;
let ids_shape = ids.shape();
let seq_len = ids_shape.r1()?;
let (vocab_size, hidden_size) = rhs.shape().r2()?;
let storage = ids
.storage
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
let storage = ids.storage.embedding_impl(
ids_shape,
&ids.stride,
&rhs.storage,
hidden_size,
vocab_size,
)?;
let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone()))