diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index f997170d..cfe7f1b6 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -46,6 +46,121 @@ struct Attention { span_softmax: tracing::Span, } +// rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) +struct EinSum1; +impl candle::CustomOp2 for EinSum1 { + fn name(&self) -> &'static str { + "einsum1" + } + + fn cpu_fwd( + &self, + s1: &candle::CpuStorage, + l1: &candle::Layout, + s2: &candle::CpuStorage, + l2: &candle::Layout, + ) -> Result<(candle::CpuStorage, candle::Shape)> { + use candle::cpu::kernels::VecOps; + + let (b, h, w, c) = l1.shape().dims4()?; + let (h2, k, c2) = l2.shape().dims3()?; + if c != c2 || h != h2 { + candle::bail!("shape mismatch {l1:?} {l2:?}") + } + let s1 = s1.as_slice::()?; + let s1 = match l1.contiguous_offsets() { + None => candle::bail!("input1 has to be contiguous"), + Some((o1, o2)) => &s1[o1..o2], + }; + let s2 = s2.as_slice::()?; + let s2 = match l2.contiguous_offsets() { + None => candle::bail!("input2 has to be contiguous"), + Some((o1, o2)) => &s2[o1..o2], + }; + let mut dst = vec![0f32; b * h * w * k]; + for b_idx in 0..b { + let lhs_idx = b_idx * h * w * c; + let dst_idx = b_idx * h * w * k; + for h_idx in 0..h { + let lhs_idx = lhs_idx + h_idx * w * c; + let rhs_idx = h_idx * k * c; + let dst_idx = dst_idx + h_idx * w * k; + for w_idx in 0..w { + let lhs_idx = lhs_idx + w_idx * c; + let dst_idx = dst_idx + w_idx * k; + let lhs = &s1[lhs_idx..lhs_idx + c]; + for k_idx in 0..k { + let rhs_idx = rhs_idx + k_idx * c; + let rhs = &s2[rhs_idx..rhs_idx + c]; + let mut d = 0f32; + unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) }; + dst[dst_idx + k_idx] += d; + } + } + } + } + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, h, w, k).into())) + } +} + +// rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) +struct EinSum2; +impl candle::CustomOp2 for EinSum2 { + fn name(&self) -> &'static str { + "einsum2" + } + + fn cpu_fwd( + &self, + s1: &candle::CpuStorage, + l1: &candle::Layout, + s2: &candle::CpuStorage, + l2: &candle::Layout, + ) -> Result<(candle::CpuStorage, candle::Shape)> { + use candle::cpu::kernels::VecOps; + + let (b, h, w, c) = l1.shape().dims4()?; + let (w2, k, c2) = l2.shape().dims3()?; + if c != c2 || w != w2 { + candle::bail!("shape mismatch {l1:?} {l2:?}") + } + let s1 = s1.as_slice::()?; + let s1 = match l1.contiguous_offsets() { + None => candle::bail!("input1 has to be contiguous"), + Some((o1, o2)) => &s1[o1..o2], + }; + let s2 = s2.as_slice::()?; + let s2 = match l2.contiguous_offsets() { + None => candle::bail!("input2 has to be contiguous"), + Some((o1, o2)) => &s2[o1..o2], + }; + let mut dst = vec![0f32; b * h * w * k]; + for b_idx in 0..b { + let lhs_idx = b_idx * h * w * c; + let dst_idx = b_idx * h * w * k; + for h_idx in 0..h { + let lhs_idx = lhs_idx + h_idx * w * c; + let dst_idx = dst_idx + h_idx * w * k; + for w_idx in 0..w { + let lhs_idx = lhs_idx + w_idx * c; + let rhs_idx = w_idx * k * c; + let dst_idx = dst_idx + w_idx * k; + let lhs = &s1[lhs_idx..lhs_idx + c]; + for k_idx in 0..k { + let rhs_idx = rhs_idx + k_idx * c; + let rhs = &s2[rhs_idx..rhs_idx + c]; + let mut d = 0f32; + unsafe { f32::vec_dot(lhs.as_ptr(), rhs.as_ptr(), &mut d, c) }; + dst[dst_idx + k_idx] += d; + } + } + } + } + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, (b, h, w, k).into())) + } +} impl Attention { fn new( dim: usize, @@ -90,18 +205,15 @@ impl Attention { ) -> Result { match &self.rel_pos_hw { Some((rel_pos_h, rel_pos_w)) => { + println!("{:?} {:?}", attn.layout(), q.layout()); let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?; let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?; let (b, _, dim) = q.dims3()?; let r_q = q.reshape((b, q_h, q_w, dim))?; // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) - let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?; + let rel_h = r_q.apply_op2_no_bwd(&r_h, &EinSum1)?; // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) - let rel_w = r_q - .transpose(1, 2)? // -> bwhc - .contiguous()? - .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk - .transpose(1, 2)?; + let rel_w = r_q.apply_op2_no_bwd(&r_w, &EinSum2)?; (attn.reshape((b, q_h, q_w, k_h, k_w))? + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? .reshape((b, q_h * q_w, k_h * k_w))