mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add Pixtral. (#2521)
* Add Pixtral. * More pixtral vision encoder. * Sketch a pixtral example. * Sketch a pixtral example. * Better image loading. * Support loading images embedded in safetensor files. * Clippy fixes. * Add the llava multimodal adapter. * Add more of the llava bits. * Add the pixtral config. * More pixtral inference. * Add the text generation bits. * Get the example to work. * Bugfix. * Run some bits of the model in f32. * Blessed version :) * Better rope frequency computations. * README update.
This commit is contained in:
@ -279,7 +279,7 @@ impl LLaVA {
|
||||
(),
|
||||
))?
|
||||
} else {
|
||||
todo!("not implemented in original python LLaVA yet")
|
||||
bail!("not implemented in original python LLaVA yet")
|
||||
};
|
||||
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
|
||||
let new_image_feature = new_image_feature
|
||||
|
@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D};
|
||||
use candle_nn::{Activation, VarBuilder};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn default_num_attention_heads() -> usize {
|
||||
32
|
||||
}
|
||||
|
||||
fn default_use_flash_attn() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn default_hidden_act() -> candle_nn::Activation {
|
||||
candle_nn::Activation::Silu
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
#[serde(default = "default_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
pub head_dim: Option<usize>,
|
||||
pub num_key_value_heads: usize,
|
||||
#[serde(default = "default_hidden_act")]
|
||||
pub hidden_act: Activation,
|
||||
pub max_position_embeddings: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
@ -107,14 +117,14 @@ impl RotaryEmbedding {
|
||||
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
|
||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||
.to_dtype(dtype)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((max_seq_len, 1))?;
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
Ok(Self {
|
||||
sin: freqs.sin()?,
|
||||
cos: freqs.cos()?,
|
||||
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||
})
|
||||
}
|
||||
|
||||
@ -404,6 +414,10 @@ impl Model {
|
||||
.to_dtype(self.dtype)
|
||||
}
|
||||
|
||||
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
|
||||
&self.embed_tokens
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
@ -421,6 +435,22 @@ impl Model {
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn forward_embeds(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
attn_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let (_b_size, seq_len, _) = xs.dims3()?;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
|
||||
}
|
||||
xs.narrow(1, seq_len - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.lm_head)
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
|
@ -51,6 +51,7 @@ pub mod parler_tts;
|
||||
pub mod persimmon;
|
||||
pub mod phi;
|
||||
pub mod phi3;
|
||||
pub mod pixtral;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_llama;
|
||||
|
72
candle-transformers/src/models/pixtral/llava.rs
Normal file
72
candle-transformers/src/models/pixtral/llava.rs
Normal file
@ -0,0 +1,72 @@
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::{linear, Linear, VarBuilder};
|
||||
|
||||
use super::vision_model;
|
||||
use crate::models::mistral;
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub projector_hidden_act: candle_nn::Activation,
|
||||
pub text_config: mistral::Config,
|
||||
pub vision_config: vision_model::Config,
|
||||
pub image_token_index: usize,
|
||||
pub image_seq_length: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MultiModalProjector {
|
||||
linear_1: Linear,
|
||||
act: candle_nn::Activation,
|
||||
linear_2: Linear,
|
||||
}
|
||||
|
||||
impl MultiModalProjector {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
|
||||
let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
|
||||
let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
|
||||
Ok(Self {
|
||||
linear_1,
|
||||
act: cfg.projector_hidden_act,
|
||||
linear_2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MultiModalProjector {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.linear_1)?
|
||||
.apply(&self.act)?
|
||||
.apply(&self.linear_2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
pub multi_modal_projector: MultiModalProjector,
|
||||
pub language_model: mistral::Model,
|
||||
pub vision_tower: vision_model::Model,
|
||||
pub patch_size: usize,
|
||||
pub dtype: candle::DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
|
||||
let vision_tower = vision_model::Model::new(
|
||||
&cfg.vision_config,
|
||||
vb.pp("vision_tower").to_dtype(candle::DType::F32),
|
||||
)?;
|
||||
let multi_modal_projector = MultiModalProjector::new(
|
||||
cfg,
|
||||
vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
|
||||
)?;
|
||||
Ok(Self {
|
||||
multi_modal_projector,
|
||||
language_model,
|
||||
vision_tower,
|
||||
patch_size: cfg.vision_config.patch_size,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
}
|
4
candle-transformers/src/models/pixtral/mod.rs
Normal file
4
candle-transformers/src/models/pixtral/mod.rs
Normal file
@ -0,0 +1,4 @@
|
||||
pub mod llava;
|
||||
pub mod vision_model;
|
||||
|
||||
pub use llava::{Config, Model};
|
324
candle-transformers/src/models/pixtral/vision_model.rs
Normal file
324
candle-transformers/src/models/pixtral/vision_model.rs
Normal file
@ -0,0 +1,324 @@
|
||||
use candle::{DType, Module, Result, Tensor, D};
|
||||
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
|
||||
|
||||
fn default_act() -> candle_nn::Activation {
|
||||
candle_nn::Activation::Gelu
|
||||
}
|
||||
|
||||
fn default_hidden_size() -> usize {
|
||||
1024
|
||||
}
|
||||
|
||||
fn default_intermediate_size() -> usize {
|
||||
4096
|
||||
}
|
||||
|
||||
fn default_num_channels() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn default_num_hidden_layers() -> usize {
|
||||
24
|
||||
}
|
||||
|
||||
fn default_num_attention_heads() -> usize {
|
||||
16
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
#[serde(default = "default_hidden_size")]
|
||||
pub hidden_size: usize,
|
||||
#[serde(default = "default_num_channels")]
|
||||
pub num_channels: usize,
|
||||
pub image_size: usize,
|
||||
pub patch_size: usize,
|
||||
pub rope_theta: f64,
|
||||
#[serde(default = "default_intermediate_size")]
|
||||
pub intermediate_size: usize,
|
||||
#[serde(default = "default_num_hidden_layers")]
|
||||
pub num_hidden_layers: usize,
|
||||
pub head_dim: Option<usize>,
|
||||
#[serde(default = "default_num_attention_heads")]
|
||||
pub num_attention_heads: usize,
|
||||
#[serde(default = "default_act")]
|
||||
pub hidden_act: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn pixtral_12b_2409() -> Self {
|
||||
Self {
|
||||
hidden_size: 1024,
|
||||
num_channels: 3,
|
||||
image_size: 1024,
|
||||
patch_size: 16,
|
||||
rope_theta: 10000.0,
|
||||
intermediate_size: 4096,
|
||||
num_hidden_layers: 24,
|
||||
num_attention_heads: 16,
|
||||
head_dim: None,
|
||||
// Default
|
||||
hidden_act: candle_nn::Activation::Gelu,
|
||||
}
|
||||
}
|
||||
|
||||
fn head_dim(&self) -> usize {
|
||||
self.head_dim
|
||||
.unwrap_or(self.hidden_size / self.num_attention_heads)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
scale: f64,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let head_dim = cfg.head_dim();
|
||||
let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?;
|
||||
let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?;
|
||||
let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?;
|
||||
let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
scale,
|
||||
num_heads,
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
emb: &RotaryEmbedding,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (b, patches, _) = xs.dims3()?;
|
||||
let query_states = xs.apply(&self.q_proj)?;
|
||||
let key_states = xs.apply(&self.k_proj)?;
|
||||
let value_states = xs.apply(&self.v_proj)?;
|
||||
|
||||
let shape = (b, patches, self.num_heads, self.head_dim);
|
||||
let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||
|
||||
let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
|
||||
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
|
||||
|
||||
let attn_weights = match attention_mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||
};
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights
|
||||
.matmul(&value_states)?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b, patches, ()))?
|
||||
.apply(&self.o_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
gate_proj: Linear,
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let (h, i) = (cfg.hidden_size, cfg.intermediate_size);
|
||||
let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?;
|
||||
let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?;
|
||||
let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
(xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))?
|
||||
.apply(&self.down_proj)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct AttentionLayer {
|
||||
attention_norm: RmsNorm,
|
||||
feed_forward: Mlp,
|
||||
attention: Attention,
|
||||
ffn_norm: RmsNorm,
|
||||
}
|
||||
|
||||
impl AttentionLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?;
|
||||
let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?;
|
||||
let attention = Attention::new(cfg, vb.pp("attention"))?;
|
||||
let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?;
|
||||
Ok(Self {
|
||||
attention_norm,
|
||||
feed_forward,
|
||||
attention,
|
||||
ffn_norm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
emb: &RotaryEmbedding,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = self
|
||||
.attention
|
||||
.forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
|
||||
let xs = (residual + xs)?;
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Transformer {
|
||||
layers: Vec<AttentionLayer>,
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&self,
|
||||
xs: &Tensor,
|
||||
emb: &RotaryEmbedding,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, emb, attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dtype = vb.dtype();
|
||||
let dev = vb.device();
|
||||
let dim = cfg.head_dim();
|
||||
let rope_theta = cfg.rope_theta as f32;
|
||||
let max_patches_per_side = cfg.image_size / cfg.patch_size;
|
||||
let freqs: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>();
|
||||
let freqs_h = Tensor::new(freqs_h, dev)?;
|
||||
let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>();
|
||||
let freqs_w = Tensor::new(freqs_w, dev)?;
|
||||
let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
|
||||
let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
|
||||
let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?;
|
||||
let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?;
|
||||
let inv_freq = Tensor::cat(
|
||||
&[
|
||||
freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?,
|
||||
freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?,
|
||||
],
|
||||
D::Minus1,
|
||||
)?
|
||||
.reshape(((), dim / 2))?;
|
||||
let cos = inv_freq.cos()?.to_dtype(dtype)?;
|
||||
let sin = inv_freq.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self { cos, sin })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = &self.cos;
|
||||
let sin = &self.sin;
|
||||
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
patch_conv: candle_nn::Conv2d,
|
||||
ln_pre: RmsNorm,
|
||||
transformer: Transformer,
|
||||
patch_positional_embedding: RotaryEmbedding,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let conv2d_cfg = candle_nn::Conv2dConfig {
|
||||
stride: cfg.patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let patch_conv = candle_nn::conv2d_no_bias(
|
||||
cfg.num_channels,
|
||||
cfg.hidden_size,
|
||||
cfg.patch_size,
|
||||
conv2d_cfg,
|
||||
vb.pp("patch_conv"),
|
||||
)?;
|
||||
let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?;
|
||||
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
|
||||
let patch_positional_embedding =
|
||||
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
|
||||
Ok(Self {
|
||||
patch_conv,
|
||||
ln_pre,
|
||||
transformer,
|
||||
patch_positional_embedding,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Model {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let patch_embeds = xs.apply(&self.patch_conv)?;
|
||||
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
|
||||
self.transformer
|
||||
.forward(&patch_embeds, &self.patch_positional_embedding, None)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user