Fuse the rel-pos additions via a custom-op. (#786)

* Fuse the rel-pos additions via a custom-op.

* Run with rayon.

* Add more tracing.
This commit is contained in:
Laurent Mazare
2023-09-09 10:46:09 +01:00
committed by GitHub
parent 722c50bb0c
commit 3cd7e7b51d
2 changed files with 86 additions and 6 deletions

View File

@ -24,6 +24,7 @@ intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
image = { workspace = true }
rayon = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }

View File

@ -34,6 +34,70 @@ impl Module for PatchEmbed {
}
}
// A custom op to make add_decomposed_rel_pos faster. Most of the time is spent on the final
// addition in the case where b = 12, q_h = q_w = 4096, k_h = k_w = 4096
// (attn.reshape((b, q_h, q_w, k_h, k_w))?
// + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
// .reshape((b, q_h * q_w, k_h * k_w))
// Ideally we would perform this operation in place but this is not supported in candle at the
// moment. We should also investigate using f16 rather than f32.
struct Add3(usize, usize, usize, usize, usize);
impl candle::CustomOp3 for Add3 {
fn name(&self) -> &'static str {
"add3"
}
fn cpu_fwd(
&self,
s1: &candle::CpuStorage,
l1: &candle::Layout,
s2: &candle::CpuStorage,
l2: &candle::Layout,
s3: &candle::CpuStorage,
l3: &candle::Layout,
) -> Result<(candle::CpuStorage, candle::Shape)> {
use rayon::prelude::*;
let Add3(b, q_h, q_w, k_h, k_w) = *self;
let s1 = s1.as_slice::<f32>()?;
let s1 = match l1.contiguous_offsets() {
None => candle::bail!("input1 has to be contiguous"),
Some((o1, o2)) => &s1[o1..o2],
};
let s2 = s2.as_slice::<f32>()?;
let s2 = match l2.contiguous_offsets() {
None => candle::bail!("input2 has to be contiguous"),
Some((o1, o2)) => &s2[o1..o2],
};
let s3 = s3.as_slice::<f32>()?;
let s3 = match l3.contiguous_offsets() {
None => candle::bail!("input3 has to be contiguous"),
Some((o1, o2)) => &s3[o1..o2],
};
let mut dst = vec![0f32; b * q_h * q_w * k_h * k_w];
dst.par_chunks_exact_mut(k_h * k_w)
.enumerate()
.for_each(|(b_idx, dst)| {
let s1_idx = b_idx * k_h * k_w;
let s2_idx = b_idx * k_h;
let s3_idx = b_idx * k_w;
for h_idx in 0..k_h {
let s1_idx = s1_idx + h_idx * k_w;
let s2_idx = s2_idx + h_idx;
let dst_idx = h_idx * k_w;
for w_idx in 0..k_w {
let s1_idx = s1_idx + w_idx;
let s3_idx = s3_idx + w_idx;
let dst_idx = dst_idx + w_idx;
dst[dst_idx] = s1[s1_idx] + s2[s2_idx] + s3[s3_idx]
}
}
});
let dst = candle::WithDType::to_cpu_storage_owned(dst);
Ok((dst, (b, q_h * q_w, k_h * k_w).into()))
}
}
#[derive(Debug)]
struct Attention {
qkv: crate::Linear,
@ -42,6 +106,7 @@ struct Attention {
scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>,
span: tracing::Span,
span_matmul: tracing::Span,
span_rel_pos: tracing::Span,
span_softmax: tracing::Span,
}
@ -56,6 +121,7 @@ impl Attention {
vb: VarBuilder,
) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attention");
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
@ -76,6 +142,7 @@ impl Attention {
scale,
rel_pos_hw,
span,
span_matmul,
span_rel_pos,
span_softmax,
})
@ -101,11 +168,17 @@ impl Attention {
.transpose(1, 2)? // -> bwhc
.contiguous()?
.matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
.transpose(1, 2)?;
.transpose(1, 2)?
.contiguous()?;
if attn.device().is_cpu() {
let op = Add3(b, q_h, q_w, k_h, k_w);
attn.apply_op3_no_bwd(&rel_h, &rel_w, &op)
} else {
(attn.reshape((b, q_h, q_w, k_h, k_w))?
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
.reshape((b, q_h * q_w, k_h * k_w))
}
}
None => Ok(attn),
}
}
@ -149,7 +222,10 @@ impl Module for Attention {
let q = qkv.i(0)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = (&q * self.scale)?.matmul(&k.t()?)?;
let attn = {
let _enter = self.span_matmul.enter();
(&q * self.scale)?.matmul(&k.t()?)?
};
let attn = {
let _enter = self.span_rel_pos.enter();
self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?
@ -158,7 +234,10 @@ impl Module for Attention {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn)?
};
let attn = attn.matmul(&v)?;
let attn = {
let _enter = self.span_matmul.enter();
attn.matmul(&v)?
};
let attn = attn
.reshape((b, self.num_heads, h, w, c / self.num_heads))?
.permute((0, 2, 3, 1, 4))?