Add where_cond and properly apply the causal mask.

This commit is contained in:
laurent
2023-06-25 21:08:03 +01:00
parent 25bcad290e
commit 117f014b55
8 changed files with 168 additions and 24 deletions

View File

@ -220,10 +220,8 @@ impl Mlp {
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let _on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
// TODO: add an equivalent to where (or xla's select) so that we can use the following:
// let m = mask.where_cond(&on_true, on_false)?;
let m = on_false.clone();
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
@ -297,7 +295,7 @@ impl CausalSelfAttention {
//let mask = Tensor::new(1u32, &device)?
// .broadcast_as(&[t, t])?
// .lower_triangle()?
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = att.softmax(att.rank() - 1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.

View File

@ -24,6 +24,15 @@ impl Tensor {
nodes
} else if let Some(op) = node.op() {
match op {
Op::WhereCond(t1, t2, t3) => {
let (tg, nodes) = walk(t1, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t2, nodes, already_seen);
track_grad |= tg;
let (tg, nodes) = walk(t3, nodes, already_seen);
track_grad |= tg;
nodes
}
Op::Add(lhs, rhs)
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
@ -161,6 +170,9 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
}
Op::WhereCond(_pred, _t, _f) => {
return Err(Error::BackwardNotSupported { op: "where_cond" })
}
Op::Embedding(_lhs, _rhs) => {
return Err(Error::BackwardNotSupported { op: "embedding" })
}

View File

@ -13,6 +13,36 @@ pub enum CpuStorage {
F64(Vec<f64>),
}
fn wcond<T: Copy>(
pred: &[u32],
shape: &Shape,
stride: &[usize],
t: &[T],
stride_t: &[usize],
f: &[T],
stride_f: &[usize],
) -> Vec<T> {
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
{
let elem_count = shape.elem_count();
let pred = &pred[..elem_count];
let t = &t[..elem_count];
let f = &f[..elem_count];
pred.iter()
.zip(t.iter().zip(f.iter()))
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
.collect::<Vec<_>>()
} else {
let dims = shape.dims();
let it_p = StridedIndex::new(dims, stride);
let it_t = StridedIndex::new(dims, stride_t);
let it_f = StridedIndex::new(dims, stride_f);
it_p.zip(it_t.zip(it_f))
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
.collect::<Vec<_>>()
}
}
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
shape: &Shape,
stride: &[usize],
@ -402,6 +432,38 @@ impl CpuStorage {
Ok(())
}
pub(crate) fn where_cond(
&self,
shape: &Shape,
stride: &[usize],
t: &Self,
stride_t: &[usize],
f: &Self,
stride_f: &[usize],
) -> Result<Self> {
// TODO: Support types that could be casted to a boolean.
let pred = self.as_slice::<u32>()?;
match (t, f) {
(Self::F32(t), Self::F32(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F32(data))
}
(Self::F64(t), Self::F64(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::F64(data))
}
(Self::U32(t), Self::U32(f)) => {
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
Ok(Self::U32(data))
}
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: t.dtype(),
rhs: f.dtype(),
op: "where_cond",
}),
}
}
pub(crate) fn embedding_impl(
&self,
shape: &Shape,
@ -410,8 +472,8 @@ impl CpuStorage {
hidden_size: usize,
vocab_size: usize,
) -> Result<Self> {
match self {
CpuStorage::U32(ids) => match vs {
let ids = self.as_slice::<u32>()?;
match vs {
CpuStorage::F32(vs) => {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::F32(storage))
@ -424,11 +486,6 @@ impl CpuStorage {
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
Ok(CpuStorage::U32(storage))
}
},
ids => Err(Error::UnexpectedDType {
expected: DType::U32,
got: ids.dtype(),
}),
}
}

View File

@ -411,6 +411,18 @@ impl CudaStorage {
}
}
pub(crate) fn where_cond(
&self,
_shape: &Shape,
_stride: &[usize],
_t: &Self,
_stride_t: &[usize],
_f: &Self,
_stride_f: &[usize],
) -> Result<Self> {
Err(CudaError::InternalError("TODO: implement where_cond"))
}
pub(crate) fn embedding_impl(
&self,
_shape: &Shape,

View File

@ -90,6 +90,18 @@ impl CudaStorage {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn where_cond(
&self,
_: &Shape,
_: &[usize],
_: &Self,
_: &[usize],
_: &Self,
_: &[usize],
) -> Result<Self> {
Err(Error::NotCompiledWithCudaSupport)
}
pub(crate) fn embedding_impl(
&self,
_: &Shape,

View File

@ -12,6 +12,7 @@ pub(crate) enum Op {
BroadcastDiv(Tensor, Tensor),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
WhereCond(Tensor, Tensor, Tensor),
Cat(Vec<Tensor>, usize),

View File

@ -154,6 +154,35 @@ impl Storage {
}
}
pub(crate) fn where_cond(
&self,
shape: &Shape,
stride: &[usize],
t: &Self,
stride_t: &[usize],
f: &Self,
stride_f: &[usize],
) -> Result<Self> {
self.same_device(t, "where")?;
self.same_device(f, "where")?;
t.same_dtype(f, "where")?;
match (self, t, f) {
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
Ok(Self::Cuda(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "embedding",
}),
}
}
pub(crate) fn embedding_impl(
&self,
shape: &Shape,

View File

@ -425,6 +425,29 @@ impl Tensor {
Ok(from_storage(storage, c_shape, op, false))
}
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
let storage = self.storage.where_cond(
shape,
self.stride(),
&on_true.storage,
on_true.stride(),
&on_false.storage,
on_false.stride(),
)?;
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
Some(Op::WhereCond(
self.clone(),
on_true.clone(),
on_false.clone(),
))
} else {
None
};
Ok(from_storage(storage, shape, op, false))
}
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
if !rhs.is_contiguous() {
return Err(Error::RequiresContiguous { op: "embedding" });