mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add where_cond and properly apply the causal mask.
This commit is contained in:
@ -220,10 +220,8 @@ impl Mlp {
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let _on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
// TODO: add an equivalent to where (or xla's select) so that we can use the following:
|
||||
// let m = mask.where_cond(&on_true, on_false)?;
|
||||
let m = on_false.clone();
|
||||
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
@ -297,7 +295,7 @@ impl CausalSelfAttention {
|
||||
//let mask = Tensor::new(1u32, &device)?
|
||||
// .broadcast_as(&[t, t])?
|
||||
// .lower_triangle()?
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = att.softmax(att.rank() - 1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
|
@ -24,6 +24,15 @@ impl Tensor {
|
||||
nodes
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::WhereCond(t1, t2, t3) => {
|
||||
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(t3, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
}
|
||||
Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
@ -161,6 +170,9 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(_pred, _t, _f) => {
|
||||
return Err(Error::BackwardNotSupported { op: "where_cond" })
|
||||
}
|
||||
Op::Embedding(_lhs, _rhs) => {
|
||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||
}
|
||||
|
@ -13,6 +13,36 @@ pub enum CpuStorage {
|
||||
F64(Vec<f64>),
|
||||
}
|
||||
|
||||
fn wcond<T: Copy>(
|
||||
pred: &[u32],
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &[T],
|
||||
stride_t: &[usize],
|
||||
f: &[T],
|
||||
stride_f: &[usize],
|
||||
) -> Vec<T> {
|
||||
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
||||
{
|
||||
let elem_count = shape.elem_count();
|
||||
let pred = &pred[..elem_count];
|
||||
let t = &t[..elem_count];
|
||||
let f = &f[..elem_count];
|
||||
pred.iter()
|
||||
.zip(t.iter().zip(f.iter()))
|
||||
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
let dims = shape.dims();
|
||||
let it_p = StridedIndex::new(dims, stride);
|
||||
let it_t = StridedIndex::new(dims, stride_t);
|
||||
let it_f = StridedIndex::new(dims, stride_f);
|
||||
it_p.zip(it_t.zip(it_f))
|
||||
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
}
|
||||
|
||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
@ -402,6 +432,38 @@ impl CpuStorage {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
) -> Result<Self> {
|
||||
// TODO: Support types that could be casted to a boolean.
|
||||
let pred = self.as_slice::<u32>()?;
|
||||
match (t, f) {
|
||||
(Self::F32(t), Self::F32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F32(data))
|
||||
}
|
||||
(Self::F64(t), Self::F64(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::F64(data))
|
||||
}
|
||||
(Self::U32(t), Self::U32(f)) => {
|
||||
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||
Ok(Self::U32(data))
|
||||
}
|
||||
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: t.dtype(),
|
||||
rhs: f.dtype(),
|
||||
op: "where_cond",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
@ -410,8 +472,8 @@ impl CpuStorage {
|
||||
hidden_size: usize,
|
||||
vocab_size: usize,
|
||||
) -> Result<Self> {
|
||||
match self {
|
||||
CpuStorage::U32(ids) => match vs {
|
||||
let ids = self.as_slice::<u32>()?;
|
||||
match vs {
|
||||
CpuStorage::F32(vs) => {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::F32(storage))
|
||||
@ -424,11 +486,6 @@ impl CpuStorage {
|
||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||
Ok(CpuStorage::U32(storage))
|
||||
}
|
||||
},
|
||||
ids => Err(Error::UnexpectedDType {
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -411,6 +411,18 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
_shape: &Shape,
|
||||
_stride: &[usize],
|
||||
_t: &Self,
|
||||
_stride_t: &[usize],
|
||||
_f: &Self,
|
||||
_stride_f: &[usize],
|
||||
) -> Result<Self> {
|
||||
Err(CudaError::InternalError("TODO: implement where_cond"))
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
_shape: &Shape,
|
||||
|
@ -90,6 +90,18 @@ impl CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
_: &Shape,
|
||||
_: &[usize],
|
||||
_: &Self,
|
||||
_: &[usize],
|
||||
_: &Self,
|
||||
_: &[usize],
|
||||
) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
_: &Shape,
|
||||
|
@ -12,6 +12,7 @@ pub(crate) enum Op {
|
||||
BroadcastDiv(Tensor, Tensor),
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
|
||||
Cat(Vec<Tensor>, usize),
|
||||
|
||||
|
@ -154,6 +154,35 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn where_cond(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
stride: &[usize],
|
||||
t: &Self,
|
||||
stride_t: &[usize],
|
||||
f: &Self,
|
||||
stride_f: &[usize],
|
||||
) -> Result<Self> {
|
||||
self.same_device(t, "where")?;
|
||||
self.same_device(f, "where")?;
|
||||
t.same_dtype(f, "where")?;
|
||||
match (self, t, f) {
|
||||
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
}
|
||||
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
|
||||
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
}
|
||||
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "embedding",
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embedding_impl(
|
||||
&self,
|
||||
shape: &Shape,
|
||||
|
@ -425,6 +425,29 @@ impl Tensor {
|
||||
Ok(from_storage(storage, c_shape, op, false))
|
||||
}
|
||||
|
||||
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
|
||||
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||
let storage = self.storage.where_cond(
|
||||
shape,
|
||||
self.stride(),
|
||||
&on_true.storage,
|
||||
on_true.stride(),
|
||||
&on_false.storage,
|
||||
on_false.stride(),
|
||||
)?;
|
||||
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||
Some(Op::WhereCond(
|
||||
self.clone(),
|
||||
on_true.clone(),
|
||||
on_false.clone(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||
if !rhs.is_contiguous() {
|
||||
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||
|
Reference in New Issue
Block a user