mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
add dynamic position encoding to Siglip (#2770)
* add dynamic position encoding * remove debug messages
This commit is contained in:
@ -29,6 +29,9 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long, use_value_delimiter = true)]
|
#[arg(long, use_value_delimiter = true)]
|
||||||
sequences: Option<Vec<String>>,
|
sequences: Option<Vec<String>>,
|
||||||
|
|
||||||
|
#[arg(short, long)]
|
||||||
|
image_size: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
fn load_image<T: AsRef<std::path::Path>>(path: T, image_size: usize) -> anyhow::Result<Tensor> {
|
||||||
@ -81,7 +84,11 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
"candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(),
|
||||||
],
|
],
|
||||||
};
|
};
|
||||||
let images = load_images(&vec_imgs, config.vision_config.image_size)?.to_device(&device)?;
|
let images = load_images(
|
||||||
|
&vec_imgs,
|
||||||
|
args.image_size.unwrap_or(config.vision_config.image_size),
|
||||||
|
)?
|
||||||
|
.to_device(&device)?;
|
||||||
let vb =
|
let vb =
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)? };
|
||||||
let model = siglip::Model::new(&config, vb)?;
|
let model = siglip::Model::new(&config, vb)?;
|
||||||
|
@ -434,8 +434,9 @@ impl Encoder {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct VisionEmbeddings {
|
struct VisionEmbeddings {
|
||||||
patch_embedding: candle_nn::Conv2d,
|
patch_embedding: candle_nn::Conv2d,
|
||||||
position_embedding: candle_nn::Embedding,
|
position_embedding: Tensor,
|
||||||
position_ids: Tensor,
|
patch_size: usize,
|
||||||
|
base_num_patches_per_side: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VisionEmbeddings {
|
impl VisionEmbeddings {
|
||||||
@ -451,25 +452,52 @@ impl VisionEmbeddings {
|
|||||||
conv2d_cfg,
|
conv2d_cfg,
|
||||||
vb.pp("patch_embedding"),
|
vb.pp("patch_embedding"),
|
||||||
)?;
|
)?;
|
||||||
let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
|
let num_patches_per_side = cfg.image_size / cfg.patch_size;
|
||||||
let position_ids = Tensor::arange(0, num_patches as i64, vb.device())?;
|
let embedder = candle_nn::embedding(
|
||||||
let position_embedding =
|
num_patches_per_side.pow(2),
|
||||||
candle_nn::embedding(num_patches, cfg.hidden_size(), vb.pp("position_embedding"))?;
|
cfg.hidden_size(),
|
||||||
|
vb.pp("position_embedding"),
|
||||||
|
)?;
|
||||||
|
let position_embedding = embedder.embeddings();
|
||||||
|
let position_embedding = position_embedding
|
||||||
|
.reshape((
|
||||||
|
1,
|
||||||
|
num_patches_per_side,
|
||||||
|
num_patches_per_side,
|
||||||
|
cfg.hidden_size(),
|
||||||
|
))?
|
||||||
|
.permute((0, 3, 1, 2))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_embedding,
|
patch_embedding,
|
||||||
position_embedding,
|
position_embedding,
|
||||||
position_ids,
|
patch_size: cfg.patch_size,
|
||||||
|
base_num_patches_per_side: num_patches_per_side,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for VisionEmbeddings {
|
impl Module for VisionEmbeddings {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
//embed tokens
|
||||||
let (_batch, _channels, _height, _width) = xs.dims4()?;
|
let (_batch, _channels, _height, _width) = xs.dims4()?;
|
||||||
let embeddings = xs.apply(&self.patch_embedding)?;
|
let embeddings = xs.apply(&self.patch_embedding)?;
|
||||||
let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?;
|
// interpolate position embeddings for the current image size (if needed)
|
||||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
let num_patches_h = _height / self.patch_size;
|
||||||
embeddings.broadcast_add(&position_embedding)
|
let num_patches_w = _width / self.patch_size;
|
||||||
|
let resized_position_embedding = if num_patches_w == self.base_num_patches_per_side
|
||||||
|
&& num_patches_h == self.base_num_patches_per_side
|
||||||
|
{
|
||||||
|
self.position_embedding.clone()
|
||||||
|
} else {
|
||||||
|
self.position_embedding
|
||||||
|
.interpolate2d(num_patches_h, num_patches_w)?
|
||||||
|
};
|
||||||
|
// Add position embeddings to tokens and flatten from 2D patches to 1D sequence
|
||||||
|
let embeddings = embeddings
|
||||||
|
.broadcast_add(&resized_position_embedding)?
|
||||||
|
.flatten_from(2)?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
Ok(embeddings)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user