diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index d030fab1..be8f7b07 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -30,7 +30,7 @@ pub enum Error { UnsupportedDTypeForOp(DType, &'static str), // === Dimension Index Errors === - #[error("{op}: dimension index {dim} out of range for {shape:?}")] + #[error("{op}: dimension index {dim} out of range for shape {shape:?}")] DimOutOfRange { shape: Shape, dim: i32, diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index 578e8ac9..9617d1ac 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -73,6 +73,14 @@ impl From<(usize, usize, usize, usize, usize)> for Shape { } } +impl From<(usize, usize, usize, usize, usize, usize)> for Shape { + fn from(d123456: (usize, usize, usize, usize, usize, usize)) -> Self { + Self(vec![ + d123456.0, d123456.1, d123456.2, d123456.3, d123456.4, d123456.5, + ]) + } +} + impl From> for Shape { fn from(dims: Vec) -> Self { Self(dims) diff --git a/candle-examples/examples/segment-anything/main.rs b/candle-examples/examples/segment-anything/main.rs index 368b5a33..a2722270 100644 --- a/candle-examples/examples/segment-anything/main.rs +++ b/candle-examples/examples/segment-anything/main.rs @@ -15,7 +15,7 @@ pub mod model_sam; pub mod model_transformer; use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{Linear, Module, VarBuilder}; use clap::Parser; pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { @@ -26,65 +26,74 @@ pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Resu } } +#[derive(Debug)] +pub struct LayerNorm2d { + weight: Tensor, + bias: Tensor, + num_channels: usize, + eps: f64, +} + +impl LayerNorm2d { + pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(num_channels, "weight")?; + let bias = vb.get(num_channels, "bias")?; + Ok(Self { + weight, + bias, + num_channels, + eps, + }) + } +} + +impl Module for LayerNorm2d { + fn forward(&self, xs: &Tensor) -> Result { + let u = xs.mean_keepdim(1)?; + let xs = xs.broadcast_sub(&u)?; + let s = xs.sqr()?.mean_keepdim(1)?; + let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?; + xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)? + .broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?) + } +} + #[derive(Debug)] pub struct MlpBlock { lin1: Linear, lin2: Linear, + activation: candle_nn::Activation, } impl MlpBlock { - pub fn new(embedding_dim: usize, mlp_dim: usize, vb: VarBuilder) -> Result { + pub fn new( + embedding_dim: usize, + mlp_dim: usize, + activation: candle_nn::Activation, + vb: VarBuilder, + ) -> Result { let lin1 = candle_nn::linear(embedding_dim, mlp_dim, vb.pp("lin1"))?; let lin2 = candle_nn::linear(mlp_dim, embedding_dim, vb.pp("lin2"))?; - Ok(Self { lin1, lin2 }) + Ok(Self { + lin1, + lin2, + activation, + }) } } impl Module for MlpBlock { fn forward(&self, xs: &Tensor) -> Result { - xs.apply(&self.lin1)?.gelu()?.apply(&self.lin2) + xs.apply(&self.lin1)? + .apply(&self.activation)? + .apply(&self.lin2) } } -/* - fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result { - let npatch = xs.dim(1)? - 1; - let n = self.pos_embed.dim(1)? - 1; - let sqrt_n = (n as f64).sqrt(); - if npatch == n && w == h { - return Ok(xs.clone()); - } - let class_pos_embed = self.pos_embed.i((.., ..1))?; - let patch_pos_embed = self.pos_embed.i((.., 1..))?; - let dim = xs.dim(D::Minus1)?; - let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); - let patch_pos_embed = patch_pos_embed - .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? - .transpose(2, 3)? - .transpose(1, 2)?; - // This uses bicubic interpolation in the original implementation. - let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; - let el_count = patch_pos_embed.shape().elem_count(); - let patch_pos_embed = - patch_pos_embed - .transpose(1, 2)? - .transpose(2, 3)? - .reshape((1, el_count / dim, dim))?; - Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1) - } - - fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result { - let (_b, _nc, w, h) = xs.dims4()?; - let xs = self.patch_embed.forward(xs)?; - let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; - &xs + &self.interpolate_pos_encoding(&xs, w, h)? - } -*/ - #[derive(Parser)] struct Args { #[arg(long)] - model: String, + model: Option, #[arg(long)] image: String, @@ -99,13 +108,24 @@ pub fn main() -> anyhow::Result<()> { let device = candle_examples::device(args.cpu)?; - let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device); + let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?; println!("loaded image {image:?}"); - let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? }; + let model = match args.model { + Some(model) => std::path::PathBuf::from(model), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-sam".to_string()); + api.get("sam_vit_b_01ec64.safetensors")? + } + }; + let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device); - let _sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b + let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b + let (mask, iou_predictions) = sam.forward(&image, false)?; + println!("mask: {mask:?}"); + println!("iou_predictions: {iou_predictions:?}"); Ok(()) } diff --git a/candle-examples/examples/segment-anything/model_image_encoder.rs b/candle-examples/examples/segment-anything/model_image_encoder.rs index cfcdbb38..f5db2830 100644 --- a/candle-examples/examples/segment-anything/model_image_encoder.rs +++ b/candle-examples/examples/segment-anything/model_image_encoder.rs @@ -70,6 +70,60 @@ impl Attention { rel_pos_hw, }) } + + fn add_decomposed_rel_pos( + &self, + attn: Tensor, + q: &Tensor, + (q_h, q_w): (usize, usize), + (k_h, k_w): (usize, usize), + ) -> Result { + match &self.rel_pos_hw { + Some((rel_pos_h, rel_pos_w)) => { + let r_h = get_rel_pos(q_h, k_h, rel_pos_h)?; + let r_w = get_rel_pos(q_w, k_w, rel_pos_w)?; + let (b, _, dim) = q.dims3()?; + let r_q = q.reshape((b, q_h, q_w, dim))?; + // rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + let rel_h = r_q.matmul(&r_h.broadcast_left(b)?.t()?.contiguous()?)?; + // rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + let rel_w = r_q + .transpose(1, 2)? // -> bwhc + .contiguous()? + .matmul(&r_w.broadcast_left(b)?.t()?.contiguous()?)? // bwhc,bwck -> bwhk + .transpose(1, 2)?; + (attn.reshape((b, q_h, q_w, k_h, k_w))? + + rel_h.unsqueeze(4)?.broadcast_add(&rel_w.unsqueeze(3)?)?)? + .reshape((b, q_h * q_w, k_h * k_w)) + } + None => Ok(attn), + } + } +} + +fn get_rel_pos(q_size: usize, k_size: usize, rel_pos: &Tensor) -> Result { + let max_rel_dist = 2 * usize::max(q_size, k_size) - 1; + let dev = rel_pos.device(); + let rel_pos_resized = if rel_pos.dim(0)? != max_rel_dist { + todo!("interpolation") + } else { + rel_pos + }; + let q_coords = Tensor::arange(0u32, q_size as u32, dev)? + .reshape((q_size, 1))? + .to_dtype(DType::F32)?; + let k_coords = Tensor::arange(0u32, k_size as u32, dev)? + .reshape((1, k_size))? + .to_dtype(DType::F32)?; + let q_coords = (q_coords * f64::max(1f64, k_size as f64 / q_size as f64))?; + let k_coords = (k_coords * f64::max(1f64, q_size as f64 / k_size as f64))?; + let relative_coords = (q_coords.broadcast_sub(&k_coords)? + + (k_size as f64 - 1.) * f64::max(1f64, q_size as f64 / k_size as f64))?; + let (d1, d2) = relative_coords.dims2()?; + let relative_coords = relative_coords.to_dtype(DType::U32)?; + rel_pos_resized + .index_select(&relative_coords.reshape(d1 * d2)?, 0)? + .reshape((d1, d2, rel_pos_resized.dim(1)?)) } impl Module for Attention { @@ -77,24 +131,22 @@ impl Module for Attention { let (b, h, w, c) = xs.dims4()?; let qkv = self .qkv - .forward(xs)? + .forward(&xs.flatten_to(1)?)? .reshape((b, h * w, 3, self.num_heads, c / self.num_heads))? .permute((2, 0, 3, 1, 4))? .reshape((3, b * self.num_heads, h * w, c / self.num_heads))?; let q = qkv.i(0)?; let k = qkv.i(1)?; let v = qkv.i(2)?; - let attn = (q * self.scale)?.matmul(&k.t()?)?; - if self.use_rel_pos { - todo!() - } + let attn = (&q * self.scale)?.matmul(&k.t()?)?; + let attn = self.add_decomposed_rel_pos(attn, &q, (h, w), (h, w))?; let attn = candle_nn::ops::softmax_last_dim(&attn)?; + let attn = attn.matmul(&v)?; let attn = attn - .matmul(&v)? .reshape((b, self.num_heads, h, w, c / self.num_heads))? .permute((0, 2, 3, 1, 4))? - .reshape((b, h, w, c / self.num_heads))?; - self.proj.forward(&attn) + .reshape((b, h * w, c))?; + self.proj.forward(&attn)?.reshape((b, h, w, c)) } } @@ -117,8 +169,8 @@ impl Block { input_size: (usize, usize), vb: VarBuilder, ) -> Result { - let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?; - let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?; + let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; + let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; let input_size_attn = if window_size == 0 { input_size } else { @@ -132,7 +184,7 @@ impl Block { input_size_attn, vb.pp("attn"), )?; - let mlp = crate::MlpBlock::new(dim, dim * 4, vb.pp("mlp"))?; + let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?; Ok(Self { norm1, attn, @@ -143,17 +195,76 @@ impl Block { } } +fn window_partition(xs: Tensor, window_size: usize) -> Result<(Tensor, (usize, usize))> { + let (b, h, w, c) = xs.dims4()?; + let pad_h = (window_size - h % window_size) % window_size; + let pad_w = (window_size - w % window_size) % window_size; + let xs = if pad_h > 0 { + xs.pad_with_zeros(1, 0, pad_h)? + } else { + xs + }; + let xs = if pad_w > 0 { + xs.pad_with_zeros(2, 0, pad_w)? + } else { + xs + }; + let (h_p, w_p) = (h + pad_h, w + pad_w); + let windows = xs + .reshape(( + b, + h_p / window_size, + window_size, + w_p / window_size, + window_size, + c, + ))? + .transpose(2, 3)? + .contiguous()? + .flatten_to(2)?; + Ok((windows, (h_p, w_p))) +} + +fn window_unpartition( + windows: Tensor, + window_size: usize, + (h_p, w_p): (usize, usize), + (h, w): (usize, usize), +) -> Result { + let b = windows.dim(0)? / (h_p * w_p / window_size / window_size); + let xs = windows + .reshape(( + b, + h_p / window_size, + w_p / window_size, + window_size, + window_size, + windows.elem_count() / b / h_p / w_p, + ))? + .transpose(2, 3)? + .contiguous()? + .reshape((b, h_p, w_p, windows.elem_count() / b / h_p / w_p))?; + let xs = if h_p > h { xs.narrow(1, 0, h)? } else { xs }; + let xs = if w_p > w { xs.narrow(2, 0, w)? } else { xs }; + Ok(xs) +} + impl Module for Block { fn forward(&self, xs: &Tensor) -> Result { let shortcut = xs; let xs = self.norm1.forward(xs)?; - if self.window_size > 0 { - todo!() - } + let hw = (xs.dim(1)?, xs.dim(2)?); + let (xs, pad_hw) = if self.window_size > 0 { + window_partition(xs, self.window_size)? + } else { + (xs, (0, 0)) + }; let xs = self.attn.forward(&xs)?; - if self.window_size > 0 { - todo!() - } + let xs = if self.window_size > 0 { + window_unpartition(xs, self.window_size, pad_hw, hw)? + } else { + xs + }; let xs = (xs + shortcut)?; &xs + xs.apply(&self.norm2)?.apply(&self.mlp)? } @@ -165,9 +276,9 @@ pub struct ImageEncoderViT { patch_embed: PatchEmbed, blocks: Vec, neck_conv1: candle_nn::Conv2d, - neck_ln1: LayerNorm, + neck_ln1: crate::LayerNorm2d, neck_conv2: candle_nn::Conv2d, - neck_ln2: LayerNorm, + neck_ln2: crate::LayerNorm2d, pos_embed: Option, } @@ -222,13 +333,13 @@ impl ImageEncoderViT { Default::default(), vb.pp("neck.0"), )?; - let neck_ln1 = layer_norm(out_chans, 1e-6, vb.pp("neck.1"))?; + let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?; let cfg = candle_nn::Conv2dConfig { padding: 1, ..Default::default() }; let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?; - let neck_ln2 = layer_norm(out_chans, 1e-6, vb.pp("neck.3"))?; + let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?; let pos_embed = if use_abs_pos { let p = vb.get( (1, img_size / patch_size, img_size / patch_size, embed_dim), diff --git a/candle-examples/examples/segment-anything/model_mask_decoder.rs b/candle-examples/examples/segment-anything/model_mask_decoder.rs index cf3879cd..1ef46eeb 100644 --- a/candle-examples/examples/segment-anything/model_mask_decoder.rs +++ b/candle-examples/examples/segment-anything/model_mask_decoder.rs @@ -1,5 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{Linear, Module, VarBuilder}; use crate::model_transformer::TwoWayTransformer; @@ -60,7 +60,7 @@ pub struct MaskDecoder { mask_tokens: candle_nn::Embedding, iou_prediction_head: MlpMaskDecoder, output_upscaling_conv1: candle_nn::ConvTranspose2d, - output_upscaling_ln: LayerNorm, + output_upscaling_ln: crate::LayerNorm2d, output_upscaling_conv2: candle_nn::ConvTranspose2d, num_mask_tokens: usize, output_hypernetworks_mlps: Vec, @@ -99,7 +99,7 @@ impl MaskDecoder { vb.pp("output_upscaling.0"), )?; let output_upscaling_ln = - layer_norm(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; + crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?; let output_upscaling_conv2 = candle_nn::conv_transpose2d( transformer_dim / 4, transformer_dim / 8, @@ -140,7 +140,7 @@ impl MaskDecoder { }) } - fn forward( + pub fn forward( &self, image_embeddings: &Tensor, image_pe: &Tensor, @@ -195,7 +195,7 @@ impl MaskDecoder { // Run the transformer let (hs, src) = self.transformer.forward(&src, &pos_src, &tokens)?; let iou_token_out = hs.i((.., 0))?; - let mask_tokens_out = hs.i((.., 1, 1 + self.num_mask_tokens))?; + let mask_tokens_out = hs.i((.., 1..1 + self.num_mask_tokens))?; // Upscale mask embeddings and predict masks using the masks tokens. let src = src.transpose(1, 2)?.reshape((b, c, h, w))?; @@ -213,9 +213,8 @@ impl MaskDecoder { } let hyper_in = Tensor::stack(hyper_in_list.as_slice(), 1)?; let (b, c, h, w) = upscaled_embedding.dims4()?; - let masks = hyper_in - .matmul(&upscaled_embedding.reshape((b, c, h * w))?)? - .reshape((b, 0, h, w))?; + let masks = hyper_in.matmul(&upscaled_embedding.reshape((b, c, h * w))?)?; + let masks = masks.reshape((b, masks.elem_count() / b / h / w, h, w))?; // Generate mask quality predictions. let iou_pred = self.iou_prediction_head.forward(&iou_token_out)?; @@ -224,6 +223,9 @@ impl MaskDecoder { } // Equivalent to torch.repeat_interleave -fn repeat_interleave(_img: &Tensor, _repeats: usize, _dim: usize) -> Result { - todo!() +fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result { + let img = img.unsqueeze(dim + 1)?; + let mut dims = img.dims().to_vec(); + dims[dim + 1] = repeats; + img.broadcast_as(dims)?.flatten(dim, dim + 1) } diff --git a/candle-examples/examples/segment-anything/model_prompt_encoder.rs b/candle-examples/examples/segment-anything/model_prompt_encoder.rs index 7ac4c66d..c6ffffd2 100644 --- a/candle-examples/examples/segment-anything/model_prompt_encoder.rs +++ b/candle-examples/examples/segment-anything/model_prompt_encoder.rs @@ -1,5 +1,5 @@ use candle::{DType, IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; +use candle_nn::{Linear, Module, VarBuilder}; #[derive(Debug)] struct PostionEmbeddingRandom { @@ -17,7 +17,7 @@ impl PostionEmbeddingRandom { fn pe_encoding(&self, coords: &Tensor) -> Result { let coords = coords.affine(2., -1.)?; - let coords = coords.matmul(&self.positional_encoding_gaussian_matrix)?; + let coords = coords.broadcast_matmul(&self.positional_encoding_gaussian_matrix)?; let coords = (coords * (2. * std::f64::consts::PI))?; Tensor::cat(&[coords.sin()?, coords.cos()?], D::Minus1) } @@ -25,12 +25,14 @@ impl PostionEmbeddingRandom { fn forward(&self, h: usize, w: usize) -> Result { let device = self.positional_encoding_gaussian_matrix.device(); let grid = Tensor::ones((h, w), DType::F32, device)?; - // TODO: cumsum - let x_embed = (&grid - 0.5)?; - // TODO: cumsum - let y_embed = (&grid - 0.5)?; - let x_embed = (x_embed / w as f64)?; - let y_embed = (y_embed / h as f64)?; + let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?; + let x_embed = (x_embed / w as f64)? + .reshape((1, w))? + .broadcast_as((h, w))?; + let y_embed = (y_embed / h as f64)? + .reshape((h, 1))? + .broadcast_as((h, w))?; let coords = Tensor::stack(&[&x_embed, &y_embed], D::Minus1)?; self.pe_encoding(&coords)?.permute((2, 0, 1)) } @@ -55,13 +57,14 @@ pub struct PromptEncoder { point_embeddings: Vec, not_a_point_embed: candle_nn::Embedding, mask_downscaling_conv1: candle_nn::Conv2d, - mask_downscaling_ln1: LayerNorm, + mask_downscaling_ln1: crate::LayerNorm2d, mask_downscaling_conv2: candle_nn::Conv2d, - mask_downscaling_ln2: LayerNorm, + mask_downscaling_ln2: crate::LayerNorm2d, mask_downscaling_conv3: candle_nn::Conv2d, no_mask_embed: candle_nn::Embedding, image_embedding_size: (usize, usize), input_image_size: (usize, usize), + embed_dim: usize, } impl PromptEncoder { @@ -97,8 +100,9 @@ impl PromptEncoder { vb.pp("mask_downscaling.6"), )?; let mask_downscaling_ln1 = - layer_norm(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; - let mask_downscaling_ln2 = layer_norm(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; + crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?; + let mask_downscaling_ln2 = + crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?; let mut point_embeddings = Vec::with_capacity(num_points_embeddings); let vb_e = vb.pp("point_embeddings"); for i in 0..num_points_embeddings { @@ -117,9 +121,16 @@ impl PromptEncoder { no_mask_embed, image_embedding_size, input_image_size, + embed_dim, }) } + pub fn get_dense_pe(&self) -> Result { + self.pe_layer + .forward(self.image_embedding_size.0, self.image_embedding_size.1)? + .unsqueeze(0) + } + fn embed_masks(&self, masks: &Tensor) -> Result { masks .apply(&self.mask_downscaling_conv1)? @@ -133,7 +144,16 @@ impl PromptEncoder { fn embed_points(&self, points: &Tensor, labels: &Tensor, pad: bool) -> Result { let points = (points + 0.5)?; - let points = if pad { todo!() } else { points }; + let dev = points.device(); + let (points, labels) = if pad { + let padding_point = Tensor::zeros((points.dim(0)?, 1, 2), DType::F32, dev)?; + let padding_label = (Tensor::ones((labels.dim(0)?, 1), DType::F32, dev)? * (-1f64))?; + let points = Tensor::cat(&[&points, &padding_point], 1)?; + let labels = Tensor::cat(&[labels, &padding_label], 1)?; + (points, labels) + } else { + (points, labels.clone()) + }; let point_embedding = self .pe_layer .forward_with_coords(&points, self.input_image_size)?; @@ -154,7 +174,7 @@ impl PromptEncoder { Tensor::cat(&[&ce1, &ce2], 1) } - fn forward( + pub fn forward( &self, points: Option<(&Tensor, &Tensor)>, boxes: Option<&Tensor>, @@ -172,7 +192,9 @@ impl PromptEncoder { (Some(se_points), Some(se_boxes)) => Tensor::cat(&[se_points, se_boxes], 1)?, (Some(se_points), None) => se_points, (None, Some(se_boxes)) => se_boxes, - (None, None) => Tensor::zeros(1, DType::F32, &candle::Device::Cpu)?, + (None, None) => { + Tensor::zeros((1, 0, self.embed_dim), DType::F32, &candle::Device::Cpu)? + } }; let dense_embeddings = match masks { @@ -180,7 +202,7 @@ impl PromptEncoder { let emb = self.no_mask_embed.embeddings(); emb.reshape((1, emb.elem_count(), 1, 1))?.expand(( 1, - 0, + emb.elem_count(), self.image_embedding_size.0, self.image_embedding_size.1, ))? diff --git a/candle-examples/examples/segment-anything/model_sam.rs b/candle-examples/examples/segment-anything/model_sam.rs index 5a0d7e8f..1c8e9a59 100644 --- a/candle-examples/examples/segment-anything/model_sam.rs +++ b/candle-examples/examples/segment-anything/model_sam.rs @@ -5,6 +5,10 @@ use crate::model_image_encoder::ImageEncoderViT; use crate::model_mask_decoder::MaskDecoder; use crate::model_prompt_encoder::PromptEncoder; +const PROMPT_EMBED_DIM: usize = 256; +const IMAGE_SIZE: usize = 1024; +const VIT_PATCH_SIZE: usize = 16; + #[derive(Debug)] pub struct Sam { image_encoder: ImageEncoderViT, @@ -22,10 +26,6 @@ impl Sam { encoder_global_attn_indexes: &[usize], vb: VarBuilder, ) -> Result { - const PROMPT_EMBED_DIM: usize = 256; - const IMAGE_SIZE: usize = 1024; - const VIT_PATCH_SIZE: usize = 16; - let image_embedding_size = IMAGE_SIZE / VIT_PATCH_SIZE; let image_encoder = ImageEncoderViT::new( @@ -69,4 +69,33 @@ impl Sam { pixel_mean, }) } + + pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> { + let img = self.preprocess(img)?.unsqueeze(0)?; + let img_embeddings = self.image_encoder.forward(&img)?; + let image_pe = self.prompt_encoder.get_dense_pe()?; + let (sparse_prompt_embeddings, dense_prompt_embeddings) = + self.prompt_encoder.forward(None, None, None)?; + let (low_res_mask, iou_predictions) = self.mask_decoder.forward( + &img_embeddings, + &image_pe, + &sparse_prompt_embeddings, + &dense_prompt_embeddings, + multimask_output, + )?; + // TODO: post-processing. + Ok((low_res_mask, iou_predictions)) + } + + fn preprocess(&self, img: &Tensor) -> Result { + let (c, h, w) = img.dims3()?; + let img = img + .broadcast_sub(&self.pixel_mean)? + .broadcast_div(&self.pixel_std)?; + if h > IMAGE_SIZE || w > IMAGE_SIZE { + candle::bail!("image is too large ({w}, {h}), maximum size {IMAGE_SIZE}") + } + let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?; + img.pad_with_zeros(2, 0, IMAGE_SIZE - w) + } } diff --git a/candle-examples/examples/segment-anything/model_transformer.rs b/candle-examples/examples/segment-anything/model_transformer.rs index a845085d..044dce9b 100644 --- a/candle-examples/examples/segment-anything/model_transformer.rs +++ b/candle-examples/examples/segment-anything/model_transformer.rs @@ -36,7 +36,8 @@ impl Attention { fn separate_heads(&self, x: &Tensor) -> Result { let (b, n, c) = x.dims3()?; x.reshape((b, n, self.num_heads, c / self.num_heads))? - .transpose(1, 2) + .transpose(1, 2)? + .contiguous() } fn recombine_heads(&self, x: &Tensor) -> Result { @@ -102,8 +103,12 @@ impl TwoWayAttentionBlock { 2, vb.pp("cross_attn_image_to_token"), )?; - // TODO: use relu in this mlp - let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?; + let mlp = crate::MlpBlock::new( + embedding_dim, + mlp_dim, + candle_nn::Activation::Relu, + vb.pp("mlp"), + )?; Ok(Self { self_attn, norm1, @@ -126,7 +131,7 @@ impl TwoWayAttentionBlock { ) -> Result<(Tensor, Tensor)> { // Self attention block let queries = if self.skip_first_layer_pe { - self.self_attn.forward(queries, keys, queries)? + self.self_attn.forward(queries, queries, queries)? } else { let q = (queries + query_pe)?; let attn_out = self.self_attn.forward(&q, &q, queries)?; diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 08e2f628..d2e80a82 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -28,7 +28,7 @@ //! ``` //! //! [`Layer Normalization`]: https://arxiv.org/abs/1607.06450 -use candle::{DType, Result, Tensor}; +use candle::{DType, Result, Tensor, D}; #[derive(Debug, Clone, Copy, PartialEq)] pub struct LayerNormConfig { @@ -104,15 +104,15 @@ impl crate::Module for LayerNorm { DType::F16 | DType::BF16 => DType::F32, d => d, }; - let (_bsize, _seq_len, hidden_size) = x.dims3()?; + let hidden_size = x.dim(D::Minus1)?; let x = x.to_dtype(internal_dtype)?; let x = if self.remove_mean { - let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; x.broadcast_sub(&mean_x)? } else { x }; - let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?; let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; let x = x_normed.to_dtype(x_dtype)?.broadcast_mul(&self.weight)?; match &self.bias { diff --git a/candle-nn/src/linear.rs b/candle-nn/src/linear.rs index 7028f68c..de335964 100644 --- a/candle-nn/src/linear.rs +++ b/candle-nn/src/linear.rs @@ -41,8 +41,9 @@ impl Linear { impl super::Module for Linear { fn forward(&self, x: &Tensor) -> candle::Result { - let w = match x.dims() { - &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, + let w = match *x.dims() { + [b1, b2, _, _] => self.weight.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, _ => self.weight.t()?, }; let x = x.matmul(&w)?;