Add where_cond and properly apply the causal mask.

This commit is contained in:
laurent
2023-06-25 21:08:03 +01:00
parent 25bcad290e
commit 117f014b55
8 changed files with 168 additions and 24 deletions

View File

@ -13,6 +13,36 @@ pub enum CpuStorage {
F64(Vec<f64>),
}
fn wcond<T: Copy>(
pred: &[u32],
shape: &Shape,
stride: &[usize],
t: &[T],
stride_t: &[usize],
f: &[T],
stride_f: &[usize],
) -> Vec<T> {
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
{
let elem_count = shape.elem_count();
let pred = &pred[..elem_count];
let t = &t[..elem_count];
let f = &f[..elem_count];
pred.iter()
.zip(t.iter().zip(f.iter()))
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
.collect::<Vec<_>>()
} else {
let dims = shape.dims();
let it_p = StridedIndex::new(dims, stride);
let it_t = StridedIndex::new(dims, stride_t);
let it_f = StridedIndex::new(dims, stride_f);
it_p.zip(it_t.zip(it_f))
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
.collect::<Vec<_>>()
}
}
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
shape: &Shape,
stride: &[usize],
@ -402,6 +432,38 @@ impl CpuStorage {
Ok(())
}
pub(crate) fn where_cond(
&self,
shape: &Shape,
stride: &[usize],
t: &Self,
stride_t: &[usize],
f: &Self,
stride_f: &[usize],
) -> Result<Self> {
// TODO: Support types that could be casted to a boolean.
let pred = self.as_slice::<u32>()?;
match (t, f) {
(Self::F32(t), Self::F32(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F32(data))
}
(Self::F64(t), Self::F64(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F64(data))
}
(Self::U32(t), Self::U32(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::U32(data))
}
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: t.dtype(),
rhs: f.dtype(),
op: "where_cond",
}),
}
}
pub(crate) fn embedding_impl(
&self,
shape: &Shape,
@ -410,25 +472,20 @@ impl CpuStorage {
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
match self {
CpuStorage::U32(ids) => match vs {
CpuStorage::F32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F32(storage))
}
CpuStorage::F64(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F64(storage))
}
CpuStorage::U32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::U32(storage))
}
},
ids => Err(Error::UnexpectedDType {
expected: DType::U32,
got: ids.dtype(),
}),
let ids = self.as_slice::<u32>()?;
match vs {
CpuStorage::F32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F32(storage))
}
CpuStorage::F64(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F64(storage))
}
CpuStorage::U32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::U32(storage))
}
}
}