Add Pixtral. (#2521)

* Add Pixtral.

* More pixtral vision encoder.

* Sketch a pixtral example.

* Sketch a pixtral example.

* Better image loading.

* Support loading images embedded in safetensor files.

* Clippy fixes.

* Add the llava multimodal adapter.

* Add more of the llava bits.

* Add the pixtral config.

* More pixtral inference.

* Add the text generation bits.

* Get the example to work.

* Bugfix.

* Run some bits of the model in f32.

* Blessed version :)

* Better rope frequency computations.

* README update.
This commit is contained in:
Laurent Mazare
2024-09-30 19:31:14 +02:00
committed by GitHub
parent 2f49e1b534
commit 683ab698de
9 changed files with 822 additions and 19 deletions

View File

@ -0,0 +1,28 @@
# pixtral
Pixtral-12B is a 12B text+vision model.
[Blog Post](https://mistral.ai/news/pixtral-12b/) -
[HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) -
[HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b).
```bash
cargo run --profile=release-with-debug --features cuda --example pixtral -- \
--image candle-examples/examples/flux/assets/flux-robot.jpg
```
```
Describe the image.
The image depicts a charming, rustic robot standing on a sandy beach at sunset.
The robot has a vintage, steampunk aesthetic with visible gears and mechanical
parts. It is holding a small lantern in one hand, which emits a warm glow, and
its other arm is extended forward as if reaching out or guiding the way. The
robot's body is adorned with the word "RUST" in bright orange letters, adding to
its rustic theme.
The background features a dramatic sky filled with clouds, illuminated by the
setting sun, casting a golden hue over the scene. Gentle waves lap against the
shore, creating a serene and picturesque atmosphere. The overall mood of the
image is whimsical and nostalgic, evoking a sense of adventure and tranquility.
```

View File

@ -0,0 +1,336 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::pixtral::{vision_model, Config, Model};
use candle::{DType, Device, Module, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
image: Tensor,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
image: Tensor,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
image,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let get_token = |v| match self.tokenizer.get_token(v) {
Some(token) => Ok(token),
None => anyhow::bail!("cannot find the {v} token"),
};
let bos_token = get_token("<s>")?;
let eos_token = get_token("</s>")?;
let inst_token = get_token("[INST]")?;
let end_inst_token = get_token("[/INST]")?;
let img_break = get_token("[IMG_BREAK]")?;
let img_end = get_token("[IMG_END]")?;
let start_gen = std::time::Instant::now();
let mut pos = 0;
for index in 0..sample_len {
let logits = if index > 0 {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.language_model.forward(&input, pos)?;
pos += context_size;
logits
} else {
let (_b, _c, h, w) = self.image.dims4()?;
let h = h / self.model.patch_size;
let w = w / self.model.patch_size;
let image_embeds = self.model.vision_tower.forward(&self.image)?;
let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?;
println!("generated image embeddings {image_embeds:?}");
let image_embeds = image_embeds.to_dtype(self.model.dtype)?;
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let break_embeds = {
let input = Tensor::new(&[img_break], &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let start_embeds = {
let mut in_tokens = vec![bos_token, inst_token];
in_tokens.extend_from_slice(tokens.as_slice());
let input = Tensor::new(in_tokens.as_slice(), &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let end_embeds = {
let input =
Tensor::new(&[img_end, end_inst_token], &self.device)?.unsqueeze(0)?;
self.model.language_model.embed_tokens().forward(&input)?
};
let mut input_embeds = vec![start_embeds];
for h_idx in 0..h {
if h_idx > 0 {
input_embeds.push(break_embeds.clone())
}
let row = image_embeds.narrow(1, h_idx * w, w)?;
input_embeds.push(row);
}
input_embeds.push(end_embeds);
let input_embeds = Tensor::cat(&input_embeds, 1)?;
let logits = self
.model
.language_model
.forward_embeds(&input_embeds, None, pos)?;
pos += input_embeds.dim(1)?;
logits
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long, default_value = "Describe the image.\n")]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
config_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long)]
image: String,
#[arg(long)]
vision_only: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "mistral-community/pixtral-12b".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let device = candle_examples::device(args.cpu)?;
let dtype = if device.supports_bf16() && !args.vision_only {
DType::BF16
} else {
DType::F32
};
let config: Config = match args.config_file {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let image = if args.image.ends_with(".safetensors") {
match candle::safetensors::load(&args.image, &device)?.remove("img") {
None => anyhow::bail!("no img tensor in {}", args.image),
Some(v) => v,
}
} else {
candle_examples::imagenet::load_image_with_std_mean(
&args.image,
1024,
&[0.48145466, 0.4578275, 0.40821073],
&[0.26862954, 0.261_302_6, 0.275_777_1],
)?
};
let image = image.to_device(&device)?.unsqueeze(0)?;
println!("loaded image with shape {:?}", image);
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
if args.vision_only {
let start = std::time::Instant::now();
let model = vision_model::Model::new(&config.vision_config, vb.pp("vision_tower"))?;
println!("loaded the model in {:?}", start.elapsed());
let embs = model.forward(&image)?;
println!("EMBS\n{embs}");
} else {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
image,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
}
Ok(())
}

View File

@ -14,6 +14,7 @@ use std::sync::Arc;
pub struct VarBuilderArgs<'a, B: Backend> {
data: Arc<TensorData<B>>,
path: Vec<String>,
pub dtype: DType,
_phantom: std::marker::PhantomData<&'a B>,
}
@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: self.path.clone(),
dtype: self.dtype,
_phantom: self._phantom,
}
}
@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
struct TensorData<B: Backend> {
backend: B,
pub dtype: DType,
pub device: Device,
}
@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
let data = TensorData {
backend,
dtype,
device: dev.clone(),
};
Self {
data: Arc::new(data),
path: vec![],
dtype,
_phantom: std::marker::PhantomData,
}
}
@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![],
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path: vec![prefix.to_string()],
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
Self {
data: self.data.clone(),
path,
dtype: self.dtype,
_phantom: std::marker::PhantomData,
}
}
@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
/// The dtype used by default.
pub fn dtype(&self) -> DType {
self.data.dtype
self.dtype
}
/// Clone the VarBuilder tweaking its dtype
pub fn to_dtype(&self, dtype: DType) -> Self {
Self {
data: self.data.clone(),
path: self.path.clone(),
dtype,
_phantom: std::marker::PhantomData,
}
}
fn path(&self, tensor_name: &str) -> String {
@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
self.get_with_hints_dtype(s, name, hints, self.dtype)
}
/// Retrieve the tensor associated with the given name at the current path.
@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> {
dtype: DType,
device: Device,
) -> Self {
let data = TensorData {
backend,
dtype,
device,
};
let data = TensorData { backend, device };
Self {
data: Arc::new(data),
path: vec![],
dtype,
_phantom: std::marker::PhantomData,
}
}
@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> {
let path = self.path.clone();
let backend = Rename::new(self, renamer);
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
let data = TensorData {
backend,
dtype,
device,
};
let data = TensorData { backend, device };
Self {
data: Arc::new(data),
dtype,
path,
_phantom: std::marker::PhantomData,
}

View File

@ -279,7 +279,7 @@ impl LLaVA {
(),
))?
} else {
todo!("not implemented in original python LLaVA yet")
bail!("not implemented in original python LLaVA yet")
};
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
let new_image_feature = new_image_feature

View File

@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
fn default_num_attention_heads() -> usize {
32
}
fn default_use_flash_attn() -> bool {
false
}
fn default_hidden_act() -> candle_nn::Activation {
candle_nn::Activation::Silu
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
pub head_dim: Option<usize>,
pub num_key_value_heads: usize,
#[serde(default = "default_hidden_act")]
pub hidden_act: Activation,
pub max_position_embeddings: usize,
pub rms_norm_eps: f64,
@ -107,14 +117,14 @@ impl RotaryEmbedding {
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.to_dtype(DType::F32)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
sin: freqs.sin()?.to_dtype(dtype)?,
cos: freqs.cos()?.to_dtype(dtype)?,
})
}
@ -404,6 +414,10 @@ impl Model {
.to_dtype(self.dtype)
}
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
&self.embed_tokens
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
@ -421,6 +435,22 @@ impl Model {
.apply(&self.lm_head)
}
pub fn forward_embeds(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_b_size, seq_len, _) = xs.dims3()?;
let mut xs = xs.clone();
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()

View File

@ -51,6 +51,7 @@ pub mod parler_tts;
pub mod persimmon;
pub mod phi;
pub mod phi3;
pub mod pixtral;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_llama;

View File

@ -0,0 +1,72 @@
use candle::{Module, Result, Tensor};
use candle_nn::{linear, Linear, VarBuilder};
use super::vision_model;
use crate::models::mistral;
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
pub projector_hidden_act: candle_nn::Activation,
pub text_config: mistral::Config,
pub vision_config: vision_model::Config,
pub image_token_index: usize,
pub image_seq_length: usize,
}
#[derive(Debug, Clone)]
pub struct MultiModalProjector {
linear_1: Linear,
act: candle_nn::Activation,
linear_2: Linear,
}
impl MultiModalProjector {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
Ok(Self {
linear_1,
act: cfg.projector_hidden_act,
linear_2,
})
}
}
impl Module for MultiModalProjector {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.linear_1)?
.apply(&self.act)?
.apply(&self.linear_2)
}
}
#[derive(Debug, Clone)]
pub struct Model {
pub multi_modal_projector: MultiModalProjector,
pub language_model: mistral::Model,
pub vision_tower: vision_model::Model,
pub patch_size: usize,
pub dtype: candle::DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
let vision_tower = vision_model::Model::new(
&cfg.vision_config,
vb.pp("vision_tower").to_dtype(candle::DType::F32),
)?;
let multi_modal_projector = MultiModalProjector::new(
cfg,
vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
)?;
Ok(Self {
multi_modal_projector,
language_model,
vision_tower,
patch_size: cfg.vision_config.patch_size,
dtype: vb.dtype(),
})
}
}

View File

@ -0,0 +1,4 @@
pub mod llava;
pub mod vision_model;
pub use llava::{Config, Model};

View File

@ -0,0 +1,324 @@
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
fn default_act() -> candle_nn::Activation {
candle_nn::Activation::Gelu
}
fn default_hidden_size() -> usize {
1024
}
fn default_intermediate_size() -> usize {
4096
}
fn default_num_channels() -> usize {
3
}
fn default_num_hidden_layers() -> usize {
24
}
fn default_num_attention_heads() -> usize {
16
}
#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
#[serde(default = "default_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_num_channels")]
pub num_channels: usize,
pub image_size: usize,
pub patch_size: usize,
pub rope_theta: f64,
#[serde(default = "default_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_num_hidden_layers")]
pub num_hidden_layers: usize,
pub head_dim: Option<usize>,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_act")]
pub hidden_act: candle_nn::Activation,
}
impl Config {
pub fn pixtral_12b_2409() -> Self {
Self {
hidden_size: 1024,
num_channels: 3,
image_size: 1024,
patch_size: 16,
rope_theta: 10000.0,
intermediate_size: 4096,
num_hidden_layers: 24,
num_attention_heads: 16,
head_dim: None,
// Default
hidden_act: candle_nn::Activation::Gelu,
}
}
fn head_dim(&self) -> usize {
self.head_dim
.unwrap_or(self.hidden_size / self.num_attention_heads)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
scale: f64,
num_heads: usize,
head_dim: usize,
}
impl Attention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let h = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let head_dim = cfg.head_dim();
let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?;
let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?;
let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?;
let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?;
let scale = (head_dim as f64).powf(-0.5);
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
scale,
num_heads,
head_dim,
})
}
fn forward(
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b, patches, _) = xs.dims3()?;
let query_states = xs.apply(&self.q_proj)?;
let key_states = xs.apply(&self.k_proj)?;
let value_states = xs.apply(&self.v_proj)?;
let shape = (b, patches, self.num_heads, self.head_dim);
let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights
.matmul(&value_states)?
.transpose(1, 2)?
.reshape((b, patches, ()))?
.apply(&self.o_proj)
}
}
#[derive(Debug, Clone)]
struct Mlp {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: candle_nn::Activation,
}
impl Mlp {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let (h, i) = (cfg.hidden_size, cfg.intermediate_size);
let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?;
let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?;
let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
(xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))?
.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct AttentionLayer {
attention_norm: RmsNorm,
feed_forward: Mlp,
attention: Attention,
ffn_norm: RmsNorm,
}
impl AttentionLayer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?;
let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?;
let attention = Attention::new(cfg, vb.pp("attention"))?;
let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?;
Ok(Self {
attention_norm,
feed_forward,
attention,
ffn_norm,
})
}
fn forward(
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs;
let xs = self
.attention
.forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
let xs = (residual + xs)?;
let residual = &xs;
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
xs + residual
}
}
#[derive(Debug, Clone)]
struct Transformer {
layers: Vec<AttentionLayer>,
}
impl Transformer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb = vb.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?;
layers.push(layer)
}
Ok(Self { layers })
}
fn forward(
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, emb, attention_mask)?
}
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
cos: Tensor,
sin: Tensor,
}
impl RotaryEmbedding {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dtype = vb.dtype();
let dev = vb.device();
let dim = cfg.head_dim();
let rope_theta = cfg.rope_theta as f32;
let max_patches_per_side = cfg.image_size / cfg.patch_size;
let freqs: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
.collect();
let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>();
let freqs_h = Tensor::new(freqs_h, dev)?;
let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>();
let freqs_w = Tensor::new(freqs_w, dev)?;
let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?;
let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?;
let inv_freq = Tensor::cat(
&[
freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?,
freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?,
],
D::Minus1,
)?
.reshape(((), dim / 2))?;
let cos = inv_freq.cos()?.to_dtype(dtype)?;
let sin = inv_freq.sin()?.to_dtype(dtype)?;
Ok(Self { cos, sin })
}
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
let cos = &self.cos;
let sin = &self.sin;
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
pub struct Model {
patch_conv: candle_nn::Conv2d,
ln_pre: RmsNorm,
transformer: Transformer,
patch_positional_embedding: RotaryEmbedding,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let conv2d_cfg = candle_nn::Conv2dConfig {
stride: cfg.patch_size,
..Default::default()
};
let patch_conv = candle_nn::conv2d_no_bias(
cfg.num_channels,
cfg.hidden_size,
cfg.patch_size,
conv2d_cfg,
vb.pp("patch_conv"),
)?;
let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?;
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
let patch_positional_embedding =
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
Ok(Self {
patch_conv,
ln_pre,
transformer,
patch_positional_embedding,
})
}
}
impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let patch_embeds = xs.apply(&self.patch_conv)?;
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
self.transformer
.forward(&patch_embeds, &self.patch_positional_embedding, None)
}
}