mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add negative prompts to segment-anything. (#1000)
This commit is contained in:
@ -27,10 +27,16 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
generate_masks: bool,
|
generate_masks: bool,
|
||||||
|
|
||||||
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image).
|
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
|
||||||
|
/// should be part of the generated mask.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
point: Vec<String>,
|
point: Vec<String>,
|
||||||
|
|
||||||
|
/// List of x,y coordinates, between 0 and 1 (0.5 is at the middle of the image). These points
|
||||||
|
/// should not be part of the generated mask and should be part of the background instead.
|
||||||
|
#[arg(long)]
|
||||||
|
neg_point: Vec<String>,
|
||||||
|
|
||||||
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
/// The detection threshold for the mask, 0 is the default value, negative values mean a larger
|
||||||
/// mask, positive makes the mask more selective.
|
/// mask, positive makes the mask more selective.
|
||||||
#[arg(long, default_value_t = 0.)]
|
#[arg(long, default_value_t = 0.)]
|
||||||
@ -107,16 +113,17 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let points = args
|
let iter_points = args.point.iter().map(|p| (p, true));
|
||||||
.point
|
let iter_neg_points = args.neg_point.iter().map(|p| (p, false));
|
||||||
.iter()
|
let points = iter_points
|
||||||
.map(|point| {
|
.chain(iter_neg_points)
|
||||||
|
.map(|(point, b)| {
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
let xy = point.split(',').collect::<Vec<_>>();
|
let xy = point.split(',').collect::<Vec<_>>();
|
||||||
if xy.len() != 2 {
|
if xy.len() != 2 {
|
||||||
anyhow::bail!("expected format for points is 0.4,0.2")
|
anyhow::bail!("expected format for points is 0.4,0.2")
|
||||||
}
|
}
|
||||||
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?))
|
Ok((f64::from_str(xy[0])?, f64::from_str(xy[1])?, b))
|
||||||
})
|
})
|
||||||
.collect::<anyhow::Result<Vec<_>>>()?;
|
.collect::<anyhow::Result<Vec<_>>>()?;
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
@ -158,15 +165,15 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (x, y) in points {
|
for (x, y, b) in points {
|
||||||
let x = (x * img.width() as f64) as i32;
|
let x = (x * img.width() as f64) as i32;
|
||||||
let y = (y * img.height() as f64) as i32;
|
let y = (y * img.height() as f64) as i32;
|
||||||
imageproc::drawing::draw_filled_circle_mut(
|
let color = if b {
|
||||||
&mut img,
|
image::Rgba([255, 0, 0, 200])
|
||||||
(x, y),
|
} else {
|
||||||
3,
|
image::Rgba([0, 255, 0, 200])
|
||||||
image::Rgba([255, 0, 0, 200]),
|
};
|
||||||
);
|
imageproc::drawing::draw_filled_circle_mut(&mut img, (x, y), 3, color);
|
||||||
}
|
}
|
||||||
img.save("sam_merged.jpg")?
|
img.save("sam_merged.jpg")?
|
||||||
}
|
}
|
||||||
|
@ -130,7 +130,7 @@ impl Sam {
|
|||||||
pub fn forward(
|
pub fn forward(
|
||||||
&self,
|
&self,
|
||||||
img: &Tensor,
|
img: &Tensor,
|
||||||
points: &[(f64, f64)],
|
points: &[(f64, f64, bool)],
|
||||||
multimask_output: bool,
|
multimask_output: bool,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let (_c, original_h, original_w) = img.dims3()?;
|
let (_c, original_h, original_w) = img.dims3()?;
|
||||||
@ -150,12 +150,17 @@ impl Sam {
|
|||||||
Ok((mask, iou))
|
Ok((mask, iou))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Generate the mask and IOU predictions from some image embeddings and prompt.
|
||||||
|
///
|
||||||
|
/// The prompt is specified as a list of points `(x, y, b)`. `x` and `y` are the point
|
||||||
|
/// coordinates (between 0 and 1) and `b` is `true` for points that should be part of the mask
|
||||||
|
/// and `false` for points that should be part of the background and so excluded from the mask.
|
||||||
pub fn forward_for_embeddings(
|
pub fn forward_for_embeddings(
|
||||||
&self,
|
&self,
|
||||||
img_embeddings: &Tensor,
|
img_embeddings: &Tensor,
|
||||||
original_h: usize,
|
original_h: usize,
|
||||||
original_w: usize,
|
original_w: usize,
|
||||||
points: &[(f64, f64)],
|
points: &[(f64, f64, bool)],
|
||||||
multimask_output: bool,
|
multimask_output: bool,
|
||||||
) -> Result<(Tensor, Tensor)> {
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||||
@ -165,14 +170,18 @@ impl Sam {
|
|||||||
let n_points = points.len();
|
let n_points = points.len();
|
||||||
let xys = points
|
let xys = points
|
||||||
.iter()
|
.iter()
|
||||||
.flat_map(|(x, y)| {
|
.flat_map(|(x, y, _b)| {
|
||||||
let x = (*x as f32) * (original_w as f32);
|
let x = (*x as f32) * (original_w as f32);
|
||||||
let y = (*y as f32) * (original_h as f32);
|
let y = (*y as f32) * (original_h as f32);
|
||||||
[x, y]
|
[x, y]
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
let labels = points
|
||||||
|
.iter()
|
||||||
|
.map(|(_x, _y, b)| if *b { 1f32 } else { 0f32 })
|
||||||
|
.collect::<Vec<_>>();
|
||||||
let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?;
|
let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?;
|
||||||
let labels = Tensor::ones((1, n_points), DType::F32, img_embeddings.device())?;
|
let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?;
|
||||||
Some((points, labels))
|
Some((points, labels))
|
||||||
};
|
};
|
||||||
let points = points.as_ref().map(|(x, y)| (x, y));
|
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||||
|
@ -94,7 +94,7 @@ impl Model {
|
|||||||
&embeddings.data,
|
&embeddings.data,
|
||||||
embeddings.height as usize,
|
embeddings.height as usize,
|
||||||
embeddings.width as usize,
|
embeddings.width as usize,
|
||||||
&[(x, y)],
|
&[(x, y, true)],
|
||||||
false,
|
false,
|
||||||
)?;
|
)?;
|
||||||
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
|
let iou = iou_predictions.flatten(0, 1)?.to_vec1::<f32>()?[0];
|
||||||
|
Reference in New Issue
Block a user