diff --git a/candle-examples/examples/blip/main.rs b/candle-examples/examples/blip/main.rs index 81c01482..45300feb 100644 --- a/candle-examples/examples/blip/main.rs +++ b/candle-examples/examples/blip/main.rs @@ -11,9 +11,25 @@ use candle::{DType, Device, Result, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; use candle_nn::VarBuilder; use candle_transformers::models::blip; +use candle_transformers::models::quantized_blip; use tokenizers::Tokenizer; +enum Model { + M(blip::BlipForConditionalGeneration), + Q(quantized_blip::BlipForConditionalGeneration), +} + +impl Model { + fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result { + match self { + Self::M(m) => m.text_decoder().forward(xs, img_xs), + Self::Q(m) => m.text_decoder().forward(xs, img_xs), + } + } +} + +// TODO: Maybe add support for the conditional prompt. #[derive(Parser)] struct Args { #[arg(long)] @@ -28,6 +44,10 @@ struct Args { /// Run on CPU rather than on GPU. #[arg(long)] cpu: bool, + + /// Use the quantized version of the model. + #[arg(long)] + quantized: bool, } const SEP_TOKEN_ID: u32 = 102; @@ -54,20 +74,20 @@ pub fn load_image>(p: P) -> Result { pub fn main() -> anyhow::Result<()> { let args = Args::parse(); - let device = candle_examples::device(args.cpu)?; - - let image = load_image(args.image)?.to_device(&device)?; - println!("loaded image {image:?}"); - let model_file = match args.model { None => { let api = hf_hub::api::sync::Api::new()?; - let api = api.repo(hf_hub::Repo::with_revision( - "Salesforce/blip-image-captioning-large".to_string(), - hf_hub::RepoType::Model, - "refs/pr/18".to_string(), - )); - api.get("model.safetensors")? + if args.quantized { + let api = api.model("lmz/candle-blip".to_string()); + api.get("blip-image-captioning-large-q4k.gguf")? + } else { + let api = api.repo(hf_hub::Repo::with_revision( + "Salesforce/blip-image-captioning-large".to_string(), + hf_hub::RepoType::Model, + "refs/pr/18".to_string(), + )); + api.get("model.safetensors")? + } } Some(model) => model.into(), }; @@ -84,19 +104,35 @@ pub fn main() -> anyhow::Result<()> { let mut logits_processor = candle_transformers::generation::LogitsProcessor::new(1337, None, None); - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; let config = blip::Config::image_captioning_large(); - let mut model = blip::BlipForConditionalGeneration::new(&config, vb)?; - println!("model built"); - // TODO: Maybe add support for the conditional prompt. - let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; + + let (image_embeds, device, mut model) = if args.quantized { + let device = Device::Cpu; + let image = load_image(args.image)?.to_device(&device)?; + println!("loaded image {image:?}"); + + let vb = quantized_blip::VarBuilder::from_gguf(model_file)?; + let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; + let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; + (image_embeds, device, Model::Q(model)) + } else { + let device = candle_examples::device(args.cpu)?; + let image = load_image(args.image)?.to_device(&device)?; + println!("loaded image {image:?}"); + + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = blip::BlipForConditionalGeneration::new(&config, vb)?; + let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; + (image_embeds, device, Model::M(model)) + }; let mut token_ids = vec![30522u32]; for index in 0..1000 { let context_size = if index > 0 { 1 } else { token_ids.len() }; let start_pos = token_ids.len().saturating_sub(context_size); let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?; - let logits = model.text_decoder().forward(&input_ids, &image_embeds)?; + let logits = model.text_decoder_forward(&input_ids, &image_embeds)?; let logits = logits.squeeze(0)?; let logits = logits.get(logits.dim(0)? - 1)?; let token = logits_processor.sample(&logits)?; diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 6836b9c0..ce576c54 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -10,6 +10,8 @@ pub mod llama; pub mod mistral; pub mod mixformer; pub mod mpt; +pub mod quantized_blip; +pub mod quantized_blip_text; pub mod quantized_llama; pub mod quantized_mistral; pub mod quantized_mixformer; diff --git a/candle-transformers/src/models/quantized_blip.rs b/candle-transformers/src/models/quantized_blip.rs new file mode 100644 index 00000000..6c498aa0 --- /dev/null +++ b/candle-transformers/src/models/quantized_blip.rs @@ -0,0 +1,258 @@ +use super::quantized_blip_text as blip_text; +use crate::quantized_nn::{layer_norm, linear, Linear}; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{Module, Result, Tensor, D}; +use candle_nn::{Conv2d, Conv2dConfig, LayerNorm}; + +pub type VisionConfig = super::blip::VisionConfig; +pub type Config = super::blip::Config; + +#[derive(Debug, Clone)] +struct VisionEmbeddings { + class_embedding: Tensor, + patch_embedding: Conv2d, + position_embedding: Tensor, +} + +impl VisionEmbeddings { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let class_embedding = vb + .get((1, 1, cfg.hidden_size), "class_embedding")? + .dequantize(vb.device())?; + let conv_cfg = Conv2dConfig { + stride: cfg.patch_size, + ..Default::default() + }; + let pe_vb = vb.pp("patch_embedding"); + let pe_weight = pe_vb + .get( + (cfg.hidden_size, 3, cfg.patch_size, cfg.patch_size), + "weight", + )? + .dequantize(vb.device())?; + let pe_bias = pe_vb + .get(cfg.hidden_size, "bias")? + .dequantize(vb.device())?; + + let patch_embedding = Conv2d::new(pe_weight, Some(pe_bias), conv_cfg); + let num_patches1 = cfg.image_size / cfg.patch_size; + let num_patches = num_patches1 * num_patches1; + let num_positions = num_patches + 1; + let position_embedding = vb + .get((1, num_positions, cfg.hidden_size), "position_embedding")? + .dequantize(vb.device())?; + Ok(Self { + class_embedding, + patch_embedding, + position_embedding, + }) + } +} + +impl Module for VisionEmbeddings { + fn forward(&self, xs: &Tensor) -> Result { + let target_dtype = xs.dtype(); + let b_size = xs.dim(0)?; + let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?; + let d = self.class_embedding.dim(D::Minus1)?; + let class_embeds = self + .class_embedding + .broadcast_as((b_size, 1, d))? + .to_dtype(target_dtype)?; + let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?; + let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?; + embeddings.broadcast_add(&position_embedding) + } +} + +#[derive(Debug, Clone)] +struct Attention { + qkv: Linear, + projection: Linear, + scale: f64, + num_heads: usize, +} + +impl Attention { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let embed_dim = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let head_dim = embed_dim / num_heads; + let scale = 1f64 / (head_dim as f64).sqrt(); + let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?; + let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?; + Ok(Self { + qkv, + projection, + scale, + num_heads, + }) + } + + fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result { + let (b_sz, tgt_len, embed_dim) = xs.dims3()?; + let mixed_qkv = xs + .apply(&self.qkv)? + .reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))? + .permute((2, 0, 3, 1, 4))?; + let query = mixed_qkv.get(0)?; + let key = mixed_qkv.get(1)?; + let value = mixed_qkv.get(2)?; + let attention_scores = query.matmul(&key.t()?)?; + let attention_scores = (attention_scores * self.scale)?; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + let attention_probs = match attn_mask { + None => attention_probs, + Some(attn_mask) => (attention_probs * attn_mask)?, + }; + attention_probs + .matmul(&value)? + .permute((0, 2, 1, 3))? + .flatten_from(D::Minus2)? + .apply(&self.projection) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + activation_fn: candle_nn::Activation, + fc1: Linear, + fc2: Linear, +} + +impl MLP { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?; + let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; + Ok(Self { + activation_fn: cfg.hidden_act, + fc1, + fc2, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.fc1)? + .apply(&self.activation_fn)? + .apply(&self.fc2) + } +} + +#[derive(Debug, Clone)] +struct EncoderLayer { + self_attn: Attention, + layer_norm1: LayerNorm, + mlp: MLP, + layer_norm2: LayerNorm, +} + +impl EncoderLayer { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let embed_dim = cfg.hidden_size; + let self_attn = Attention::new(cfg, vb.pp("self_attn"))?; + let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?; + let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + Ok(Self { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = xs.apply(&self.layer_norm1)?; + let xs = self.self_attn.forward(&xs, attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?; + xs + residual + } +} + +#[derive(Debug, Clone)] +struct Encoder { + layers: Vec, +} + +impl Encoder { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb = vb.pp("layers"); + for i in 0..cfg.num_hidden_layers { + let layer = EncoderLayer::new(cfg, vb.pp(i))?; + layers.push(layer) + } + Ok(Self { layers }) + } + + fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, attention_mask)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct VisionModel { + embeddings: VisionEmbeddings, + encoder: Encoder, + post_layernorm: LayerNorm, +} + +impl VisionModel { + fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result { + let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?; + let encoder = Encoder::new(cfg, vb.pp("encoder"))?; + let post_layernorm = + layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?; + Ok(Self { + embeddings, + encoder, + post_layernorm, + }) + } +} + +impl Module for VisionModel { + fn forward(&self, xs: &Tensor) -> Result { + let xs = xs.apply(&self.embeddings)?; + let encoder_outputs = self.encoder.forward(&xs, None)?; + // Return the last hidden state rather than pooled outputs. + encoder_outputs.apply(&self.post_layernorm) + } +} + +#[derive(Debug, Clone)] +pub struct BlipForConditionalGeneration { + vision_model: VisionModel, + text_decoder: blip_text::TextLMHeadModel, +} + +impl BlipForConditionalGeneration { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?; + let text_decoder = + blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?; + Ok(Self { + vision_model, + text_decoder, + }) + } + + pub fn vision_model(&self) -> &VisionModel { + &self.vision_model + } + + pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel { + &mut self.text_decoder + } +} diff --git a/candle-transformers/src/models/quantized_blip_text.rs b/candle-transformers/src/models/quantized_blip_text.rs new file mode 100644 index 00000000..652205d6 --- /dev/null +++ b/candle-transformers/src/models/quantized_blip_text.rs @@ -0,0 +1,476 @@ +use crate::models::with_tracing::QMatMul; +use crate::quantized_nn::{layer_norm, linear, Embedding, Linear}; +pub use crate::quantized_var_builder::VarBuilder; +use candle::{Module, Result, Tensor, D}; +use candle_nn::LayerNorm; + +pub type Config = super::blip_text::Config; + +#[derive(Debug, Clone)] +struct TextEmbeddings { + word_embedddings: Embedding, + position_embeddings: Embedding, + layer_norm: LayerNorm, + position_ids: Tensor, +} + +impl TextEmbeddings { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let word_embedddings = + Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?; + let position_embeddings = Embedding::new( + cfg.max_position_embeddings, + cfg.hidden_size, + vb.pp("position_embeddings"), + )?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + let position_ids = + Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?; + Ok(Self { + word_embedddings, + position_embeddings, + layer_norm, + position_ids, + }) + } + + fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result { + let seq_len = xs.dim(1)?; + let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?; + let embeddings = self.word_embedddings.forward(xs)?; + let position_embeddings = self.position_embeddings.forward(&position_ids)?; + (embeddings + position_embeddings)?.apply(&self.layer_norm) + } +} + +#[derive(Debug, Clone)] +struct TextSelfAttention { + query: Linear, + key: Linear, + value: Linear, + attention_head_size: usize, + num_attention_heads: usize, + attention_scale: f64, + kv_cache: Option<(Tensor, Tensor)>, +} + +impl TextSelfAttention { + fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result { + let num_attention_heads = cfg.num_attention_heads; + let attention_head_size = cfg.hidden_size / num_attention_heads; + let all_head_size = cfg.num_attention_heads * attention_head_size; + let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?; + let in_size = if is_cross_attention { + cfg.encoder_hidden_size + } else { + cfg.hidden_size + }; + let key = linear(in_size, all_head_size, vb.pp("key"))?; + let value = linear(in_size, all_head_size, vb.pp("value"))?; + let attention_scale = 1f64 / (attention_head_size as f64).sqrt(); + Ok(Self { + query, + key, + value, + attention_head_size, + num_attention_heads, + attention_scale, + kv_cache: None, + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let (b_size, seq_len, _) = xs.dims3()?; + xs.reshape(( + b_size, + seq_len, + self.num_attention_heads, + self.attention_head_size, + ))? + .permute((0, 2, 1, 3)) + } + + fn reset_kv_cache(&mut self) { + self.kv_cache = None + } + + fn forward( + &mut self, + xs: &Tensor, + encoder_hidden_states: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result { + let query = self + .transpose_for_scores(&self.query.forward(xs)?)? + .contiguous()?; + let (key, value) = match encoder_hidden_states { + None => { + let key = self.transpose_for_scores(&self.key.forward(xs)?)?; + let value = self.transpose_for_scores(&self.value.forward(xs)?)?; + let (key, value) = match &self.kv_cache { + None => (key, value), + Some((prev_key, prev_value)) => { + let key = Tensor::cat(&[prev_key, &key], 2)?; + let value = Tensor::cat(&[prev_value, &value], 2)?; + (key, value) + } + }; + self.kv_cache = Some((key.clone(), value.clone())); + (key, value) + } + Some(xs) => { + let key = self.transpose_for_scores(&self.key.forward(xs)?)?; + let value = self.transpose_for_scores(&self.value.forward(xs)?)?; + // no kv-cache in this case, but the results could probably be memoized. + (key, value) + } + }; + let key = key.contiguous()?; + let value = value.contiguous()?; + let attention_scores = query.matmul(&key.t()?)?; + let attention_scores = (attention_scores * self.attention_scale)?; + let attention_scores = match attention_mask { + Some(mask) => attention_scores.broadcast_add(mask)?, + None => attention_scores, + }; + let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?; + attention_probs + .matmul(&value)? + .permute((0, 2, 1, 3))? + .flatten_from(D::Minus2) + } +} + +#[derive(Debug, Clone)] +struct TextSelfOutput { + dense: Linear, + layer_norm: LayerNorm, +} + +impl TextSelfOutput { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { dense, layer_norm }) + } + + fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result { + (xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm) + } +} + +#[derive(Debug, Clone)] +struct TextAttention { + self_: TextSelfAttention, + output: TextSelfOutput, +} + +impl TextAttention { + fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result { + let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?; + let output = TextSelfOutput::new(cfg, vb.pp("output"))?; + Ok(Self { self_, output }) + } + + fn reset_kv_cache(&mut self) { + self.self_.reset_kv_cache() + } + + fn forward( + &mut self, + xs: &Tensor, + encoder_hidden_states: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result { + let self_outputs = self + .self_ + .forward(xs, encoder_hidden_states, attention_mask)?; + self.output.forward(&self_outputs, xs) + } +} + +#[derive(Debug, Clone)] +struct TextIntermediate { + dense: Linear, + intermediate_act_fn: candle_nn::Activation, +} + +impl TextIntermediate { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?; + Ok(Self { + dense, + intermediate_act_fn: cfg.hidden_act, + }) + } +} + +impl Module for TextIntermediate { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.dense)?.apply(&self.intermediate_act_fn) + } +} + +#[derive(Debug, Clone)] +struct TextOutput { + dense: Linear, + layer_norm: LayerNorm, +} + +impl TextOutput { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { dense, layer_norm }) + } + + fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result { + (xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm) + } +} + +#[derive(Debug, Clone)] +struct TextLayer { + attention: TextAttention, + cross_attention: Option, + intermediate: TextIntermediate, + output: TextOutput, +} + +impl TextLayer { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let attention = TextAttention::new(cfg, false, vb.pp("attention"))?; + let cross_attention = if cfg.is_decoder { + Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?) + } else { + None + }; + let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?; + let output = TextOutput::new(cfg, vb.pp("output"))?; + Ok(Self { + attention, + cross_attention, + intermediate, + output, + }) + } + + fn reset_kv_cache(&mut self) { + self.attention.reset_kv_cache(); + if let Some(ca) = &mut self.cross_attention { + ca.reset_kv_cache() + } + } + + fn forward( + &mut self, + xs: &Tensor, + encoder_hidden_states: &Tensor, + attention_mask: &Tensor, + ) -> Result { + let attention_output = self.attention.forward(xs, None, Some(attention_mask))?; + let attention_output = match &mut self.cross_attention { + Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?, + None => candle::bail!("expected some cross-attn"), + }; + let intermediate_output = self.intermediate.forward(&attention_output)?; + self.output.forward(&intermediate_output, &attention_output) + } +} + +#[derive(Debug, Clone)] +struct TextEncoder { + layers: Vec, +} + +impl TextEncoder { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb = vb.pp("layer"); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + for i in 0..cfg.num_hidden_layers { + let layer = TextLayer::new(cfg, vb.pp(i))?; + layers.push(layer) + } + Ok(Self { layers }) + } + + fn reset_kv_cache(&mut self) { + self.layers.iter_mut().for_each(|l| l.reset_kv_cache()) + } + + fn forward( + &mut self, + xs: &Tensor, + encoder_hidden_states: &Tensor, + attention_mask: &Tensor, + ) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, encoder_hidden_states, attention_mask)? + } + Ok(xs) + } +} + +#[derive(Debug, Clone)] +pub struct TextPooler { + dense: Linear, +} + +impl TextPooler { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + Ok(Self { dense }) + } +} + +impl Module for TextPooler { + fn forward(&self, xs: &Tensor) -> Result { + xs.narrow(D::Minus1, 0, 1)? + .squeeze(D::Minus1)? + .apply(&self.dense)? + .tanh() + } +} + +#[derive(Debug, Clone)] +struct TextPredictionHeadTransform { + dense: Linear, + transform_act_fn: candle_nn::Activation, + layer_norm: LayerNorm, +} + +impl TextPredictionHeadTransform { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?; + let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?; + Ok(Self { + dense, + transform_act_fn: cfg.hidden_act, + layer_norm, + }) + } +} + +impl Module for TextPredictionHeadTransform { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.dense)? + .apply(&self.transform_act_fn)? + .apply(&self.layer_norm) + } +} + +#[derive(Debug, Clone)] +struct TextLMPredictionHead { + transform: TextPredictionHeadTransform, + decoder: Linear, +} + +impl TextLMPredictionHead { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?; + let weight = QMatMul::new(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?; + let bias = vb.get(cfg.vocab_size, "bias")?.dequantize(vb.device())?; + let decoder = Linear::from_weights(weight, Some(bias)); + Ok(Self { transform, decoder }) + } +} + +impl Module for TextLMPredictionHead { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.transform)?.apply(&self.decoder) + } +} + +#[derive(Debug, Clone)] +struct TextOnlyMLMHead { + predictions: TextLMPredictionHead, +} + +impl TextOnlyMLMHead { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?; + Ok(Self { predictions }) + } +} + +impl Module for TextOnlyMLMHead { + fn forward(&self, xs: &Tensor) -> Result { + self.predictions.forward(xs) + } +} + +#[derive(Debug, Clone)] +struct TextModel { + embeddings: TextEmbeddings, + encoder: TextEncoder, + past_kv_len: usize, + // We do not need the pooler for caption generation +} + +impl TextModel { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?; + let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?; + Ok(Self { + embeddings, + encoder, + past_kv_len: 0, + }) + } + + fn forward( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: &Tensor, + attention_mask: &Tensor, + ) -> Result { + let (_b_sz, seq_len) = input_ids.dims2()?; + let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?; + let sequence_output = + self.encoder + .forward(&embedding_output, encoder_hidden_states, attention_mask)?; + self.past_kv_len += seq_len; + // We're interested in the sequence-output rather than the pooled-output. + Ok(sequence_output) + } + + fn reset_kv_cache(&mut self) { + self.past_kv_len = 0; + self.encoder.reset_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct TextLMHeadModel { + bert: TextModel, + cls: TextOnlyMLMHead, +} + +impl TextLMHeadModel { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let bert = TextModel::new(cfg, vb.pp("bert"))?; + let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?; + Ok(Self { bert, cls }) + } + + pub fn forward( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: &Tensor, + ) -> Result { + let seq_len = input_ids.dim(1)?; + let mask: Vec<_> = (0..seq_len) + .flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?; + let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?; + let prediction_scores = self.cls.forward(&sequence_output)?; + // return_logits is false so we don't discard the last sequence element. + Ok(prediction_scores) + } + + pub fn reset_kv_cache(&mut self) { + self.bert.reset_kv_cache() + } +} diff --git a/candle-transformers/src/quantized_nn.rs b/candle-transformers/src/quantized_nn.rs index 2941c3f0..99e8d45b 100644 --- a/candle-transformers/src/quantized_nn.rs +++ b/candle-transformers/src/quantized_nn.rs @@ -34,6 +34,12 @@ pub struct Linear { bias: Option, } +impl Linear { + pub fn from_weights(weight: QMatMul, bias: Option) -> Self { + Self { weight, bias } + } +} + impl Module for Linear { fn forward(&self, x: &Tensor) -> candle::Result { let x = x.apply(&self.weight)?;