mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add where_cond and properly apply the causal mask.
This commit is contained in:
@ -220,10 +220,8 @@ impl Mlp {
|
|||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
let shape = mask.shape();
|
let shape = mask.shape();
|
||||||
let _on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
let on_true = Tensor::new(on_true, &on_false.device())?.broadcast_as(shape.dims())?;
|
||||||
// TODO: add an equivalent to where (or xla's select) so that we can use the following:
|
let m = mask.where_cond(&on_true, on_false)?;
|
||||||
// let m = mask.where_cond(&on_true, on_false)?;
|
|
||||||
let m = on_false.clone();
|
|
||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -297,7 +295,7 @@ impl CausalSelfAttention {
|
|||||||
//let mask = Tensor::new(1u32, &device)?
|
//let mask = Tensor::new(1u32, &device)?
|
||||||
// .broadcast_as(&[t, t])?
|
// .broadcast_as(&[t, t])?
|
||||||
// .lower_triangle()?
|
// .lower_triangle()?
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), &device)?.reshape(&[1, 1, t, t])?;
|
let mask = Tensor::from_slice(&mask, (t, t), &device)?.broadcast_as(att.shape())?;
|
||||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||||
let att = att.softmax(att.rank() - 1)?;
|
let att = att.softmax(att.rank() - 1)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
|
@ -24,6 +24,15 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
} else if let Some(op) = node.op() {
|
} else if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
|
Op::WhereCond(t1, t2, t3) => {
|
||||||
|
let (tg, nodes) = walk(t1, nodes, already_seen);
|
||||||
|
track_grad |= tg;
|
||||||
|
let (tg, nodes) = walk(t2, nodes, already_seen);
|
||||||
|
track_grad |= tg;
|
||||||
|
let (tg, nodes) = walk(t3, nodes, already_seen);
|
||||||
|
track_grad |= tg;
|
||||||
|
nodes
|
||||||
|
}
|
||||||
Op::Add(lhs, rhs)
|
Op::Add(lhs, rhs)
|
||||||
| Op::Mul(lhs, rhs)
|
| Op::Mul(lhs, rhs)
|
||||||
| Op::Sub(lhs, rhs)
|
| Op::Sub(lhs, rhs)
|
||||||
@ -161,6 +170,9 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
|
Op::WhereCond(_pred, _t, _f) => {
|
||||||
|
return Err(Error::BackwardNotSupported { op: "where_cond" })
|
||||||
|
}
|
||||||
Op::Embedding(_lhs, _rhs) => {
|
Op::Embedding(_lhs, _rhs) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "embedding" })
|
return Err(Error::BackwardNotSupported { op: "embedding" })
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,36 @@ pub enum CpuStorage {
|
|||||||
F64(Vec<f64>),
|
F64(Vec<f64>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn wcond<T: Copy>(
|
||||||
|
pred: &[u32],
|
||||||
|
shape: &Shape,
|
||||||
|
stride: &[usize],
|
||||||
|
t: &[T],
|
||||||
|
stride_t: &[usize],
|
||||||
|
f: &[T],
|
||||||
|
stride_f: &[usize],
|
||||||
|
) -> Vec<T> {
|
||||||
|
if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f)
|
||||||
|
{
|
||||||
|
let elem_count = shape.elem_count();
|
||||||
|
let pred = &pred[..elem_count];
|
||||||
|
let t = &t[..elem_count];
|
||||||
|
let f = &f[..elem_count];
|
||||||
|
pred.iter()
|
||||||
|
.zip(t.iter().zip(f.iter()))
|
||||||
|
.map(|(&p, (&t, &f))| if p > 0 { t } else { f })
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
} else {
|
||||||
|
let dims = shape.dims();
|
||||||
|
let it_p = StridedIndex::new(dims, stride);
|
||||||
|
let it_t = StridedIndex::new(dims, stride_t);
|
||||||
|
let it_f = StridedIndex::new(dims, stride_f);
|
||||||
|
it_p.zip(it_t.zip(it_f))
|
||||||
|
.map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] })
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
stride: &[usize],
|
stride: &[usize],
|
||||||
@ -402,6 +432,38 @@ impl CpuStorage {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn where_cond(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
stride: &[usize],
|
||||||
|
t: &Self,
|
||||||
|
stride_t: &[usize],
|
||||||
|
f: &Self,
|
||||||
|
stride_f: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
// TODO: Support types that could be casted to a boolean.
|
||||||
|
let pred = self.as_slice::<u32>()?;
|
||||||
|
match (t, f) {
|
||||||
|
(Self::F32(t), Self::F32(f)) => {
|
||||||
|
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||||
|
Ok(Self::F32(data))
|
||||||
|
}
|
||||||
|
(Self::F64(t), Self::F64(f)) => {
|
||||||
|
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||||
|
Ok(Self::F64(data))
|
||||||
|
}
|
||||||
|
(Self::U32(t), Self::U32(f)) => {
|
||||||
|
let data = wcond(pred, shape, stride, t, stride_t, f, stride_f);
|
||||||
|
Ok(Self::U32(data))
|
||||||
|
}
|
||||||
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: t.dtype(),
|
||||||
|
rhs: f.dtype(),
|
||||||
|
op: "where_cond",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
@ -410,8 +472,8 @@ impl CpuStorage {
|
|||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
vocab_size: usize,
|
vocab_size: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
match self {
|
let ids = self.as_slice::<u32>()?;
|
||||||
CpuStorage::U32(ids) => match vs {
|
match vs {
|
||||||
CpuStorage::F32(vs) => {
|
CpuStorage::F32(vs) => {
|
||||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||||
Ok(CpuStorage::F32(storage))
|
Ok(CpuStorage::F32(storage))
|
||||||
@ -424,11 +486,6 @@ impl CpuStorage {
|
|||||||
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
let storage = take(ids, shape, stride, vs, vocab_size, hidden_size)?;
|
||||||
Ok(CpuStorage::U32(storage))
|
Ok(CpuStorage::U32(storage))
|
||||||
}
|
}
|
||||||
},
|
|
||||||
ids => Err(Error::UnexpectedDType {
|
|
||||||
expected: DType::U32,
|
|
||||||
got: ids.dtype(),
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -411,6 +411,18 @@ impl CudaStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn where_cond(
|
||||||
|
&self,
|
||||||
|
_shape: &Shape,
|
||||||
|
_stride: &[usize],
|
||||||
|
_t: &Self,
|
||||||
|
_stride_t: &[usize],
|
||||||
|
_f: &Self,
|
||||||
|
_stride_f: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(CudaError::InternalError("TODO: implement where_cond"))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
_shape: &Shape,
|
_shape: &Shape,
|
||||||
|
@ -90,6 +90,18 @@ impl CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn where_cond(
|
||||||
|
&self,
|
||||||
|
_: &Shape,
|
||||||
|
_: &[usize],
|
||||||
|
_: &Self,
|
||||||
|
_: &[usize],
|
||||||
|
_: &Self,
|
||||||
|
_: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
_: &Shape,
|
_: &Shape,
|
||||||
|
@ -12,6 +12,7 @@ pub(crate) enum Op {
|
|||||||
BroadcastDiv(Tensor, Tensor),
|
BroadcastDiv(Tensor, Tensor),
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
|
WhereCond(Tensor, Tensor, Tensor),
|
||||||
|
|
||||||
Cat(Vec<Tensor>, usize),
|
Cat(Vec<Tensor>, usize),
|
||||||
|
|
||||||
|
@ -154,6 +154,35 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn where_cond(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
stride: &[usize],
|
||||||
|
t: &Self,
|
||||||
|
stride_t: &[usize],
|
||||||
|
f: &Self,
|
||||||
|
stride_f: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.same_device(t, "where")?;
|
||||||
|
self.same_device(f, "where")?;
|
||||||
|
t.same_dtype(f, "where")?;
|
||||||
|
match (self, t, f) {
|
||||||
|
(Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => {
|
||||||
|
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||||
|
Ok(Self::Cpu(storage))
|
||||||
|
}
|
||||||
|
(Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => {
|
||||||
|
let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?;
|
||||||
|
Ok(Self::Cuda(storage))
|
||||||
|
}
|
||||||
|
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||||
|
lhs: lhs.device().location(),
|
||||||
|
rhs: rhs.device().location(),
|
||||||
|
op: "embedding",
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn embedding_impl(
|
pub(crate) fn embedding_impl(
|
||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
|
@ -425,6 +425,29 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, c_shape, op, false))
|
Ok(from_storage(storage, c_shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn where_cond(&self, on_true: &Self, on_false: &Self) -> Result<Self> {
|
||||||
|
let _shap = self.same_shape_binary_op(on_true, "where_cond")?;
|
||||||
|
let shape = self.same_shape_binary_op(on_false, "where_cond")?;
|
||||||
|
let storage = self.storage.where_cond(
|
||||||
|
shape,
|
||||||
|
self.stride(),
|
||||||
|
&on_true.storage,
|
||||||
|
on_true.stride(),
|
||||||
|
&on_false.storage,
|
||||||
|
on_false.stride(),
|
||||||
|
)?;
|
||||||
|
let op = if self.track_op() || on_true.track_op() || on_false.track_op() {
|
||||||
|
Some(Op::WhereCond(
|
||||||
|
self.clone(),
|
||||||
|
on_true.clone(),
|
||||||
|
on_false.clone(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(from_storage(storage, shape, op, false))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
pub fn embedding(ids: &Self, rhs: &Self) -> Result<Self> {
|
||||||
if !rhs.is_contiguous() {
|
if !rhs.is_contiguous() {
|
||||||
return Err(Error::RequiresContiguous { op: "embedding" });
|
return Err(Error::RequiresContiguous { op: "embedding" });
|
||||||
|
Reference in New Issue
Block a user