Add Pixtral. (#2521)

* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
This commit is contained in:
Laurent Mazare
2024-09-30 19:31:14 +02:00
committed by GitHub
parent 2f49e1b534
commit 683ab698de
9 changed files with 822 additions and 19 deletions

View File

@ -0,0 +1,72 @@
use candle::{Module, Result, Tensor};
use candle_nn::{linear, Linear, VarBuilder};
use super::vision_model;
use crate::models::mistral;
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub projector_hidden_act: candle_nn::Activation,
pub text_config: mistral::Config,
pub vision_config: vision_model::Config,
pub image_token_index: usize,
pub image_seq_length: usize,
}
#[derive(Debug, Clone)]
pub struct MultiModalProjector {
linear_1: Linear,
act: candle_nn::Activation,
linear_2: Linear,
}
impl MultiModalProjector {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
Ok(Self {
linear_1,
act: cfg.projector_hidden_act,
linear_2,
})
}
}
impl Module for MultiModalProjector {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.linear_1)?
.apply(&self.act)?
.apply(&self.linear_2)
}
}
#[derive(Debug, Clone)]
pub struct Model {
pub multi_modal_projector: MultiModalProjector,
pub language_model: mistral::Model,
pub vision_tower: vision_model::Model,
pub patch_size: usize,
pub dtype: candle::DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
let vision_tower = vision_model::Model::new(
&cfg.vision_config,
vb.pp("vision_tower").to_dtype(candle::DType::F32),
)?;
let multi_modal_projector = MultiModalProjector::new(
cfg,
vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
)?;
Ok(Self {
multi_modal_projector,
language_model,
vision_tower,
patch_size: cfg.vision_config.patch_size,
dtype: vb.dtype(),
})
}
}