mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fuse the rel-pos additions via a custom-op. (#786)
* Fuse the rel-pos additions via a custom-op. * Run with rayon. * Add more tracing.
This commit is contained in:
@ -34,6 +34,70 @@ impl Module for PatchEmbed {
|
||||
}
|
||||
}
|
||||
|
||||
// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
|
||||
// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
|
||||
// (attn.reshape((b, q_h, q_w, k_h, k_w))?
|
||||
// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
|
||||
// .reshape((b, q_h * q_w, k_h * k_w))
|
||||
// Ideally we would perform this operation in place but this is not supported in candle at the
|
||||
// moment. We should also investigate using f16 rather than f32.
|
||||
struct Add3(usize, usize, usize, usize, usize);
|
||||
impl candle::CustomOp3 for Add3 {
|
||||
fn name(&self) -> &'static str {
|
||||
"add3"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &candle::CpuStorage,
|
||||
l1: &candle::Layout,
|
||||
s2: &candle::CpuStorage,
|
||||
l2: &candle::Layout,
|
||||
s3: &candle::CpuStorage,
|
||||
l3: &candle::Layout,
|
||||
) -> Result<(candle::CpuStorage, candle::Shape)> {
|
||||
use rayon::prelude::*;
|
||||
|
||||
let Add3(b, q_h, q_w, k_h, k_w) = *self;
|
||||
let s1 = s1.as_slice::<f32>()?;
|
||||
let s1 = match l1.contiguous_offsets() {
|
||||
None => candle::bail!("input1 has to be contiguous"),
|
||||
Some((o1, o2)) => &s1[o1..o2],
|
||||
};
|
||||
let s2 = s2.as_slice::<f32>()?;
|
||||
let s2 = match l2.contiguous_offsets() {
|
||||
None => candle::bail!("input2 has to be contiguous"),
|
||||
Some((o1, o2)) => &s2[o1..o2],
|
||||
};
|
||||
let s3 = s3.as_slice::<f32>()?;
|
||||
let s3 = match l3.contiguous_offsets() {
|
||||
None => candle::bail!("input3 has to be contiguous"),
|
||||
Some((o1, o2)) => &s3[o1..o2],
|
||||
};
|
||||
let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
|
||||
dst.par_chunks_exact_mut(k_h * k_w)
|
||||
.enumerate()
|
||||
.for_each(|(b_idx, dst)| {
|
||||
let s1_idx = b_idx * k_h * k_w;
|
||||
let s2_idx = b_idx * k_h;
|
||||
let s3_idx = b_idx * k_w;
|
||||
for h_idx in 0..k_h {
|
||||
let s1_idx = s1_idx + h_idx * k_w;
|
||||
let s2_idx = s2_idx + h_idx;
|
||||
let dst_idx = h_idx * k_w;
|
||||
for w_idx in 0..k_w {
|
||||
let s1_idx = s1_idx + w_idx;
|
||||
let s3_idx = s3_idx + w_idx;
|
||||
let dst_idx = dst_idx + w_idx;
|
||||
dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
|
||||
}
|
||||
}
|
||||
});
|
||||
let dst = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Attention {
|
||||
qkv: crate::Linear,
|
||||
@ -42,6 +106,7 @@ struct Attention {
|
||||
scale: f64,
|
||||
rel_pos_hw: Option<(Tensor, Tensor)>,
|
||||
span: tracing::Span,
|
||||
span_matmul: tracing::Span,
|
||||
span_rel_pos: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
@ -56,6 +121,7 @@ impl Attention {
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attention");
|
||||
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
|
||||
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
|
||||
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
|
||||
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
|
||||
@ -76,6 +142,7 @@ impl Attention {
|
||||
scale,
|
||||
rel_pos_hw,
|
||||
span,
|
||||
span_matmul,
|
||||
span_rel_pos,
|
||||
span_softmax,
|
||||
})
|
||||
@ -101,10 +168,16 @@ impl Attention {
|
||||
.transpose(1, 2)? // -> bwhc
|
||||
.contiguous()?
|
||||
.matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
|
||||
.transpose(1, 2)?;
|
||||
(attn.reshape((b, q_h, q_w, k_h, k_w))?
|
||||
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
|
||||
.reshape((b, q_h * q_w, k_h * k_w))
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
if attn.device().is_cpu() {
|
||||
let op = Add3(b, q_h, q_w, k_h, k_w);
|
||||
attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
|
||||
} else {
|
||||
(attn.reshape((b, q_h, q_w, k_h, k_w))?
|
||||
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
|
||||
.reshape((b, q_h * q_w, k_h * k_w))
|
||||
}
|
||||
}
|
||||
None => Ok(attn),
|
||||
}
|
||||
@ -149,7 +222,10 @@ impl Module for Attention {
|
||||
let q = qkv.i(0)?;
|
||||
let k = qkv.i(1)?;
|
||||
let v = qkv.i(2)?;
|
||||
let attn = (&q * self.scale)?.matmul(&k.t()?)?;
|
||||
let attn = {
|
||||
let _enter = self.span_matmul.enter();
|
||||
(&q * self.scale)?.matmul(&k.t()?)?
|
||||
};
|
||||
let attn = {
|
||||
let _enter = self.span_rel_pos.enter();
|
||||
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
|
||||
@ -158,7 +234,10 @@ impl Module for Attention {
|
||||
let _enter = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attn)?
|
||||
};
|
||||
let attn = attn.matmul(&v)?;
|
||||
let attn = {
|
||||
let _enter = self.span_matmul.enter();
|
||||
attn.matmul(&v)?
|
||||
};
|
||||
let attn = attn
|
||||
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
|
||||
.permute((0, 2, 3, 1, 4))?
|
||||
|
Reference in New Issue
Block a user