mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Add negative prompts to segment-anything. (#1000)
This commit is contained in:
@ -130,7 +130,7 @@ impl Sam {
|
||||
pub fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
points: &[(f64, f64)],
|
||||
points: &[(f64, f64, bool)],
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_c, original_h, original_w) = img.dims3()?;
|
||||
@ -150,12 +150,17 @@ impl Sam {
|
||||
Ok((mask, iou))
|
||||
}
|
||||
|
||||
/// Generate the mask and IOU predictions from some image embeddings and prompt.
|
||||
///
|
||||
/// The prompt is specified as a list of points `(x, y, b)`. `x` and `y` are the point
|
||||
/// coordinates (between 0 and 1) and `b` is `true` for points that should be part of the mask
|
||||
/// and `false` for points that should be part of the background and so excluded from the mask.
|
||||
pub fn forward_for_embeddings(
|
||||
&self,
|
||||
img_embeddings: &Tensor,
|
||||
original_h: usize,
|
||||
original_w: usize,
|
||||
points: &[(f64, f64)],
|
||||
points: &[(f64, f64, bool)],
|
||||
multimask_output: bool,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let image_pe = self.prompt_encoder.get_dense_pe()?;
|
||||
@ -165,14 +170,18 @@ impl Sam {
|
||||
let n_points = points.len();
|
||||
let xys = points
|
||||
.iter()
|
||||
.flat_map(|(x, y)| {
|
||||
.flat_map(|(x, y, _b)| {
|
||||
let x = (*x as f32) * (original_w as f32);
|
||||
let y = (*y as f32) * (original_h as f32);
|
||||
[x, y]
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let labels = points
|
||||
.iter()
|
||||
.map(|(_x, _y, b)| if *b { 1f32 } else { 0f32 })
|
||||
.collect::<Vec<_>>();
|
||||
let points = Tensor::from_vec(xys, (1, n_points, 2), img_embeddings.device())?;
|
||||
let labels = Tensor::ones((1, n_points), DType::F32, img_embeddings.device())?;
|
||||
let labels = Tensor::from_vec(labels, (1, n_points), img_embeddings.device())?;
|
||||
Some((points, labels))
|
||||
};
|
||||
let points = points.as_ref().map(|(x, y)| (x, y));
|
||||
|
Reference in New Issue
Block a user