Compare commits

...

1 Commits

Author SHA1 Message Date
a910ec5993 CustomOp for einsum. 2023-09-08 20:46:30 +01:00

View File

@ -46,6 +46,121 @@ struct Attention {
span_softmax: tracing::Span,
}
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
struct EinSum1;
impl candle::CustomOp2 for EinSum1 {
fn name(&self) -> &'static str {
"einsum1"
}
fn cpu_fwd(
&self,
s1: &candle::CpuStorage,
l1: &candle::Layout,
s2: &candle::CpuStorage,
l2: &candle::Layout,
) -> Result<(candle::CpuStorage, candle::Shape)> {
use candle::cpu::kernels::VecOps;
let (b, h, w, c) = l1.shape().dims4()?;
let (h2, k, c2) = l2.shape().dims3()?;
if c != c2 || h != h2 {
candle::bail!("shape mismatch {l1:?} {l2:?}")
}
let s1 = s1.as_slice::<f32>()?;
let s1 = match l1.contiguous_offsets() {
None => candle::bail!("input1 has to be contiguous"),
Some((o1, o2)) => &s1[o1..o2],
};
let s2 = s2.as_slice::<f32>()?;
let s2 = match l2.contiguous_offsets() {
None => candle::bail!("input2 has to be contiguous"),
Some((o1, o2)) => &s2[o1..o2],
};
let mut dst = vec![0f32; b * h * w * k];
for b_idx in 0..b {
let lhs_idx = b_idx * h * w * c;
let dst_idx = b_idx * h * w * k;
for h_idx in 0..h {
let lhs_idx = lhs_idx + h_idx * w * c;
let rhs_idx = h_idx * k * c;
let dst_idx = dst_idx + h_idx * w * k;
for w_idx in 0..w {
let lhs_idx = lhs_idx + w_idx * c;
let dst_idx = dst_idx + w_idx * k;
let lhs = &s1[lhs_idx..lhs_idx + c];
for k_idx in 0..k {
let rhs_idx = rhs_idx + k_idx * c;
let rhs = &s2[rhs_idx..rhs_idx + c];
let mut d = 0f32;
unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) };
dst[dst_idx + k_idx] += d;
}
}
}
}
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b, h, w, k).into()))
}
}
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
struct EinSum2;
impl candle::CustomOp2 for EinSum2 {
fn name(&self) -> &'static str {
"einsum2"
}
fn cpu_fwd(
&self,
s1: &candle::CpuStorage,
l1: &candle::Layout,
s2: &candle::CpuStorage,
l2: &candle::Layout,
) -> Result<(candle::CpuStorage, candle::Shape)> {
use candle::cpu::kernels::VecOps;
let (b, h, w, c) = l1.shape().dims4()?;
let (w2, k, c2) = l2.shape().dims3()?;
if c != c2 || w != w2 {
candle::bail!("shape mismatch {l1:?} {l2:?}")
}
let s1 = s1.as_slice::<f32>()?;
let s1 = match l1.contiguous_offsets() {
None => candle::bail!("input1 has to be contiguous"),
Some((o1, o2)) => &s1[o1..o2],
};
let s2 = s2.as_slice::<f32>()?;
let s2 = match l2.contiguous_offsets() {
None => candle::bail!("input2 has to be contiguous"),
Some((o1, o2)) => &s2[o1..o2],
};
let mut dst = vec![0f32; b * h * w * k];
for b_idx in 0..b {
let lhs_idx = b_idx * h * w * c;
let dst_idx = b_idx * h * w * k;
for h_idx in 0..h {
let lhs_idx = lhs_idx + h_idx * w * c;
let dst_idx = dst_idx + h_idx * w * k;
for w_idx in 0..w {
let lhs_idx = lhs_idx + w_idx * c;
let rhs_idx = w_idx * k * c;
let dst_idx = dst_idx + w_idx * k;
let lhs = &s1[lhs_idx..lhs_idx + c];
for k_idx in 0..k {
let rhs_idx = rhs_idx + k_idx * c;
let rhs = &s2[rhs_idx..rhs_idx + c];
let mut d = 0f32;
unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) };
dst[dst_idx + k_idx] += d;
}
}
}
}
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, (b, h, w, k).into()))
}
}
impl Attention {
fn new(
dim: usize,
@ -90,18 +205,15 @@ impl Attention {
) -> Result<Tensor> {
match &self.rel_pos_hw {
Some((rel_pos_h, rel_pos_w)) => {
println!("{:?} {:?}", attn.layout(), q.layout());
let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
let (b, _, dim) = q.dims3()?;
let r_q = q.reshape((b, q_h, q_w, dim))?;
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
let rel_h = r_q.apply_op2_no_bwd(&r_h, &EinSum1)?;
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
let rel_w = r_q
.transpose(1, 2)? // -> bwhc
.contiguous()?
.matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
.transpose(1, 2)?;
let rel_w = r_q.apply_op2_no_bwd(&r_w, &EinSum2)?;
(attn.reshape((b, q_h, q_w, k_h, k_w))?
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
.reshape((b, q_h * q_w, k_h * k_w))