diff --git a/examples/llama/main.rs b/examples/llama/main.rs index 32ad71bc..54a02079 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -350,7 +350,8 @@ impl Llama { } fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result { - let (_, t) = x.shape().r2()?; + // TODO: Support for mini-batches? (i.e. r2) + let t = x.shape().r1()?; let mut x = self.wte.forward(x)?; for block in self.blocks.iter() { x = block.forward(&x, freqs_cis)?; @@ -427,7 +428,7 @@ fn main() -> Result<()> { let mut rng = thread_rng(); for index in 0..args.sample_len { let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..]; - let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?; + let input = Tensor::new(ctxt, &Device::Cpu)?; let logits = llama.forward(&input, &freqs_cis)?; let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?; let logits_v: Vec = prs.to_vec1()?; diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index ebd55453..a2112a30 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -88,6 +88,30 @@ fn binary_map T>( } } +fn take( + ids: &[u32], + shape: &Shape, + stride: &[usize], + vs: &[T], + vocab_size: usize, + hidden_size: usize, +) -> Result> { + let mut values = Vec::with_capacity(shape.elem_count() * hidden_size); + for index in StridedIndex::new(shape.dims(), stride) { + let index = ids[index].try_into()?; + if index >= vocab_size { + return Err(Error::InvalidIndex { + index, + vocab_size, + op: "take", + }); + } else { + values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]); + } + } + Ok(values) +} + fn copy_strided_src_( src: &[T], dst: &mut [T], @@ -380,52 +404,30 @@ impl CpuStorage { pub(crate) fn embedding_impl( &self, - rhs: &Self, + shape: &Shape, + stride: &[usize], + vs: &Self, hidden_size: usize, vocab_size: usize, ) -> Result { match self { - CpuStorage::U32(lhs) => match rhs { - CpuStorage::F32(rhs) => { - let mut weights = Vec::with_capacity(lhs.len() * hidden_size); - for &index in lhs { - let index: usize = index.try_into()?; - if index >= vocab_size { - return Err(Error::InvalidIndex { - index, - vocab_size, - op: "embedding", - }); - } else { - weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]); - } - } - Ok(CpuStorage::F32(weights)) + CpuStorage::U32(ids) => match vs { + CpuStorage::F32(vs) => { + let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; + Ok(CpuStorage::F32(storage)) } - CpuStorage::F64(rhs) => { - let mut weights = Vec::with_capacity(lhs.len() * hidden_size); - for &index in lhs { - let index: usize = index.try_into()?; - if index >= vocab_size { - return Err(Error::InvalidIndex { - index, - vocab_size, - op: "embedding", - }); - } else { - weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]); - } - } - Ok(CpuStorage::F64(weights)) + CpuStorage::F64(vs) => { + let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; + Ok(CpuStorage::F64(storage)) + } + CpuStorage::U32(vs) => { + let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?; + Ok(CpuStorage::U32(storage)) } - rhs => Err(Error::UnexpectedDType { - expected: DType::F32, - got: rhs.dtype(), - }), }, - lhs => Err(Error::UnexpectedDType { + ids => Err(Error::UnexpectedDType { expected: DType::U32, - got: lhs.dtype(), + got: ids.dtype(), }), } } diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index fdfca801..61125f93 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -413,6 +413,8 @@ impl CudaStorage { pub(crate) fn embedding_impl( &self, + _shape: &Shape, + _stride: &[usize], _rhs: &Self, _hidden_size: usize, _vocab_size: usize, diff --git a/src/dummy_cuda_backend.rs b/src/dummy_cuda_backend.rs index 98762277..da7221e4 100644 --- a/src/dummy_cuda_backend.rs +++ b/src/dummy_cuda_backend.rs @@ -90,7 +90,14 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn embedding_impl(&self, _: &Self, _: usize, _: usize) -> Result { + pub(crate) fn embedding_impl( + &self, + _: &Shape, + _: &[usize], + _: &Self, + _: usize, + _: usize, + ) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/src/storage.rs b/src/storage.rs index c13a01a6..16f74995 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -156,19 +156,20 @@ impl Storage { pub(crate) fn embedding_impl( &self, + shape: &Shape, + stride: &[usize], rhs: &Self, hidden_size: usize, vocab_size: usize, ) -> Result { self.same_device(rhs, "embedding")?; - self.same_dtype(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/src/tensor.rs b/src/tensor.rs index bfd9964f..32347328 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -435,11 +435,16 @@ impl Tensor { op: "embedding", }); } - let seq_len = ids.shape().r1()?; + let ids_shape = ids.shape(); + let seq_len = ids_shape.r1()?; let (vocab_size, hidden_size) = rhs.shape().r2()?; - let storage = ids - .storage - .embedding_impl(&rhs.storage, hidden_size, vocab_size)?; + let storage = ids.storage.embedding_impl( + ids_shape, + &ids.stride, + &rhs.storage, + hidden_size, + vocab_size, + )?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone()))