mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Fix position encodings for Pixtral (#2678)
* init commit: add position id in meshgrid * pass in subsampled positions * clippy fix * clippy fix
This commit is contained in:
@ -1,8 +1,8 @@
|
|||||||
use candle::{DType, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
|
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
|
||||||
|
|
||||||
fn default_act() -> candle_nn::Activation {
|
fn default_act() -> candle_nn::Activation {
|
||||||
candle_nn::Activation::Gelu
|
candle_nn::Activation::Silu
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_hidden_size() -> usize {
|
fn default_hidden_size() -> usize {
|
||||||
@ -58,7 +58,7 @@ impl Config {
|
|||||||
num_attention_heads: 16,
|
num_attention_heads: 16,
|
||||||
head_dim: None,
|
head_dim: None,
|
||||||
// Default
|
// Default
|
||||||
hidden_act: candle_nn::Activation::Gelu,
|
hidden_act: candle_nn::Activation::Silu,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,6 +104,7 @@ impl Attention {
|
|||||||
&self,
|
&self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
emb: &RotaryEmbedding,
|
emb: &RotaryEmbedding,
|
||||||
|
subsampled_positions: Option<&Tensor>,
|
||||||
attention_mask: Option<&Tensor>,
|
attention_mask: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let (b, patches, _) = xs.dims3()?;
|
let (b, patches, _) = xs.dims3()?;
|
||||||
@ -116,7 +117,8 @@ impl Attention {
|
|||||||
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||||
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
|
let (query_states, key_states) =
|
||||||
|
emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
|
||||||
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
|
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
|
||||||
|
|
||||||
let attn_weights = match attention_mask {
|
let attn_weights = match attention_mask {
|
||||||
@ -189,12 +191,16 @@ impl AttentionLayer {
|
|||||||
&self,
|
&self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
emb: &RotaryEmbedding,
|
emb: &RotaryEmbedding,
|
||||||
|
subsampled_positions: Option<&Tensor>,
|
||||||
attention_mask: Option<&Tensor>,
|
attention_mask: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs = self
|
let xs = self.attention.forward(
|
||||||
.attention
|
&xs.apply(&self.attention_norm)?,
|
||||||
.forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
|
emb,
|
||||||
|
subsampled_positions,
|
||||||
|
attention_mask,
|
||||||
|
)?;
|
||||||
let xs = (residual + xs)?;
|
let xs = (residual + xs)?;
|
||||||
let residual = &xs;
|
let residual = &xs;
|
||||||
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
|
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
|
||||||
@ -222,11 +228,12 @@ impl Transformer {
|
|||||||
&self,
|
&self,
|
||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
emb: &RotaryEmbedding,
|
emb: &RotaryEmbedding,
|
||||||
|
subsampled_positions: Option<&Tensor>,
|
||||||
attention_mask: Option<&Tensor>,
|
attention_mask: Option<&Tensor>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let mut xs = xs.clone();
|
let mut xs = xs.clone();
|
||||||
for layer in self.layers.iter() {
|
for layer in self.layers.iter() {
|
||||||
xs = layer.forward(&xs, emb, attention_mask)?
|
xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
|
||||||
}
|
}
|
||||||
Ok(xs)
|
Ok(xs)
|
||||||
}
|
}
|
||||||
@ -270,10 +277,20 @@ impl RotaryEmbedding {
|
|||||||
Ok(Self { cos, sin })
|
Ok(Self { cos, sin })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
subsampled_positions: Option<&Tensor>,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
|
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
|
||||||
let cos = &self.cos;
|
let (cos, sin) = match subsampled_positions {
|
||||||
let sin = &self.sin;
|
None => (&self.cos, &self.sin),
|
||||||
|
Some(pos) => (
|
||||||
|
&self.cos.index_select(pos, 0)?,
|
||||||
|
&self.sin.index_select(pos, 0)?,
|
||||||
|
),
|
||||||
|
};
|
||||||
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
|
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
|
||||||
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
|
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
@ -286,6 +303,7 @@ pub struct Model {
|
|||||||
ln_pre: RmsNorm,
|
ln_pre: RmsNorm,
|
||||||
transformer: Transformer,
|
transformer: Transformer,
|
||||||
patch_positional_embedding: RotaryEmbedding,
|
patch_positional_embedding: RotaryEmbedding,
|
||||||
|
max_image_width: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -305,20 +323,44 @@ impl Model {
|
|||||||
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
|
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
|
||||||
let patch_positional_embedding =
|
let patch_positional_embedding =
|
||||||
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
|
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
|
||||||
|
let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_conv,
|
patch_conv,
|
||||||
ln_pre,
|
ln_pre,
|
||||||
transformer,
|
transformer,
|
||||||
patch_positional_embedding,
|
patch_positional_embedding,
|
||||||
|
max_image_width,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn position_ids_in_meshgrid(
|
||||||
|
&self,
|
||||||
|
num_patches_h: usize,
|
||||||
|
num_patches_w: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let idx = Tensor::arange(0, num_patches_h as u32, device)?;
|
||||||
|
let idy = Tensor::arange(0, num_patches_w as u32, device)?;
|
||||||
|
let mesh = Tensor::meshgrid(&[idx, idy], false)?;
|
||||||
|
let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
|
||||||
|
Ok(ids)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Model {
|
impl Module for Model {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let patch_embeds = xs.apply(&self.patch_conv)?;
|
let patch_embeds = xs.apply(&self.patch_conv)?;
|
||||||
|
let subsampled_positions = Some(self.position_ids_in_meshgrid(
|
||||||
|
patch_embeds.dim(2)?,
|
||||||
|
patch_embeds.dim(3)?,
|
||||||
|
patch_embeds.device(),
|
||||||
|
)?);
|
||||||
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
|
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
|
||||||
self.transformer
|
self.transformer.forward(
|
||||||
.forward(&patch_embeds, &self.patch_positional_embedding, None)
|
&patch_embeds,
|
||||||
|
&self.patch_positional_embedding,
|
||||||
|
subsampled_positions.as_ref(),
|
||||||
|
None,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user