mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
1 Commits
0.6.0
...
einsum-cus
Author | SHA1 | Date | |
---|---|---|---|
a910ec5993 |
@ -46,6 +46,121 @@ struct Attention {
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
struct EinSum1;
|
||||
impl candle::CustomOp2 for EinSum1 {
|
||||
fn name(&self) -> &'static str {
|
||||
"einsum1"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &candle::CpuStorage,
|
||||
l1: &candle::Layout,
|
||||
s2: &candle::CpuStorage,
|
||||
l2: &candle::Layout,
|
||||
) -> Result<(candle::CpuStorage, candle::Shape)> {
|
||||
use candle::cpu::kernels::VecOps;
|
||||
|
||||
let (b, h, w, c) = l1.shape().dims4()?;
|
||||
let (h2, k, c2) = l2.shape().dims3()?;
|
||||
if c != c2 || h != h2 {
|
||||
candle::bail!("shape mismatch {l1:?} {l2:?}")
|
||||
}
|
||||
let s1 = s1.as_slice::<f32>()?;
|
||||
let s1 = match l1.contiguous_offsets() {
|
||||
None => candle::bail!("input1 has to be contiguous"),
|
||||
Some((o1, o2)) => &s1[o1..o2],
|
||||
};
|
||||
let s2 = s2.as_slice::<f32>()?;
|
||||
let s2 = match l2.contiguous_offsets() {
|
||||
None => candle::bail!("input2 has to be contiguous"),
|
||||
Some((o1, o2)) => &s2[o1..o2],
|
||||
};
|
||||
let mut dst = vec![0f32; b * h * w * k];
|
||||
for b_idx in 0..b {
|
||||
let lhs_idx = b_idx * h * w * c;
|
||||
let dst_idx = b_idx * h * w * k;
|
||||
for h_idx in 0..h {
|
||||
let lhs_idx = lhs_idx + h_idx * w * c;
|
||||
let rhs_idx = h_idx * k * c;
|
||||
let dst_idx = dst_idx + h_idx * w * k;
|
||||
for w_idx in 0..w {
|
||||
let lhs_idx = lhs_idx + w_idx * c;
|
||||
let dst_idx = dst_idx + w_idx * k;
|
||||
let lhs = &s1[lhs_idx..lhs_idx + c];
|
||||
for k_idx in 0..k {
|
||||
let rhs_idx = rhs_idx + k_idx * c;
|
||||
let rhs = &s2[rhs_idx..rhs_idx + c];
|
||||
let mut d = 0f32;
|
||||
unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) };
|
||||
dst[dst_idx + k_idx] += d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, h, w, k).into()))
|
||||
}
|
||||
}
|
||||
|
||||
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
struct EinSum2;
|
||||
impl candle::CustomOp2 for EinSum2 {
|
||||
fn name(&self) -> &'static str {
|
||||
"einsum2"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
s1: &candle::CpuStorage,
|
||||
l1: &candle::Layout,
|
||||
s2: &candle::CpuStorage,
|
||||
l2: &candle::Layout,
|
||||
) -> Result<(candle::CpuStorage, candle::Shape)> {
|
||||
use candle::cpu::kernels::VecOps;
|
||||
|
||||
let (b, h, w, c) = l1.shape().dims4()?;
|
||||
let (w2, k, c2) = l2.shape().dims3()?;
|
||||
if c != c2 || w != w2 {
|
||||
candle::bail!("shape mismatch {l1:?} {l2:?}")
|
||||
}
|
||||
let s1 = s1.as_slice::<f32>()?;
|
||||
let s1 = match l1.contiguous_offsets() {
|
||||
None => candle::bail!("input1 has to be contiguous"),
|
||||
Some((o1, o2)) => &s1[o1..o2],
|
||||
};
|
||||
let s2 = s2.as_slice::<f32>()?;
|
||||
let s2 = match l2.contiguous_offsets() {
|
||||
None => candle::bail!("input2 has to be contiguous"),
|
||||
Some((o1, o2)) => &s2[o1..o2],
|
||||
};
|
||||
let mut dst = vec![0f32; b * h * w * k];
|
||||
for b_idx in 0..b {
|
||||
let lhs_idx = b_idx * h * w * c;
|
||||
let dst_idx = b_idx * h * w * k;
|
||||
for h_idx in 0..h {
|
||||
let lhs_idx = lhs_idx + h_idx * w * c;
|
||||
let dst_idx = dst_idx + h_idx * w * k;
|
||||
for w_idx in 0..w {
|
||||
let lhs_idx = lhs_idx + w_idx * c;
|
||||
let rhs_idx = w_idx * k * c;
|
||||
let dst_idx = dst_idx + w_idx * k;
|
||||
let lhs = &s1[lhs_idx..lhs_idx + c];
|
||||
for k_idx in 0..k {
|
||||
let rhs_idx = rhs_idx + k_idx * c;
|
||||
let rhs = &s2[rhs_idx..rhs_idx + c];
|
||||
let mut d = 0f32;
|
||||
unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) };
|
||||
dst[dst_idx + k_idx] += d;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, (b, h, w, k).into()))
|
||||
}
|
||||
}
|
||||
impl Attention {
|
||||
fn new(
|
||||
dim: usize,
|
||||
@ -90,18 +205,15 @@ impl Attention {
|
||||
) -> Result<Tensor> {
|
||||
match &self.rel_pos_hw {
|
||||
Some((rel_pos_h, rel_pos_w)) => {
|
||||
println!("{:?} {:?}", attn.layout(), q.layout());
|
||||
let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?;
|
||||
let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?;
|
||||
let (b, _, dim) = q.dims3()?;
|
||||
let r_q = q.reshape((b, q_h, q_w, dim))?;
|
||||
// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?;
|
||||
let rel_h = r_q.apply_op2_no_bwd(&r_h, &EinSum1)?;
|
||||
// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
let rel_w = r_q
|
||||
.transpose(1, 2)? // -> bwhc
|
||||
.contiguous()?
|
||||
.matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk
|
||||
.transpose(1, 2)?;
|
||||
let rel_w = r_q.apply_op2_no_bwd(&r_w, &EinSum2)?;
|
||||
(attn.reshape((b, q_h, q_w, k_h, k_w))?
|
||||
+ rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)?
|
||||
.reshape((b, q_h * q_w, k_h * k_w))
|
||||
|
Reference in New Issue
Block a user