Add where_cond and properly apply the causal mask.

This commit is contained in:
laurent
2023-06-25 21:08:03 +01:00
parent 25bcad290e
commit 117f014b55
8 changed files with 168 additions and 24 deletions

View File

@ -154,6 +154,35 @@ impl Storage {
}
}
pub(crate) fn where_cond(
&self,
shape: &Shape,
stride: &[usize],
t: &Self,
stride_t: &[usize],
f: &Self,
stride_f: &[usize],
) -> Result<Self> {
self.same_device(t, "where")?;
self.same_device(f, "where")?;
t.same_dtype(f, "where")?;
match (self, t, f) {
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
Ok(Self::Cpu(storage))
}
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
Ok(Self::Cuda(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
op: "embedding",
}),
}
}
pub(crate) fn embedding_impl(
&self,
shape: &Shape,