Rework the embeddings so that it works on non-contiguous weights + factor out some code.

This commit is contained in:
laurent
2023-06-25 17:37:47 +01:00
parent 334524e2c4
commit 817e4b5005
6 changed files with 66 additions and 48 deletions

View File

@ -350,7 +350,8 @@ impl Llama {
} }
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
let (_, t) = x.shape().r2()?; // TODO: Support for mini-batches? (i.e. r2)
let t = x.shape().r1()?;
let mut x = self.wte.forward(x)?; let mut x = self.wte.forward(x)?;
for block in self.blocks.iter() { for block in self.blocks.iter() {
x = block.forward(&x, freqs_cis)?; x = block.forward(&x, freqs_cis)?;
@ -427,7 +428,7 @@ fn main() -> Result<()> {
let mut rng = thread_rng(); let mut rng = thread_rng();
for index in 0..args.sample_len { for index in 0..args.sample_len {
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?; let input = Tensor::new(ctxt, &Device::Cpu)?;
let logits = llama.forward(&input, &freqs_cis)?; let logits = llama.forward(&input, &freqs_cis)?;
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?; let logits_v: Vec<f32> = prs.to_vec1()?;

View File

@ -88,6 +88,30 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
} }
} }
fn take<T: Copy>(
ids: &[u32],
shape: &Shape,
stride: &[usize],
vs: &[T],
vocab_size: usize,
hidden_size: usize,
) -> Result<Vec<T>> {
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
for index in StridedIndex::new(shape.dims(), stride) {
let index = ids[index].try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "take",
});
} else {
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(values)
}
fn copy_strided_src_<T: Copy + std::fmt::Display>( fn copy_strided_src_<T: Copy + std::fmt::Display>(
src: &[T], src: &[T],
dst: &mut [T], dst: &mut [T],
@ -380,52 +404,30 @@ impl CpuStorage {
pub(crate) fn embedding_impl( pub(crate) fn embedding_impl(
&self, &self,
rhs: &Self, shape: &Shape,
stride: &[usize],
vs: &Self,
hidden_size: usize, hidden_size: usize,
vocab_size: usize, vocab_size: usize,
) -> Result<Self> { ) -> Result<Self> {
match self { match self {
CpuStorage::U32(lhs) => match rhs { CpuStorage::U32(ids) => match vs {
CpuStorage::F32(rhs) => { CpuStorage::F32(vs) => {
let mut weights = Vec::with_capacity(lhs.len() * hidden_size); let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
for &index in lhs { Ok(CpuStorage::F32(storage))
let index: usize = index.try_into()?;
if index >= vocab_size {
return Err(Error::InvalidIndex {
index,
vocab_size,
op: "embedding",
});
} else {
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(CpuStorage::F32(weights))
} }
CpuStorage::F64(rhs) => { CpuStorage::F64(vs) => {
let mut weights = Vec::with_capacity(lhs.len() * hidden_size); let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
for &index in lhs { Ok(CpuStorage::F64(storage))
let index: usize = index.try_into()?; }
if index >= vocab_size { CpuStorage::U32(vs) => {
return Err(Error::InvalidIndex { let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
index, Ok(CpuStorage::U32(storage))
vocab_size,
op: "embedding",
});
} else {
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
}
}
Ok(CpuStorage::F64(weights))
} }
rhs => Err(Error::UnexpectedDType {
expected: DType::F32,
got: rhs.dtype(),
}),
}, },
lhs => Err(Error::UnexpectedDType { ids => Err(Error::UnexpectedDType {
expected: DType::U32, expected: DType::U32,
got: lhs.dtype(), got: ids.dtype(),
}), }),
} }
} }

View File

@ -413,6 +413,8 @@ impl CudaStorage {
pub(crate) fn embedding_impl( pub(crate) fn embedding_impl(
&self, &self,
_shape: &Shape,
_stride: &[usize],
_rhs: &Self, _rhs: &Self,
_hidden_size: usize, _hidden_size: usize,
_vocab_size: usize, _vocab_size: usize,

View File

@ -90,7 +90,14 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }
pub(crate) fn embedding_impl(&self, _: &Self, _: usize, _: usize) -> Result<Self> { pub(crate) fn embedding_impl(
&self,
_: &Shape,
_: &[usize],
_: &Self,
_: usize,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport) Err(Error::NotCompiledWithCudaSupport)
} }

View File

@ -156,19 +156,20 @@ impl Storage {
pub(crate) fn embedding_impl( pub(crate) fn embedding_impl(
&self, &self,
shape: &Shape,
stride: &[usize],
rhs: &Self, rhs: &Self,
hidden_size: usize, hidden_size: usize,
vocab_size: usize, vocab_size: usize,
) -> Result<Self> { ) -> Result<Self> {
self.same_device(rhs, "embedding")?; self.same_device(rhs, "embedding")?;
self.same_dtype(rhs, "embedding")?;
match (self, rhs) { match (self, rhs) {
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?; let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
Ok(Self::Cpu(storage)) Ok(Self::Cpu(storage))
} }
(Self::Cuda(lhs), Self::Cuda(rhs)) => { (Self::Cuda(lhs), Self::Cuda(rhs)) => {
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?; let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
Ok(Self::Cuda(storage)) Ok(Self::Cuda(storage))
} }
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {

View File

@ -435,11 +435,16 @@ impl Tensor {
op: "embedding", op: "embedding",
}); });
} }
let seq_len = ids.shape().r1()?; let ids_shape = ids.shape();
let seq_len = ids_shape.r1()?;
let (vocab_size, hidden_size) = rhs.shape().r2()?; let (vocab_size, hidden_size) = rhs.shape().r2()?;
let storage = ids let storage = ids.storage.embedding_impl(
.storage ids_shape,
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?; &ids.stride,
&rhs.storage,
hidden_size,
vocab_size,
)?;
let shape: Shape = (seq_len, hidden_size).into(); let shape: Shape = (seq_len, hidden_size).into();
let op = if ids.track_op() || rhs.track_op() { let op = if ids.track_op() || rhs.track_op() {
Some(Op::Embedding(ids.clone(), rhs.clone())) Some(Op::Embedding(ids.clone(), rhs.clone()))