mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Rework the embeddings so that it works on non-contiguous weights + factor out some code.
This commit is contained in:
@ -350,7 +350,8 @@ impl Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let (_, t) = x.shape().r2()?;
|
// TODO: Support for mini-batches? (i.e. r2)
|
||||||
|
let t = x.shape().r1()?;
|
||||||
let mut x = self.wte.forward(x)?;
|
let mut x = self.wte.forward(x)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
x = block.forward(&x, freqs_cis)?;
|
x = block.forward(&x, freqs_cis)?;
|
||||||
@ -427,7 +428,7 @@ fn main() -> Result<()> {
|
|||||||
let mut rng = thread_rng();
|
let mut rng = thread_rng();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
let ctxt = &tokens[tokens.len().saturating_sub(CONTEXT_SIZE)..];
|
||||||
let input = Tensor::new(ctxt, &Device::Cpu)?.reshape((1, ctxt.len()))?;
|
let input = Tensor::new(ctxt, &Device::Cpu)?;
|
||||||
let logits = llama.forward(&input, &freqs_cis)?;
|
let logits = llama.forward(&input, &freqs_cis)?;
|
||||||
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
let prs = (&logits / args.temperature)?.softmax(logits.rank() - 1)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
|
@ -88,6 +88,30 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn take<T: Copy>(
|
||||||
|
ids: &[u32],
|
||||||
|
shape: &Shape,
|
||||||
|
stride: &[usize],
|
||||||
|
vs: &[T],
|
||||||
|
vocab_size: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
) -> Result<Vec<T>> {
|
||||||
|
let mut values = Vec::with_capacity(shape.elem_count() * hidden_size);
|
||||||
|
for index in StridedIndex::new(shape.dims(), stride) {
|
||||||
|
let index = ids[index].try_into()?;
|
||||||
|
if index >= vocab_size {
|
||||||
|
return Err(Error::InvalidIndex {
|
||||||
|
index,
|
||||||
|
vocab_size,
|
||||||
|
op: "take",
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
values.extend(&vs[hidden_size * index..hidden_size * (index + 1)]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(values)
|
||||||
|
}
|
||||||
|
|
||||||
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
||||||
src: &[T],
|
src: &[T],
|
||||||
dst: &mut [T],
|
dst: &mut [T],
|
||||||
@ -380,52 +404,30 @@ impl CpuStorage {
|
|||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
shape: &Shape,
|
||||||
|
stride: &[usize],
|
||||||
|
vs: &Self,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
CpuStorage::U32(lhs) => match rhs {
|
CpuStorage::U32(ids) => match vs {
|
||||||
CpuStorage::F32(rhs) => {
|
CpuStorage::F32(vs) => {
|
||||||
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||||
for &index in lhs {
|
Ok(CpuStorage::F32(storage))
|
||||||
let index: usize = index.try_into()?;
|
|
||||||
if index >= vocab_size {
|
|
||||||
return Err(Error::InvalidIndex {
|
|
||||||
index,
|
|
||||||
vocab_size,
|
|
||||||
op: "embedding",
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(CpuStorage::F32(weights))
|
|
||||||
}
|
}
|
||||||
CpuStorage::F64(rhs) => {
|
CpuStorage::F64(vs) => {
|
||||||
let mut weights = Vec::with_capacity(lhs.len() * hidden_size);
|
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||||
for &index in lhs {
|
Ok(CpuStorage::F64(storage))
|
||||||
let index: usize = index.try_into()?;
|
}
|
||||||
if index >= vocab_size {
|
CpuStorage::U32(vs) => {
|
||||||
return Err(Error::InvalidIndex {
|
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||||
index,
|
Ok(CpuStorage::U32(storage))
|
||||||
vocab_size,
|
|
||||||
op: "embedding",
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
weights.extend(&rhs[hidden_size * index..hidden_size * (index + 1)]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(CpuStorage::F64(weights))
|
|
||||||
}
|
}
|
||||||
rhs => Err(Error::UnexpectedDType {
|
|
||||||
expected: DType::F32,
|
|
||||||
got: rhs.dtype(),
|
|
||||||
}),
|
|
||||||
},
|
},
|
||||||
lhs => Err(Error::UnexpectedDType {
|
ids => Err(Error::UnexpectedDType {
|
||||||
expected: DType::U32,
|
expected: DType::U32,
|
||||||
got: lhs.dtype(),
|
got: ids.dtype(),
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -413,6 +413,8 @@ impl CudaStorage {
|
|||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
|
_shape: &Shape,
|
||||||
|
_stride: &[usize],
|
||||||
_rhs: &Self,
|
_rhs: &Self,
|
||||||
_hidden_size: usize,
|
_hidden_size: usize,
|
||||||
_vocab_size: usize,
|
_vocab_size: usize,
|
||||||
|
@ -90,7 +90,14 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(&self, _: &Self, _: usize, _: usize) -> Result<Self> {
|
pub(crate) fn embedding_impl(
|
||||||
|
&self,
|
||||||
|
_: &Shape,
|
||||||
|
_: &[usize],
|
||||||
|
_: &Self,
|
||||||
|
_: usize,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,19 +156,20 @@ impl Storage {
|
|||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
stride: &[usize],
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
self.same_device(rhs, "embedding")?;
|
self.same_device(rhs, "embedding")?;
|
||||||
self.same_dtype(rhs, "embedding")?;
|
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
(Storage::Cpu(lhs), Storage::Cpu(rhs)) => {
|
||||||
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
|
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||||
Ok(Self::Cpu(storage))
|
Ok(Self::Cpu(storage))
|
||||||
}
|
}
|
||||||
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
(Self::Cuda(lhs), Self::Cuda(rhs)) => {
|
||||||
let storage = lhs.embedding_impl(rhs, hidden_size, vocab_size)?;
|
let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?;
|
||||||
Ok(Self::Cuda(storage))
|
Ok(Self::Cuda(storage))
|
||||||
}
|
}
|
||||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
@ -435,11 +435,16 @@ impl Tensor {
|
|||||||
op: "embedding",
|
op: "embedding",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
let seq_len = ids.shape().r1()?;
|
let ids_shape = ids.shape();
|
||||||
|
let seq_len = ids_shape.r1()?;
|
||||||
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
let (vocab_size, hidden_size) = rhs.shape().r2()?;
|
||||||
let storage = ids
|
let storage = ids.storage.embedding_impl(
|
||||||
.storage
|
ids_shape,
|
||||||
.embedding_impl(&rhs.storage, hidden_size, vocab_size)?;
|
&ids.stride,
|
||||||
|
&rhs.storage,
|
||||||
|
hidden_size,
|
||||||
|
vocab_size,
|
||||||
|
)?;
|
||||||
let shape: Shape = (seq_len, hidden_size).into();
|
let shape: Shape = (seq_len, hidden_size).into();
|
||||||
let op = if ids.track_op() || rhs.track_op() {
|
let op = if ids.track_op() || rhs.track_op() {
|
||||||
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
Some(Op::Embedding(ids.clone(), rhs.clone()))
|
||||||
|
Reference in New Issue
Block a user