mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add where_cond and properly apply the causal mask.
This commit is contained in:
@ -154,6 +154,35 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.same_device(t, "where")?;
|
||||
self.same_device(f, "where")?;
|
||||
t.same_dtype(f, "where")?;
|
||||
match (self, t, f) {
|
||||
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "embedding",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
|
Reference in New Issue
Block a user