mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Rework the embeddings so that it works on non-contiguous weights + factor out some code.
This commit is contained in:
@ -350,7 +350,8 @@ impl Llama {
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||
let (_, t) = x.shape().r2()?;
|
||||
// TODO: Support for mini-batches? (i.e. r2)
|
||||
let t = x.shape().r1()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, freqs_cis)?;
|
||||
@ -427,7 +428,7 @@ fn main() -> Result<()> {
|
||||
let mut rng = thread_rng();
|
||||
for index in 0..args.sample_len {
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
|
||||
let input = Tensor::new(ctxt, &Device::Cpu)?;
|
||||
let logits = llama.forward(&input, &freqs_cis)?;
|
||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
|
@ -88,6 +88,30 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
}
|
||||
}
|
||||
|
||||
fn take<T: Copy>(
|
||||
ids: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &[T],
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
) -> Result<Vec<T>> {
|
||||
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
|
||||
for index in StridedIndex::new(shape.dims(), stride) {
|
||||
let index = ids[index].try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "take",
|
||||
});
|
||||
} else {
|
||||
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(values)
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
@ -380,52 +404,30 @@ impl CpuStorage {
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
vs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
CpuStorage::U32(lhs) => match rhs {
|
||||
CpuStorage::F32(rhs) => {
|
||||
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
||||
for &index in lhs {
|
||||
let index: usize = index.try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "embedding",
|
||||
});
|
||||
} else {
|
||||
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
CpuStorage::U32(ids) => match vs {
|
||||
CpuStorage::F32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F32(storage))
|
||||
}
|
||||
CpuStorage::F64(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F64(storage))
|
||||
}
|
||||
Ok(CpuStorage::F32(weights))
|
||||
CpuStorage::U32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::U32(storage))
|
||||
}
|
||||
CpuStorage::F64(rhs) => {
|
||||
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
||||
for &index in lhs {
|
||||
let index: usize = index.try_into()?;
|
||||
if index >= vocab_size {
|
||||
return Err(Error::InvalidIndex {
|
||||
index,
|
||||
vocab_size,
|
||||
op: "embedding",
|
||||
});
|
||||
} else {
|
||||
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
||||
}
|
||||
}
|
||||
Ok(CpuStorage::F64(weights))
|
||||
}
|
||||
rhs => Err(Error::UnexpectedDType {
|
||||
expected: DType::F32,
|
||||
got: rhs.dtype(),
|
||||
}),
|
||||
},
|
||||
lhs => Err(Error::UnexpectedDType {
|
||||
ids => Err(Error::UnexpectedDType {
|
||||
expected: DType::U32,
|
||||
got: lhs.dtype(),
|
||||
got: ids.dtype(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
@ -413,6 +413,8 @@ impl CudaStorage {
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
_shape: &Shape,
|
||||
_stride: &[usize],
|
||||
_rhs: &Self,
|
||||
_hidden_size: usize,
|
||||
_vocab_size: usize,
|
||||
|
@ -90,7 +90,14 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(&self, _: &Self, _: usize, _: usize) -> Result<Self> {
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
|
@ -156,19 +156,20 @@ impl Storage {
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
rhs: &Self,
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
self.same_device(rhs, "embedding")?;
|
||||
self.same_dtype(rhs, "embedding")?;
|
||||
match (self, rhs) {
|
||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
|
||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
|
||||
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
|
@ -435,11 +435,16 @@ impl Tensor {
|
||||
op: "embedding",
|
||||
});
|
||||
}
|
||||
let seq_len = ids.shape().r1()?;
|
||||
let ids_shape = ids.shape();
|
||||
let seq_len = ids_shape.r1()?;
|
||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||
let storage = ids
|
||||
.storage
|
||||
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
|
||||
let storage = ids.storage.embedding_impl(
|
||||
ids_shape,
|
||||
&ids.stride,
|
||||
&rhs.storage,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
)?;
|
||||
let shape: Shape = (seq_len, hidden_size).into();
|
||||
let op = if ids.track_op() || rhs.track_op() {
|
||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||
|
Reference in New Issue
Block a user