mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Automatic mask generator + point base mask (#773)
* Add more to the automatic mask generator. * Add the target point. * Fix. * Remove the allow-unused. * Mask post-processing.
This commit is contained in:
@ -1,6 +1,5 @@
|
|||||||
//! SAM: Segment Anything Model
|
//! SAM: Segment Anything Model
|
||||||
//! https://github.com/facebookresearch/segment-anything
|
//! https://github.com/facebookresearch/segment-anything
|
||||||
#![allow(unused)]
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
@ -14,7 +13,7 @@ pub mod model_prompt_encoder;
|
|||||||
pub mod model_sam;
|
pub mod model_sam;
|
||||||
pub mod model_transformer;
|
pub mod model_transformer;
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, Result, Tensor};
|
||||||
use candle_nn::{Linear, Module, VarBuilder};
|
use candle_nn::{Linear, Module, VarBuilder};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
@ -101,6 +100,15 @@ struct Args {
|
|||||||
/// Run on CPU rather than on GPU.
|
/// Run on CPU rather than on GPU.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
generate_masks: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
point_x: Option<f64>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
point_y: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
@ -108,7 +116,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let image = if args.image.ends_with(".safetensors") {
|
let (image, initial_h, initial_w) = if args.image.ends_with(".safetensors") {
|
||||||
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
let mut tensors = candle::safetensors::load(&args.image, &device)?;
|
||||||
let image = match tensors.remove("image") {
|
let image = match tensors.remove("image") {
|
||||||
Some(image) => image,
|
Some(image) => image,
|
||||||
@ -119,13 +127,16 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
tensors.into_values().next().unwrap()
|
tensors.into_values().next().unwrap()
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if image.rank() == 4 {
|
let image = if image.rank() == 4 {
|
||||||
image.get(0)?
|
image.get(0)?
|
||||||
} else {
|
} else {
|
||||||
image
|
image
|
||||||
}
|
};
|
||||||
|
let (_c, h, w) = image.dims3()?;
|
||||||
|
(image, h, w)
|
||||||
} else {
|
} else {
|
||||||
candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?.to_device(&device)?
|
let (image, h, w) = candle_examples::load_image(args.image, Some(model_sam::IMAGE_SIZE))?;
|
||||||
|
(image.to_device(&device)?, h, w)
|
||||||
};
|
};
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
|
|
||||||
@ -142,19 +153,30 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
||||||
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
let sam = model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)?; // sam_vit_b
|
||||||
|
|
||||||
let (mask, iou_predictions) = sam.forward(&image, false)?;
|
if args.generate_masks {
|
||||||
|
// Default options similar to the Python version.
|
||||||
|
sam.generate_masks(
|
||||||
|
&image,
|
||||||
|
/* points_per_side */ 32,
|
||||||
|
/* crop_n_layer */ 0,
|
||||||
|
/* crop_overlap_ratio */ 512. / 1500.,
|
||||||
|
/* crop_n_points_downscale_factor */ 1,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
let point = args.point_x.zip(args.point_y);
|
||||||
|
let (mask, iou_predictions) = sam.forward(&image, point, false)?;
|
||||||
println!("mask:\n{mask}");
|
println!("mask:\n{mask}");
|
||||||
println!("iou_predictions: {iou_predictions:?}");
|
println!("iou_predictions: {iou_predictions:?}");
|
||||||
|
|
||||||
// Save the mask as an image.
|
// Save the mask as an image.
|
||||||
let mask = mask.ge(&mask.zeros_like()?)?;
|
let mask = (mask.ge(&mask.zeros_like()?)? * 255.)?;
|
||||||
let mask = (mask * 255.)?.squeeze(0)?;
|
|
||||||
let (_one, h, w) = mask.dims3()?;
|
let (_one, h, w) = mask.dims3()?;
|
||||||
let mask = mask.expand((3, h, w))?;
|
let mask = mask.expand((3, h, w))?;
|
||||||
candle_examples::save_image(&mask, "sam_mask.png")?;
|
candle_examples::save_image_resize(&mask, "sam_mask.png", initial_h, initial_w)?;
|
||||||
|
|
||||||
let image = sam.preprocess(&image)?;
|
let image = sam.preprocess(&image)?;
|
||||||
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
|
let image = sam.unpreprocess(&image)?.to_dtype(DType::U8)?;
|
||||||
candle_examples::save_image(&image, "sam_input_scaled.png")?;
|
candle_examples::save_image(&image, "sam_input_scaled.png")?;
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -37,7 +37,6 @@ struct Attention {
|
|||||||
proj: Linear,
|
proj: Linear,
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
scale: f64,
|
scale: f64,
|
||||||
use_rel_pos: bool,
|
|
||||||
rel_pos_hw: Option<(Tensor, Tensor)>,
|
rel_pos_hw: Option<(Tensor, Tensor)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -66,7 +65,6 @@ impl Attention {
|
|||||||
proj,
|
proj,
|
||||||
num_heads,
|
num_heads,
|
||||||
scale,
|
scale,
|
||||||
use_rel_pos,
|
|
||||||
rel_pos_hw,
|
rel_pos_hw,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -272,7 +270,6 @@ impl Module for Block {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ImageEncoderViT {
|
pub struct ImageEncoderViT {
|
||||||
img_size: usize,
|
|
||||||
patch_embed: PatchEmbed,
|
patch_embed: PatchEmbed,
|
||||||
blocks: Vec<Block>,
|
blocks: Vec<Block>,
|
||||||
neck_conv1: candle_nn::Conv2d,
|
neck_conv1: candle_nn::Conv2d,
|
||||||
@ -350,7 +347,6 @@ impl ImageEncoderViT {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
img_size,
|
|
||||||
patch_embed,
|
patch_embed,
|
||||||
blocks,
|
blocks,
|
||||||
neck_conv1,
|
neck_conv1,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{IndexOp, Result, Tensor};
|
||||||
use candle_nn::{Linear, Module, VarBuilder};
|
use candle_nn::{Linear, Module, VarBuilder};
|
||||||
|
|
||||||
use crate::model_transformer::TwoWayTransformer;
|
use crate::model_transformer::TwoWayTransformer;
|
||||||
@ -188,7 +188,7 @@ impl MaskDecoder {
|
|||||||
|
|
||||||
// Expand per-image data in batch direction to be per mask
|
// Expand per-image data in batch direction to be per mask
|
||||||
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
|
let src = repeat_interleave(image_embeddings, tokens.dim(0)?, 0)?;
|
||||||
let src = (src + dense_prompt_embeddings)?;
|
let src = src.broadcast_add(dense_prompt_embeddings)?;
|
||||||
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
|
let pos_src = repeat_interleave(image_pe, tokens.dim(0)?, 0)?;
|
||||||
let (b, c, h, w) = src.dims4()?;
|
let (b, c, h, w) = src.dims4()?;
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Linear, Module, VarBuilder};
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct PostionEmbeddingRandom {
|
struct PostionEmbeddingRandom {
|
||||||
@ -24,7 +24,6 @@ impl PostionEmbeddingRandom {
|
|||||||
|
|
||||||
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
fn forward(&self, h: usize, w: usize) -> Result<Tensor> {
|
||||||
let device = self.positional_encoding_gaussian_matrix.device();
|
let device = self.positional_encoding_gaussian_matrix.device();
|
||||||
let grid = Tensor::ones((h, w), DType::F32, device)?;
|
|
||||||
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
let x_embed = (Tensor::arange(0u32, w as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||||
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
let y_embed = (Tensor::arange(0u32, h as u32, device)?.to_dtype(DType::F32)? + 0.5)?;
|
||||||
let x_embed = (x_embed / w as f64)?
|
let x_embed = (x_embed / w as f64)?
|
||||||
@ -157,8 +156,9 @@ impl PromptEncoder {
|
|||||||
let point_embedding = self
|
let point_embedding = self
|
||||||
.pe_layer
|
.pe_layer
|
||||||
.forward_with_coords(&points, self.input_image_size)?;
|
.forward_with_coords(&points, self.input_image_size)?;
|
||||||
|
let labels = labels.unsqueeze(2)?.broadcast_as(point_embedding.shape())?;
|
||||||
let zeros = point_embedding.zeros_like()?;
|
let zeros = point_embedding.zeros_like()?;
|
||||||
let point_embeddings = labels.lt(&labels.zeros_like()?)?.where_cond(
|
let point_embedding = labels.lt(&labels.zeros_like()?)?.where_cond(
|
||||||
&self
|
&self
|
||||||
.not_a_point_embed
|
.not_a_point_embed
|
||||||
.embeddings()
|
.embeddings()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{DType, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
|
||||||
use crate::model_image_encoder::ImageEncoderViT;
|
use crate::model_image_encoder::ImageEncoderViT;
|
||||||
use crate::model_mask_decoder::MaskDecoder;
|
use crate::model_mask_decoder::MaskDecoder;
|
||||||
@ -70,12 +70,30 @@ impl Sam {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, img: &Tensor, multimask_output: bool) -> Result<(Tensor, Tensor)> {
|
pub fn forward(
|
||||||
|
&self,
|
||||||
|
img: &Tensor,
|
||||||
|
point: Option<(f64, f64)>,
|
||||||
|
multimask_output: bool,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_c, original_h, original_w) = img.dims3()?;
|
||||||
let img = self.preprocess(img)?.unsqueeze(0)?;
|
let img = self.preprocess(img)?.unsqueeze(0)?;
|
||||||
let img_embeddings = self.image_encoder.forward(&img)?;
|
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||||
|
let points = match point {
|
||||||
|
None => None,
|
||||||
|
Some((x, y)) => {
|
||||||
|
let points = Tensor::new(
|
||||||
|
&[[[x as f32 * original_w as f32, y as f32 * original_h as f32]]],
|
||||||
|
img.device(),
|
||||||
|
)?;
|
||||||
|
let labels = Tensor::ones((1, 1), DType::F32, img.device())?;
|
||||||
|
Some((points, labels))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||||
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||||
self.prompt_encoder.forward(None, None, None)?;
|
self.prompt_encoder.forward(points, None, None)?;
|
||||||
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
let (low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||||
&img_embeddings,
|
&img_embeddings,
|
||||||
&image_pe,
|
&image_pe,
|
||||||
@ -83,8 +101,11 @@ impl Sam {
|
|||||||
&dense_prompt_embeddings,
|
&dense_prompt_embeddings,
|
||||||
multimask_output,
|
multimask_output,
|
||||||
)?;
|
)?;
|
||||||
// TODO: post-processing.
|
let mask = low_res_mask
|
||||||
Ok((low_res_mask, iou_predictions))
|
.upsample_nearest2d(IMAGE_SIZE, IMAGE_SIZE)?
|
||||||
|
.get(0)?
|
||||||
|
.i((.., ..original_h, ..original_w))?;
|
||||||
|
Ok((mask, iou_predictions))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
|
pub fn unpreprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||||
@ -96,7 +117,7 @@ impl Sam {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
pub fn preprocess(&self, img: &Tensor) -> Result<Tensor> {
|
||||||
let (c, h, w) = img.dims3()?;
|
let (_c, h, w) = img.dims3()?;
|
||||||
let img = img
|
let img = img
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.broadcast_sub(&self.pixel_mean)?
|
.broadcast_sub(&self.pixel_mean)?
|
||||||
@ -107,4 +128,150 @@ impl Sam {
|
|||||||
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
|
let img = img.pad_with_zeros(1, 0, IMAGE_SIZE - h)?;
|
||||||
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
img.pad_with_zeros(2, 0, IMAGE_SIZE - w)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn process_crop(&self, img: &Tensor, cb: CropBox, point_grids: &[(f64, f64)]) -> Result<()> {
|
||||||
|
// Crop the image and calculate embeddings.
|
||||||
|
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
|
||||||
|
let img = self.preprocess(&img)?.unsqueeze(0)?;
|
||||||
|
let img_embeddings = self.image_encoder.forward(&img)?;
|
||||||
|
|
||||||
|
let crop_w = cb.x1 - cb.x0;
|
||||||
|
let crop_h = cb.y1 - cb.y0;
|
||||||
|
|
||||||
|
// Generate masks for this crop.
|
||||||
|
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||||
|
let points = point_grids
|
||||||
|
.iter()
|
||||||
|
.map(|&(x, y)| vec![x as f32 * crop_w as f32, y as f32 * crop_h as f32])
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
for points in points.chunks(64) {
|
||||||
|
let points_len = points.len();
|
||||||
|
let in_points = Tensor::new(points.to_vec(), img.device())?.unsqueeze(1)?;
|
||||||
|
let in_labels = Tensor::ones((points_len, 1), DType::F32, img.device())?;
|
||||||
|
let (sparse_prompt_embeddings, dense_prompt_embeddings) =
|
||||||
|
self.prompt_encoder
|
||||||
|
.forward(Some((&in_points, &in_labels)), None, None)?;
|
||||||
|
let (_low_res_mask, iou_predictions) = self.mask_decoder.forward(
|
||||||
|
&img_embeddings,
|
||||||
|
&image_pe,
|
||||||
|
&sparse_prompt_embeddings,
|
||||||
|
&dense_prompt_embeddings,
|
||||||
|
/* multimask_output */ true,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
println!("{cb:?} {iou_predictions}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove duplicates within this crop.
|
||||||
|
|
||||||
|
// Return to the original image frame.
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_masks(
|
||||||
|
&self,
|
||||||
|
img: &Tensor,
|
||||||
|
points_per_side: usize,
|
||||||
|
crop_n_layer: usize,
|
||||||
|
crop_overlap_ratio: f64,
|
||||||
|
crop_n_points_downscale_factor: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
let (_c, h, w) = img.dims3()?;
|
||||||
|
let point_grids = build_all_layer_point_grids(
|
||||||
|
points_per_side,
|
||||||
|
crop_n_layer,
|
||||||
|
crop_n_points_downscale_factor,
|
||||||
|
);
|
||||||
|
let crop_boxes = generate_crop_boxes((h, w), crop_n_layer, crop_overlap_ratio);
|
||||||
|
for crop_box in crop_boxes.into_iter() {
|
||||||
|
let layer_idx = crop_box.layer_idx;
|
||||||
|
self.process_crop(img, crop_box, &point_grids[layer_idx])?
|
||||||
|
}
|
||||||
|
// TODO: remove duplicates
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct CropBox {
|
||||||
|
x0: usize,
|
||||||
|
y0: usize,
|
||||||
|
x1: usize,
|
||||||
|
y1: usize,
|
||||||
|
layer_idx: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CropBox {
|
||||||
|
fn new(x0: usize, y0: usize, x1: usize, y1: usize, layer_idx: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
x0,
|
||||||
|
y0,
|
||||||
|
x1,
|
||||||
|
y1,
|
||||||
|
layer_idx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn generate_crop_boxes(
|
||||||
|
(im_h, im_w): (usize, usize),
|
||||||
|
n_layers: usize,
|
||||||
|
overlap_ratio: f64,
|
||||||
|
) -> Vec<CropBox> {
|
||||||
|
fn crop_len(orig_len: usize, n_crops: usize, overlap: usize) -> usize {
|
||||||
|
f64::ceil((overlap * (n_crops - 1) + orig_len) as f64 / n_crops as f64) as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
let short_side = usize::min(im_h, im_w);
|
||||||
|
|
||||||
|
let mut crop_boxes = Vec::new();
|
||||||
|
|
||||||
|
// Original image.
|
||||||
|
crop_boxes.push(CropBox::new(0, 0, im_w, im_h, 0));
|
||||||
|
|
||||||
|
for layer_idx in 1..=n_layers {
|
||||||
|
let n_crops_per_side = 1 << layer_idx;
|
||||||
|
let overlap = (overlap_ratio * short_side as f64 * 2. / n_crops_per_side as f64) as usize;
|
||||||
|
let crop_w = crop_len(im_w, n_crops_per_side, overlap);
|
||||||
|
let crop_h = crop_len(im_w, n_crops_per_side, overlap);
|
||||||
|
|
||||||
|
for i_x in 0..n_crops_per_side {
|
||||||
|
let x0 = (crop_w - overlap) * i_x;
|
||||||
|
for i_y in 0..n_crops_per_side {
|
||||||
|
let y0 = (crop_h - overlap) * i_y;
|
||||||
|
let x1 = usize::min(im_w, x0 + crop_w);
|
||||||
|
let y1 = usize::min(im_h, y0 + crop_h);
|
||||||
|
crop_boxes.push(CropBox::new(x0, y0, x1, y1, layer_idx));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crop_boxes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generates a 2D grid of points evenly spaced in [0,1]x[0,1].
|
||||||
|
fn build_point_grid(n_per_side: usize) -> Vec<(f64, f64)> {
|
||||||
|
let offset = 1f64 / (2 * n_per_side) as f64;
|
||||||
|
let mut points = Vec::with_capacity(n_per_side * n_per_side);
|
||||||
|
for i_x in 0..n_per_side {
|
||||||
|
let x = offset + i_x as f64 / n_per_side as f64;
|
||||||
|
for i_y in 0..n_per_side {
|
||||||
|
let y = offset + i_y as f64 / n_per_side as f64;
|
||||||
|
points.push((x, y))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
points
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_all_layer_point_grids(
|
||||||
|
n_per_side: usize,
|
||||||
|
n_layers: usize,
|
||||||
|
scale_per_layer: usize,
|
||||||
|
) -> Vec<Vec<(f64, f64)>> {
|
||||||
|
let mut points_by_layer = Vec::with_capacity(n_layers + 1);
|
||||||
|
for i in 0..=n_layers {
|
||||||
|
let n_points = n_per_side / scale_per_layer.pow(i as u32);
|
||||||
|
points_by_layer.push(build_point_grid(n_points))
|
||||||
|
}
|
||||||
|
points_by_layer
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{DType, IndexOp, Result, Tensor, D};
|
use candle::{Result, Tensor};
|
||||||
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -7,7 +7,6 @@ struct Attention {
|
|||||||
k_proj: Linear,
|
k_proj: Linear,
|
||||||
v_proj: Linear,
|
v_proj: Linear,
|
||||||
out_proj: Linear,
|
out_proj: Linear,
|
||||||
internal_dim: usize,
|
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,7 +27,6 @@ impl Attention {
|
|||||||
k_proj,
|
k_proj,
|
||||||
v_proj,
|
v_proj,
|
||||||
out_proj,
|
out_proj,
|
||||||
internal_dim,
|
|
||||||
num_heads,
|
num_heads,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -85,7 +83,6 @@ impl TwoWayAttentionBlock {
|
|||||||
skip_first_layer_pe: bool,
|
skip_first_layer_pe: bool,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let self_attn = Attention::new(embedding_dim, num_heads, 1, vb.pp("self_attn"))?;
|
|
||||||
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
|
let norm1 = layer_norm(embedding_dim, 1e-5, vb.pp("norm1"))?;
|
||||||
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
|
let norm2 = layer_norm(embedding_dim, 1e-5, vb.pp("norm2"))?;
|
||||||
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
|
let norm3 = layer_norm(embedding_dim, 1e-5, vb.pp("norm3"))?;
|
||||||
@ -204,7 +201,6 @@ impl TwoWayTransformer {
|
|||||||
image_pe: &Tensor,
|
image_pe: &Tensor,
|
||||||
point_embedding: &Tensor,
|
point_embedding: &Tensor,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let (bs, c, h, w) = image_embedding.dims4()?;
|
|
||||||
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
|
let image_embedding = image_embedding.flatten_from(2)?.permute((0, 2, 1))?;
|
||||||
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
|
let image_pe = image_pe.flatten_from(2)?.permute((0, 2, 1))?;
|
||||||
|
|
||||||
|
@ -19,10 +19,11 @@ pub fn device(cpu: bool) -> Result<Device> {
|
|||||||
pub fn load_image<P: AsRef<std::path::Path>>(
|
pub fn load_image<P: AsRef<std::path::Path>>(
|
||||||
p: P,
|
p: P,
|
||||||
resize_longest: Option<usize>,
|
resize_longest: Option<usize>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<(Tensor, usize, usize)> {
|
||||||
let img = image::io::Reader::open(p)?
|
let img = image::io::Reader::open(p)?
|
||||||
.decode()
|
.decode()
|
||||||
.map_err(candle::Error::wrap)?;
|
.map_err(candle::Error::wrap)?;
|
||||||
|
let (initial_h, initial_w) = (img.height() as usize, img.width() as usize);
|
||||||
let img = match resize_longest {
|
let img = match resize_longest {
|
||||||
None => img,
|
None => img,
|
||||||
Some(resize_longest) => {
|
Some(resize_longest) => {
|
||||||
@ -41,7 +42,8 @@ pub fn load_image<P: AsRef<std::path::Path>>(
|
|||||||
let (height, width) = (img.height() as usize, img.width() as usize);
|
let (height, width) = (img.height() as usize, img.width() as usize);
|
||||||
let img = img.to_rgb8();
|
let img = img.to_rgb8();
|
||||||
let data = img.into_raw();
|
let data = img.into_raw();
|
||||||
Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))
|
let data = Tensor::from_vec(data, (height, width, 3), &Device::Cpu)?.permute((2, 0, 1))?;
|
||||||
|
Ok((data, initial_h, initial_w))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
pub fn load_image_and_resize<P: AsRef<std::path::Path>>(
|
||||||
@ -80,3 +82,27 @@ pub fn save_image<P: AsRef<std::path::Path>>(img: &Tensor, p: P) -> Result<()> {
|
|||||||
image.save(p).map_err(candle::Error::wrap)?;
|
image.save(p).map_err(candle::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn save_image_resize<P: AsRef<std::path::Path>>(
|
||||||
|
img: &Tensor,
|
||||||
|
p: P,
|
||||||
|
h: usize,
|
||||||
|
w: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
let p = p.as_ref();
|
||||||
|
let (channel, height, width) = img.dims3()?;
|
||||||
|
if channel != 3 {
|
||||||
|
candle::bail!("save_image expects an input of shape (3, height, width)")
|
||||||
|
}
|
||||||
|
let img = img.permute((1, 2, 0))?.flatten_all()?;
|
||||||
|
let pixels = img.to_vec1::<u8>()?;
|
||||||
|
let image: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||||
|
match image::ImageBuffer::from_raw(width as u32, height as u32, pixels) {
|
||||||
|
Some(image) => image,
|
||||||
|
None => candle::bail!("error saving image {p:?}"),
|
||||||
|
};
|
||||||
|
let image = image::DynamicImage::from(image);
|
||||||
|
let image = image.resize_to_fill(w as u32, h as u32, image::imageops::FilterType::CatmullRom);
|
||||||
|
image.save(p).map_err(candle::Error::wrap)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user