mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add where_cond and properly apply the causal mask.
This commit is contained in:
@ -13,6 +13,36 @@ pub enum CpuStorage {
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
fn wcond<T: Copy>(
|
||||
pred: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &[T],
|
||||
stride_t: &[usize],
|
||||
f: &[T],
|
||||
stride_f: &[usize],
|
||||
) -> Vec<T> {
|
||||
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
||||
{
|
||||
let elem_count = shape.elem_count();
|
||||
let pred = &pred[..elem_count];
|
||||
let t = &t[..elem_count];
|
||||
let f = &f[..elem_count];
|
||||
pred.iter()
|
||||
.zip(t.iter().zip(f.iter()))
|
||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
let dims = shape.dims();
|
||||
let it_p = StridedIndex::new(dims, stride);
|
||||
let it_t = StridedIndex::new(dims, stride_t);
|
||||
let it_f = StridedIndex::new(dims, stride_f);
|
||||
it_p.zip(it_t.zip(it_f))
|
||||
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
@ -402,6 +432,38 @@ impl CpuStorage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
) -> Result<Self> {
|
||||
// TODO: Support types that could be casted to a boolean.
|
||||
let pred = self.as_slice::<u32>()?;
|
||||
match (t, f) {
|
||||
(Self::F32(t), Self::F32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(t), Self::F64(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(t), Self::U32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: t.dtype(),
|
||||
rhs: f.dtype(),
|
||||
op: "where_cond",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
@ -410,25 +472,20 @@ impl CpuStorage {
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
CpuStorage::U32(ids) => match vs {
|
||||
CpuStorage::F32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F32(storage))
|
||||
}
|
||||
CpuStorage::F64(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F64(storage))
|
||||
}
|
||||
CpuStorage::U32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::U32(storage))
|
||||
}
|
||||
},
|
||||
ids => Err(Error::UnexpectedDType {
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
}),
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
match vs {
|
||||
CpuStorage::F32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F32(storage))
|
||||
}
|
||||
CpuStorage::F64(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F64(storage))
|
||||
}
|
||||
CpuStorage::U32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::U32(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user