mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add Pixtral. (#2521)
* Add Pixtral. * More pixtral vision encoder. * Sketch a pixtral example. * Sketch a pixtral example. * Better image loading. * Support loading images embedded in safetensor files. * Clippy fixes. * Add the llava multimodal adapter. * Add more of the llava bits. * Add the pixtral config. * More pixtral inference. * Add the text generation bits. * Get the example to work. * Bugfix. * Run some bits of the model in f32. * Blessed version :) * Better rope frequency computations. * README update.
This commit is contained in:
28
candle-examples/examples/pixtral/README.md
Normal file
28
candle-examples/examples/pixtral/README.md
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# pixtral
|
||||||
|
|
||||||
|
Pixtral-12B is a 12B text+vision model.
|
||||||
|
|
||||||
|
[Blog Post](https://mistral.ai/news/pixtral-12b/) -
|
||||||
|
[HF Model Card](https://huggingface.co/mistralai/Pixtral-12B-2409) -
|
||||||
|
[HF Community Model Card](https://huggingface.co/mistral-community/pixtral-12b).
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --profile=release-with-debug --features cuda --example pixtral -- \
|
||||||
|
--image candle-examples/examples/flux/assets/flux-robot.jpg
|
||||||
|
```
|
||||||
|
|
||||||
|
```
|
||||||
|
Describe the image.
|
||||||
|
|
||||||
|
The image depicts a charming, rustic robot standing on a sandy beach at sunset.
|
||||||
|
The robot has a vintage, steampunk aesthetic with visible gears and mechanical
|
||||||
|
parts. It is holding a small lantern in one hand, which emits a warm glow, and
|
||||||
|
its other arm is extended forward as if reaching out or guiding the way. The
|
||||||
|
robot's body is adorned with the word "RUST" in bright orange letters, adding to
|
||||||
|
its rustic theme.
|
||||||
|
|
||||||
|
The background features a dramatic sky filled with clouds, illuminated by the
|
||||||
|
setting sun, casting a golden hue over the scene. Gentle waves lap against the
|
||||||
|
shore, creating a serene and picturesque atmosphere. The overall mood of the
|
||||||
|
image is whimsical and nostalgic, evoking a sense of adventure and tranquility.
|
||||||
|
```
|
336
candle-examples/examples/pixtral/main.rs
Normal file
336
candle-examples/examples/pixtral/main.rs
Normal file
@ -0,0 +1,336 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::pixtral::{vision_model, Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Module, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
image: Tensor,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
image: Tensor,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
image,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let get_token = |v| match self.tokenizer.get_token(v) {
|
||||||
|
Some(token) => Ok(token),
|
||||||
|
None => anyhow::bail!("cannot find the {v} token"),
|
||||||
|
};
|
||||||
|
let bos_token = get_token("<s>")?;
|
||||||
|
let eos_token = get_token("</s>")?;
|
||||||
|
let inst_token = get_token("[INST]")?;
|
||||||
|
let end_inst_token = get_token("[/INST]")?;
|
||||||
|
let img_break = get_token("[IMG_BREAK]")?;
|
||||||
|
let img_end = get_token("[IMG_END]")?;
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
let mut pos = 0;
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let logits = if index > 0 {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.language_model.forward(&input, pos)?;
|
||||||
|
pos += context_size;
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let (_b, _c, h, w) = self.image.dims4()?;
|
||||||
|
let h = h / self.model.patch_size;
|
||||||
|
let w = w / self.model.patch_size;
|
||||||
|
let image_embeds = self.model.vision_tower.forward(&self.image)?;
|
||||||
|
let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?;
|
||||||
|
println!("generated image embeddings {image_embeds:?}");
|
||||||
|
let image_embeds = image_embeds.to_dtype(self.model.dtype)?;
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let break_embeds = {
|
||||||
|
let input = Tensor::new(&[img_break], &self.device)?.unsqueeze(0)?;
|
||||||
|
self.model.language_model.embed_tokens().forward(&input)?
|
||||||
|
};
|
||||||
|
let start_embeds = {
|
||||||
|
let mut in_tokens = vec![bos_token, inst_token];
|
||||||
|
in_tokens.extend_from_slice(tokens.as_slice());
|
||||||
|
let input = Tensor::new(in_tokens.as_slice(), &self.device)?.unsqueeze(0)?;
|
||||||
|
self.model.language_model.embed_tokens().forward(&input)?
|
||||||
|
};
|
||||||
|
let end_embeds = {
|
||||||
|
let input =
|
||||||
|
Tensor::new(&[img_end, end_inst_token], &self.device)?.unsqueeze(0)?;
|
||||||
|
self.model.language_model.embed_tokens().forward(&input)?
|
||||||
|
};
|
||||||
|
let mut input_embeds = vec![start_embeds];
|
||||||
|
for h_idx in 0..h {
|
||||||
|
if h_idx > 0 {
|
||||||
|
input_embeds.push(break_embeds.clone())
|
||||||
|
}
|
||||||
|
let row = image_embeds.narrow(1, h_idx * w, w)?;
|
||||||
|
input_embeds.push(row);
|
||||||
|
}
|
||||||
|
input_embeds.push(end_embeds);
|
||||||
|
|
||||||
|
let input_embeds = Tensor::cat(&input_embeds, 1)?;
|
||||||
|
let logits = self
|
||||||
|
.model
|
||||||
|
.language_model
|
||||||
|
.forward_embeds(&input_embeds, None, pos)?;
|
||||||
|
pos += input_embeds.dim(1)?;
|
||||||
|
logits
|
||||||
|
};
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "Describe the image.\n")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config_file: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weight_files: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
image: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
vision_only: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature.unwrap_or(0.),
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match &args.model_id {
|
||||||
|
Some(model_id) => model_id.to_string(),
|
||||||
|
None => "mistral-community/pixtral-12b".to_string(),
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weight_files {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let dtype = if device.supports_bf16() && !args.vision_only {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let config: Config = match args.config_file {
|
||||||
|
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||||
|
None => {
|
||||||
|
let config_file = repo.get("config.json")?;
|
||||||
|
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let image = if args.image.ends_with(".safetensors") {
|
||||||
|
match candle::safetensors::load(&args.image, &device)?.remove("img") {
|
||||||
|
None => anyhow::bail!("no img tensor in {}", args.image),
|
||||||
|
Some(v) => v,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
candle_examples::imagenet::load_image_with_std_mean(
|
||||||
|
&args.image,
|
||||||
|
1024,
|
||||||
|
&[0.48145466, 0.4578275, 0.40821073],
|
||||||
|
&[0.26862954, 0.261_302_6, 0.275_777_1],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let image = image.to_device(&device)?.unsqueeze(0)?;
|
||||||
|
println!("loaded image with shape {:?}", image);
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
|
||||||
|
if args.vision_only {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let model = vision_model::Model::new(&config.vision_config, vb.pp("vision_tower"))?;
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
let embs = model.forward(&image)?;
|
||||||
|
println!("EMBS\n{embs}");
|
||||||
|
} else {
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
image,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
args.temperature,
|
||||||
|
args.top_p,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -14,6 +14,7 @@ use std::sync::Arc;
|
|||||||
pub struct VarBuilderArgs<'a, B: Backend> {
|
pub struct VarBuilderArgs<'a, B: Backend> {
|
||||||
data: Arc<TensorData<B>>,
|
data: Arc<TensorData<B>>,
|
||||||
path: Vec<String>,
|
path: Vec<String>,
|
||||||
|
pub dtype: DType,
|
||||||
_phantom: std::marker::PhantomData<&'a B>,
|
_phantom: std::marker::PhantomData<&'a B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ impl<'a, B: Backend> Clone for VarBuilderArgs<'a, B> {
|
|||||||
Self {
|
Self {
|
||||||
data: self.data.clone(),
|
data: self.data.clone(),
|
||||||
path: self.path.clone(),
|
path: self.path.clone(),
|
||||||
|
dtype: self.dtype,
|
||||||
_phantom: self._phantom,
|
_phantom: self._phantom,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -33,7 +35,6 @@ pub type VarBuilder<'a> = VarBuilderArgs<'a, Box<dyn SimpleBackend + 'a>>;
|
|||||||
|
|
||||||
struct TensorData<B: Backend> {
|
struct TensorData<B: Backend> {
|
||||||
backend: B,
|
backend: B,
|
||||||
pub dtype: DType,
|
|
||||||
pub device: Device,
|
pub device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,12 +96,12 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
|
pub fn new_with_args(backend: B, dtype: DType, dev: &Device) -> Self {
|
||||||
let data = TensorData {
|
let data = TensorData {
|
||||||
backend,
|
backend,
|
||||||
dtype,
|
|
||||||
device: dev.clone(),
|
device: dev.clone(),
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
data: Arc::new(data),
|
data: Arc::new(data),
|
||||||
path: vec![],
|
path: vec![],
|
||||||
|
dtype,
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -115,6 +116,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
Self {
|
Self {
|
||||||
data: self.data.clone(),
|
data: self.data.clone(),
|
||||||
path: vec![],
|
path: vec![],
|
||||||
|
dtype: self.dtype,
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -124,6 +126,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
Self {
|
Self {
|
||||||
data: self.data.clone(),
|
data: self.data.clone(),
|
||||||
path: vec![prefix.to_string()],
|
path: vec![prefix.to_string()],
|
||||||
|
dtype: self.dtype,
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -136,6 +139,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
Self {
|
Self {
|
||||||
data: self.data.clone(),
|
data: self.data.clone(),
|
||||||
path,
|
path,
|
||||||
|
dtype: self.dtype,
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -152,7 +156,17 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
|
|
||||||
/// The dtype used by default.
|
/// The dtype used by default.
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
self.data.dtype
|
self.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clone the VarBuilder tweaking its dtype
|
||||||
|
pub fn to_dtype(&self, dtype: DType) -> Self {
|
||||||
|
Self {
|
||||||
|
data: self.data.clone(),
|
||||||
|
path: self.path.clone(),
|
||||||
|
dtype,
|
||||||
|
_phantom: std::marker::PhantomData,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn path(&self, tensor_name: &str) -> String {
|
fn path(&self, tensor_name: &str) -> String {
|
||||||
@ -178,7 +192,7 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
|
|||||||
name: &str,
|
name: &str,
|
||||||
hints: B::Hints,
|
hints: B::Hints,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
|
self.get_with_hints_dtype(s, name, hints, self.dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Retrieve the tensor associated with the given name at the current path.
|
/// Retrieve the tensor associated with the given name at the current path.
|
||||||
@ -460,14 +474,11 @@ impl<'a> VarBuilder<'a> {
|
|||||||
dtype: DType,
|
dtype: DType,
|
||||||
device: Device,
|
device: Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let data = TensorData {
|
let data = TensorData { backend, device };
|
||||||
backend,
|
|
||||||
dtype,
|
|
||||||
device,
|
|
||||||
};
|
|
||||||
Self {
|
Self {
|
||||||
data: Arc::new(data),
|
data: Arc::new(data),
|
||||||
path: vec![],
|
path: vec![],
|
||||||
|
dtype,
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -567,13 +578,10 @@ impl<'a> VarBuilder<'a> {
|
|||||||
let path = self.path.clone();
|
let path = self.path.clone();
|
||||||
let backend = Rename::new(self, renamer);
|
let backend = Rename::new(self, renamer);
|
||||||
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
|
let backend: Box<dyn SimpleBackend + 'a> = Box::new(backend);
|
||||||
let data = TensorData {
|
let data = TensorData { backend, device };
|
||||||
backend,
|
|
||||||
dtype,
|
|
||||||
device,
|
|
||||||
};
|
|
||||||
Self {
|
Self {
|
||||||
data: Arc::new(data),
|
data: Arc::new(data),
|
||||||
|
dtype,
|
||||||
path,
|
path,
|
||||||
_phantom: std::marker::PhantomData,
|
_phantom: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
|
@ -279,7 +279,7 @@ impl LLaVA {
|
|||||||
(),
|
(),
|
||||||
))?
|
))?
|
||||||
} else {
|
} else {
|
||||||
todo!("not implemented in original python LLaVA yet")
|
bail!("not implemented in original python LLaVA yet")
|
||||||
};
|
};
|
||||||
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
|
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
|
||||||
let new_image_feature = new_image_feature
|
let new_image_feature = new_image_feature
|
||||||
|
@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D};
|
|||||||
use candle_nn::{Activation, VarBuilder};
|
use candle_nn::{Activation, VarBuilder};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn default_num_attention_heads() -> usize {
|
||||||
|
32
|
||||||
|
}
|
||||||
|
|
||||||
fn default_use_flash_attn() -> bool {
|
fn default_use_flash_attn() -> bool {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_hidden_act() -> candle_nn::Activation {
|
||||||
|
candle_nn::Activation::Silu
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
pub num_hidden_layers: usize,
|
pub num_hidden_layers: usize,
|
||||||
|
#[serde(default = "default_num_attention_heads")]
|
||||||
pub num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
pub head_dim: Option<usize>,
|
pub head_dim: Option<usize>,
|
||||||
pub num_key_value_heads: usize,
|
pub num_key_value_heads: usize,
|
||||||
|
#[serde(default = "default_hidden_act")]
|
||||||
pub hidden_act: Activation,
|
pub hidden_act: Activation,
|
||||||
pub max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
pub rms_norm_eps: f64,
|
pub rms_norm_eps: f64,
|
||||||
@ -107,14 +117,14 @@ impl RotaryEmbedding {
|
|||||||
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let inv_freq_len = inv_freq.len();
|
let inv_freq_len = inv_freq.len();
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
.to_dtype(dtype)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?,
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
cos: freqs.cos()?,
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -404,6 +414,10 @@ impl Model {
|
|||||||
.to_dtype(self.dtype)
|
.to_dtype(self.dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
|
||||||
|
&self.embed_tokens
|
||||||
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
let (_b_size, seq_len) = input_ids.dims2()?;
|
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||||
let attention_mask = if seq_len <= 1 {
|
let attention_mask = if seq_len <= 1 {
|
||||||
@ -421,6 +435,22 @@ impl Model {
|
|||||||
.apply(&self.lm_head)
|
.apply(&self.lm_head)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn forward_embeds(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (_b_size, seq_len, _) = xs.dims3()?;
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
|
||||||
|
}
|
||||||
|
xs.narrow(1, seq_len - 1, 1)?
|
||||||
|
.apply(&self.norm)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
pub fn clear_kv_cache(&mut self) {
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
layer.clear_kv_cache()
|
layer.clear_kv_cache()
|
||||||
|
@ -51,6 +51,7 @@ pub mod parler_tts;
|
|||||||
pub mod persimmon;
|
pub mod persimmon;
|
||||||
pub mod phi;
|
pub mod phi;
|
||||||
pub mod phi3;
|
pub mod phi3;
|
||||||
|
pub mod pixtral;
|
||||||
pub mod quantized_blip;
|
pub mod quantized_blip;
|
||||||
pub mod quantized_blip_text;
|
pub mod quantized_blip_text;
|
||||||
pub mod quantized_llama;
|
pub mod quantized_llama;
|
||||||
|
72
candle-transformers/src/models/pixtral/llava.rs
Normal file
72
candle-transformers/src/models/pixtral/llava.rs
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
use candle::{Module, Result, Tensor};
|
||||||
|
use candle_nn::{linear, Linear, VarBuilder};
|
||||||
|
|
||||||
|
use super::vision_model;
|
||||||
|
use crate::models::mistral;
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
pub projector_hidden_act: candle_nn::Activation,
|
||||||
|
pub text_config: mistral::Config,
|
||||||
|
pub vision_config: vision_model::Config,
|
||||||
|
pub image_token_index: usize,
|
||||||
|
pub image_seq_length: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MultiModalProjector {
|
||||||
|
linear_1: Linear,
|
||||||
|
act: candle_nn::Activation,
|
||||||
|
linear_2: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MultiModalProjector {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let (hidden_v, hidden_t) = (cfg.vision_config.hidden_size, cfg.text_config.hidden_size);
|
||||||
|
let linear_1 = linear(hidden_v, hidden_t, vb.pp("linear_1"))?;
|
||||||
|
let linear_2 = linear(hidden_t, hidden_t, vb.pp("linear_2"))?;
|
||||||
|
Ok(Self {
|
||||||
|
linear_1,
|
||||||
|
act: cfg.projector_hidden_act,
|
||||||
|
linear_2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MultiModalProjector {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.linear_1)?
|
||||||
|
.apply(&self.act)?
|
||||||
|
.apply(&self.linear_2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
pub multi_modal_projector: MultiModalProjector,
|
||||||
|
pub language_model: mistral::Model,
|
||||||
|
pub vision_tower: vision_model::Model,
|
||||||
|
pub patch_size: usize,
|
||||||
|
pub dtype: candle::DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let language_model = mistral::Model::new(&cfg.text_config, vb.pp("language_model"))?;
|
||||||
|
let vision_tower = vision_model::Model::new(
|
||||||
|
&cfg.vision_config,
|
||||||
|
vb.pp("vision_tower").to_dtype(candle::DType::F32),
|
||||||
|
)?;
|
||||||
|
let multi_modal_projector = MultiModalProjector::new(
|
||||||
|
cfg,
|
||||||
|
vb.pp("multi_modal_projector").to_dtype(candle::DType::F32),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
multi_modal_projector,
|
||||||
|
language_model,
|
||||||
|
vision_tower,
|
||||||
|
patch_size: cfg.vision_config.patch_size,
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
4
candle-transformers/src/models/pixtral/mod.rs
Normal file
4
candle-transformers/src/models/pixtral/mod.rs
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
pub mod llava;
|
||||||
|
pub mod vision_model;
|
||||||
|
|
||||||
|
pub use llava::{Config, Model};
|
324
candle-transformers/src/models/pixtral/vision_model.rs
Normal file
324
candle-transformers/src/models/pixtral/vision_model.rs
Normal file
@ -0,0 +1,324 @@
|
|||||||
|
use candle::{DType, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};
|
||||||
|
|
||||||
|
fn default_act() -> candle_nn::Activation {
|
||||||
|
candle_nn::Activation::Gelu
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_hidden_size() -> usize {
|
||||||
|
1024
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_intermediate_size() -> usize {
|
||||||
|
4096
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_num_channels() -> usize {
|
||||||
|
3
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_num_hidden_layers() -> usize {
|
||||||
|
24
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_num_attention_heads() -> usize {
|
||||||
|
16
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
#[serde(default = "default_hidden_size")]
|
||||||
|
pub hidden_size: usize,
|
||||||
|
#[serde(default = "default_num_channels")]
|
||||||
|
pub num_channels: usize,
|
||||||
|
pub image_size: usize,
|
||||||
|
pub patch_size: usize,
|
||||||
|
pub rope_theta: f64,
|
||||||
|
#[serde(default = "default_intermediate_size")]
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
#[serde(default = "default_num_hidden_layers")]
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub head_dim: Option<usize>,
|
||||||
|
#[serde(default = "default_num_attention_heads")]
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
#[serde(default = "default_act")]
|
||||||
|
pub hidden_act: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn pixtral_12b_2409() -> Self {
|
||||||
|
Self {
|
||||||
|
hidden_size: 1024,
|
||||||
|
num_channels: 3,
|
||||||
|
image_size: 1024,
|
||||||
|
patch_size: 16,
|
||||||
|
rope_theta: 10000.0,
|
||||||
|
intermediate_size: 4096,
|
||||||
|
num_hidden_layers: 24,
|
||||||
|
num_attention_heads: 16,
|
||||||
|
head_dim: None,
|
||||||
|
// Default
|
||||||
|
hidden_act: candle_nn::Activation::Gelu,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn head_dim(&self) -> usize {
|
||||||
|
self.head_dim
|
||||||
|
.unwrap_or(self.hidden_size / self.num_attention_heads)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
scale: f64,
|
||||||
|
num_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let h = cfg.hidden_size;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let head_dim = cfg.head_dim();
|
||||||
|
let q_proj = linear_b(h, h, false, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear_b(h, h, false, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_b(h, h, false, vb.pp("v_proj"))?;
|
||||||
|
let o_proj = linear_b(h, h, false, vb.pp("o_proj"))?;
|
||||||
|
let scale = (head_dim as f64).powf(-0.5);
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
scale,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
emb: &RotaryEmbedding,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b, patches, _) = xs.dims3()?;
|
||||||
|
let query_states = xs.apply(&self.q_proj)?;
|
||||||
|
let key_states = xs.apply(&self.k_proj)?;
|
||||||
|
let value_states = xs.apply(&self.v_proj)?;
|
||||||
|
|
||||||
|
let shape = (b, patches, self.num_heads, self.head_dim);
|
||||||
|
let query_states = query_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||||
|
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||||
|
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
|
let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
|
||||||
|
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;
|
||||||
|
|
||||||
|
let attn_weights = match attention_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
attn_weights
|
||||||
|
.matmul(&value_states)?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b, patches, ()))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Mlp {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
act_fn: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let (h, i) = (cfg.hidden_size, cfg.intermediate_size);
|
||||||
|
let gate_proj = linear_b(h, i, false, vb.pp("gate_proj"))?;
|
||||||
|
let up_proj = linear_b(h, i, false, vb.pp("up_proj"))?;
|
||||||
|
let down_proj = linear_b(i, h, false, vb.pp("down_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Mlp {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
(xs.apply(&self.gate_proj)?.apply(&self.act_fn)? * xs.apply(&self.up_proj))?
|
||||||
|
.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct AttentionLayer {
|
||||||
|
attention_norm: RmsNorm,
|
||||||
|
feed_forward: Mlp,
|
||||||
|
attention: Attention,
|
||||||
|
ffn_norm: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AttentionLayer {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let attention_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("attention_norm"))?;
|
||||||
|
let feed_forward = Mlp::new(cfg, vb.pp("feed_forward"))?;
|
||||||
|
let attention = Attention::new(cfg, vb.pp("attention"))?;
|
||||||
|
let ffn_norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("ffn_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
attention_norm,
|
||||||
|
feed_forward,
|
||||||
|
attention,
|
||||||
|
ffn_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
emb: &RotaryEmbedding,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self
|
||||||
|
.attention
|
||||||
|
.forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
|
||||||
|
let xs = (residual + xs)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
|
||||||
|
xs + residual
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Transformer {
|
||||||
|
layers: Vec<AttentionLayer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Transformer {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb = vb.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
Ok(Self { layers })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&self,
|
||||||
|
xs: &Tensor,
|
||||||
|
emb: &RotaryEmbedding,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = layer.forward(&xs, emb, attention_mask)?
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
cos: Tensor,
|
||||||
|
sin: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let dtype = vb.dtype();
|
||||||
|
let dev = vb.device();
|
||||||
|
let dim = cfg.head_dim();
|
||||||
|
let rope_theta = cfg.rope_theta as f32;
|
||||||
|
let max_patches_per_side = cfg.image_size / cfg.patch_size;
|
||||||
|
let freqs: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||||
|
.collect();
|
||||||
|
let freqs_h = freqs.iter().step_by(2).copied().collect::<Vec<_>>();
|
||||||
|
let freqs_h = Tensor::new(freqs_h, dev)?;
|
||||||
|
let freqs_w = freqs.iter().skip(1).step_by(2).copied().collect::<Vec<_>>();
|
||||||
|
let freqs_w = Tensor::new(freqs_w, dev)?;
|
||||||
|
let h = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
|
||||||
|
let w = Tensor::arange(0u32, max_patches_per_side as u32, dev)?.to_dtype(DType::F32)?;
|
||||||
|
let freqs_h = h.unsqueeze(1)?.matmul(&freqs_h.unsqueeze(0)?)?;
|
||||||
|
let freqs_w = w.unsqueeze(1)?.matmul(&freqs_w.unsqueeze(0)?)?;
|
||||||
|
let inv_freq = Tensor::cat(
|
||||||
|
&[
|
||||||
|
freqs_h.unsqueeze(1)?.repeat((1, max_patches_per_side, 1))?,
|
||||||
|
freqs_w.unsqueeze(0)?.repeat((max_patches_per_side, 1, 1))?,
|
||||||
|
],
|
||||||
|
D::Minus1,
|
||||||
|
)?
|
||||||
|
.reshape(((), dim / 2))?;
|
||||||
|
let cos = inv_freq.cos()?.to_dtype(dtype)?;
|
||||||
|
let sin = inv_freq.sin()?.to_dtype(dtype)?;
|
||||||
|
Ok(Self { cos, sin })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = &self.cos;
|
||||||
|
let sin = &self.sin;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
patch_conv: candle_nn::Conv2d,
|
||||||
|
ln_pre: RmsNorm,
|
||||||
|
transformer: Transformer,
|
||||||
|
patch_positional_embedding: RotaryEmbedding,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let conv2d_cfg = candle_nn::Conv2dConfig {
|
||||||
|
stride: cfg.patch_size,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let patch_conv = candle_nn::conv2d_no_bias(
|
||||||
|
cfg.num_channels,
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.patch_size,
|
||||||
|
conv2d_cfg,
|
||||||
|
vb.pp("patch_conv"),
|
||||||
|
)?;
|
||||||
|
let ln_pre = candle_nn::rms_norm(cfg.hidden_size, 1e-5, vb.pp("ln_pre"))?;
|
||||||
|
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
|
||||||
|
let patch_positional_embedding =
|
||||||
|
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
|
||||||
|
Ok(Self {
|
||||||
|
patch_conv,
|
||||||
|
ln_pre,
|
||||||
|
transformer,
|
||||||
|
patch_positional_embedding,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Model {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let patch_embeds = xs.apply(&self.patch_conv)?;
|
||||||
|
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
|
||||||
|
self.transformer
|
||||||
|
.forward(&patch_embeds, &self.patch_positional_embedding, None)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user