mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add a quantized blip model. (#1155)
* Add a quantized blip model. * Integrate the quantized blip model to the actual example.
This commit is contained in:
@ -11,9 +11,25 @@ use candle::{DType, Device, Result, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::blip;
|
||||
use candle_transformers::models::quantized_blip;
|
||||
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum Model {
|
||||
M(blip::BlipForConditionalGeneration),
|
||||
Q(quantized_blip::BlipForConditionalGeneration),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::M(m) => m.text_decoder().forward(xs, img_xs),
|
||||
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
@ -28,6 +44,10 @@ struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
}
|
||||
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
@ -54,14 +74,13 @@ pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
if args.quantized {
|
||||
let api = api.model("lmz/candle-blip".to_string());
|
||||
api.get("blip-image-captioning-large-q4k.gguf")?
|
||||
} else {
|
||||
let api = api.repo(hf_hub::Repo::with_revision(
|
||||
"Salesforce/blip-image-captioning-large".to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
@ -69,6 +88,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
));
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
let tokenizer = match args.tokenizer {
|
||||
@ -84,19 +104,35 @@ pub fn main() -> anyhow::Result<()> {
|
||||
let mut logits_processor =
|
||||
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let config = blip::Config::image_captioning_large();
|
||||
let mut model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
println!("model built");
|
||||
// TODO: Maybe add support for the conditional prompt.
|
||||
|
||||
let (image_embeds, device, mut model) = if args.quantized {
|
||||
let device = Device::Cpu;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
|
||||
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::Q(model))
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = load_image(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
|
||||
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
|
||||
(image_embeds, device, Model::M(model))
|
||||
};
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
for index in 0..1000 {
|
||||
let context_size = if index > 0 { 1 } else { token_ids.len() };
|
||||
let start_pos = token_ids.len().saturating_sub(context_size);
|
||||
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
|
||||
let logits = model.text_decoder().forward(&input_ids, &image_embeds)?;
|
||||
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
|
@ -10,6 +10,8 @@ pub mod llama;
|
||||
pub mod mistral;
|
||||
pub mod mixformer;
|
||||
pub mod mpt;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_llama;
|
||||
pub mod quantized_mistral;
|
||||
pub mod quantized_mixformer;
|
||||
|
258
candle-transformers/src/models/quantized_blip.rs
Normal file
258
candle-transformers/src/models/quantized_blip.rs
Normal file
@ -0,0 +1,258 @@
|
||||
use super::quantized_blip_text as blip_text;
|
||||
use crate::quantized_nn::{layer_norm, linear, Linear};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn::{Conv2d, Conv2dConfig, LayerNorm};
|
||||
|
||||
pub type VisionConfig = super::blip::VisionConfig;
|
||||
pub type Config = super::blip::Config;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct VisionEmbeddings {
|
||||
class_embedding: Tensor,
|
||||
patch_embedding: Conv2d,
|
||||
position_embedding: Tensor,
|
||||
}
|
||||
|
||||
impl VisionEmbeddings {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let class_embedding = vb
|
||||
.get((1, 1, cfg.hidden_size), "class_embedding")?
|
||||
.dequantize(vb.device())?;
|
||||
let conv_cfg = Conv2dConfig {
|
||||
stride: cfg.patch_size,
|
||||
..Default::default()
|
||||
};
|
||||
let pe_vb = vb.pp("patch_embedding");
|
||||
let pe_weight = pe_vb
|
||||
.get(
|
||||
(cfg.hidden_size, 3, cfg.patch_size, cfg.patch_size),
|
||||
"weight",
|
||||
)?
|
||||
.dequantize(vb.device())?;
|
||||
let pe_bias = pe_vb
|
||||
.get(cfg.hidden_size, "bias")?
|
||||
.dequantize(vb.device())?;
|
||||
|
||||
let patch_embedding = Conv2d::new(pe_weight, Some(pe_bias), conv_cfg);
|
||||
let num_patches1 = cfg.image_size / cfg.patch_size;
|
||||
let num_patches = num_patches1 * num_patches1;
|
||||
let num_positions = num_patches + 1;
|
||||
let position_embedding = vb
|
||||
.get((1, num_positions, cfg.hidden_size), "position_embedding")?
|
||||
.dequantize(vb.device())?;
|
||||
Ok(Self {
|
||||
class_embedding,
|
||||
patch_embedding,
|
||||
position_embedding,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionEmbeddings {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let target_dtype = xs.dtype();
|
||||
let b_size = xs.dim(0)?;
|
||||
let patch_embeds = xs.apply(&self.patch_embedding)?.flatten_from(2)?.t()?;
|
||||
let d = self.class_embedding.dim(D::Minus1)?;
|
||||
let class_embeds = self
|
||||
.class_embedding
|
||||
.broadcast_as((b_size, 1, d))?
|
||||
.to_dtype(target_dtype)?;
|
||||
let embeddings = Tensor::cat(&[&class_embeds, &patch_embeds], 1)?;
|
||||
let position_embedding = self.position_embedding.narrow(1, 0, embeddings.dim(1)?)?;
|
||||
embeddings.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
qkv: Linear,
|
||||
projection: Linear,
|
||||
scale: f64,
|
||||
num_heads: usize,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_dim = cfg.hidden_size;
|
||||
let num_heads = cfg.num_attention_heads;
|
||||
let head_dim = embed_dim / num_heads;
|
||||
let scale = 1f64 / (head_dim as f64).sqrt();
|
||||
let qkv = linear(embed_dim, 3 * embed_dim, vb.pp("qkv"))?;
|
||||
let projection = linear(embed_dim, embed_dim, vb.pp("projection"))?;
|
||||
Ok(Self {
|
||||
qkv,
|
||||
projection,
|
||||
scale,
|
||||
num_heads,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, attn_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let (b_sz, tgt_len, embed_dim) = xs.dims3()?;
|
||||
let mixed_qkv = xs
|
||||
.apply(&self.qkv)?
|
||||
.reshape((b_sz, tgt_len, 3, self.num_heads, embed_dim / self.num_heads))?
|
||||
.permute((2, 0, 3, 1, 4))?;
|
||||
let query = mixed_qkv.get(0)?;
|
||||
let key = mixed_qkv.get(1)?;
|
||||
let value = mixed_qkv.get(2)?;
|
||||
let attention_scores = query.matmul(&key.t()?)?;
|
||||
let attention_scores = (attention_scores * self.scale)?;
|
||||
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||
let attention_probs = match attn_mask {
|
||||
None => attention_probs,
|
||||
Some(attn_mask) => (attention_probs * attn_mask)?,
|
||||
};
|
||||
attention_probs
|
||||
.matmul(&value)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.flatten_from(D::Minus2)?
|
||||
.apply(&self.projection)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct MLP {
|
||||
activation_fn: candle_nn::Activation,
|
||||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?;
|
||||
Ok(Self {
|
||||
activation_fn: cfg.hidden_act,
|
||||
fc1,
|
||||
fc2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.fc1)?
|
||||
.apply(&self.activation_fn)?
|
||||
.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderLayer {
|
||||
self_attn: Attention,
|
||||
layer_norm1: LayerNorm,
|
||||
mlp: MLP,
|
||||
layer_norm2: LayerNorm,
|
||||
}
|
||||
|
||||
impl EncoderLayer {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let embed_dim = cfg.hidden_size;
|
||||
let self_attn = Attention::new(cfg, vb.pp("self_attn"))?;
|
||||
let layer_norm1 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm1"))?;
|
||||
let layer_norm2 = layer_norm(embed_dim, cfg.layer_norm_eps, vb.pp("layer_norm2"))?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
layer_norm1,
|
||||
mlp,
|
||||
layer_norm2,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let residual = xs;
|
||||
let xs = xs.apply(&self.layer_norm1)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask)?;
|
||||
let xs = (xs + residual)?;
|
||||
|
||||
let residual = &xs;
|
||||
let xs = xs.apply(&self.layer_norm2)?.apply(&self.mlp)?;
|
||||
xs + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Encoder {
|
||||
layers: Vec<EncoderLayer>,
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb = vb.pp("layers");
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
let layer = EncoderLayer::new(cfg, vb.pp(i))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VisionModel {
|
||||
embeddings: VisionEmbeddings,
|
||||
encoder: Encoder,
|
||||
post_layernorm: LayerNorm,
|
||||
}
|
||||
|
||||
impl VisionModel {
|
||||
fn new(cfg: &VisionConfig, vb: VarBuilder) -> Result<Self> {
|
||||
let embeddings = VisionEmbeddings::new(cfg, vb.pp("embeddings"))?;
|
||||
let encoder = Encoder::new(cfg, vb.pp("encoder"))?;
|
||||
let post_layernorm =
|
||||
layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("post_layernorm"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
post_layernorm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionModel {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs = xs.apply(&self.embeddings)?;
|
||||
let encoder_outputs = self.encoder.forward(&xs, None)?;
|
||||
// Return the last hidden state rather than pooled outputs.
|
||||
encoder_outputs.apply(&self.post_layernorm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BlipForConditionalGeneration {
|
||||
vision_model: VisionModel,
|
||||
text_decoder: blip_text::TextLMHeadModel,
|
||||
}
|
||||
|
||||
impl BlipForConditionalGeneration {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vision_model = VisionModel::new(&cfg.vision_config, vb.pp("vision_model"))?;
|
||||
let text_decoder =
|
||||
blip_text::TextLMHeadModel::new(&cfg.text_config, vb.pp("text_decoder"))?;
|
||||
Ok(Self {
|
||||
vision_model,
|
||||
text_decoder,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn vision_model(&self) -> &VisionModel {
|
||||
&self.vision_model
|
||||
}
|
||||
|
||||
pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
|
||||
&mut self.text_decoder
|
||||
}
|
||||
}
|
476
candle-transformers/src/models/quantized_blip_text.rs
Normal file
476
candle-transformers/src/models/quantized_blip_text.rs
Normal file
@ -0,0 +1,476 @@
|
||||
use crate::models::with_tracing::QMatMul;
|
||||
use crate::quantized_nn::{layer_norm, linear, Embedding, Linear};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
use candle_nn::LayerNorm;
|
||||
|
||||
pub type Config = super::blip_text::Config;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextEmbeddings {
|
||||
word_embedddings: Embedding,
|
||||
position_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
position_ids: Tensor,
|
||||
}
|
||||
|
||||
impl TextEmbeddings {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let word_embedddings =
|
||||
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
|
||||
let position_embeddings = Embedding::new(
|
||||
cfg.max_position_embeddings,
|
||||
cfg.hidden_size,
|
||||
vb.pp("position_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
let position_ids =
|
||||
Tensor::arange(0, cfg.max_position_embeddings as u32, vb.device())?.unsqueeze(0)?;
|
||||
Ok(Self {
|
||||
word_embedddings,
|
||||
position_embeddings,
|
||||
layer_norm,
|
||||
position_ids,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let seq_len = xs.dim(1)?;
|
||||
let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
|
||||
let embeddings = self.word_embedddings.forward(xs)?;
|
||||
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
|
||||
(embeddings + position_embeddings)?.apply(&self.layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextSelfAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
attention_head_size: usize,
|
||||
num_attention_heads: usize,
|
||||
attention_scale: f64,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl TextSelfAttention {
|
||||
fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let num_attention_heads = cfg.num_attention_heads;
|
||||
let attention_head_size = cfg.hidden_size / num_attention_heads;
|
||||
let all_head_size = cfg.num_attention_heads * attention_head_size;
|
||||
let query = linear(cfg.hidden_size, all_head_size, vb.pp("query"))?;
|
||||
let in_size = if is_cross_attention {
|
||||
cfg.encoder_hidden_size
|
||||
} else {
|
||||
cfg.hidden_size
|
||||
};
|
||||
let key = linear(in_size, all_head_size, vb.pp("key"))?;
|
||||
let value = linear(in_size, all_head_size, vb.pp("value"))?;
|
||||
let attention_scale = 1f64 / (attention_head_size as f64).sqrt();
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_head_size,
|
||||
num_attention_heads,
|
||||
attention_scale,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b_size, seq_len, _) = xs.dims3()?;
|
||||
xs.reshape((
|
||||
b_size,
|
||||
seq_len,
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
))?
|
||||
.permute((0, 2, 1, 3))
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let query = self
|
||||
.transpose_for_scores(&self.query.forward(xs)?)?
|
||||
.contiguous()?;
|
||||
let (key, value) = match encoder_hidden_states {
|
||||
None => {
|
||||
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
|
||||
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
|
||||
let (key, value) = match &self.kv_cache {
|
||||
None => (key, value),
|
||||
Some((prev_key, prev_value)) => {
|
||||
let key = Tensor::cat(&[prev_key, &key], 2)?;
|
||||
let value = Tensor::cat(&[prev_value, &value], 2)?;
|
||||
(key, value)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((key.clone(), value.clone()));
|
||||
(key, value)
|
||||
}
|
||||
Some(xs) => {
|
||||
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
|
||||
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
|
||||
// no kv-cache in this case, but the results could probably be memoized.
|
||||
(key, value)
|
||||
}
|
||||
};
|
||||
let key = key.contiguous()?;
|
||||
let value = value.contiguous()?;
|
||||
let attention_scores = query.matmul(&key.t()?)?;
|
||||
let attention_scores = (attention_scores * self.attention_scale)?;
|
||||
let attention_scores = match attention_mask {
|
||||
Some(mask) => attention_scores.broadcast_add(mask)?,
|
||||
None => attention_scores,
|
||||
};
|
||||
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
||||
attention_probs
|
||||
.matmul(&value)?
|
||||
.permute((0, 2, 1, 3))?
|
||||
.flatten_from(D::Minus2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextSelfOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl TextSelfOutput {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self { dense, layer_norm })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
(xs.apply(&self.dense) + input_tensor)?.apply(&self.layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextAttention {
|
||||
self_: TextSelfAttention,
|
||||
output: TextSelfOutput,
|
||||
}
|
||||
|
||||
impl TextAttention {
|
||||
fn new(cfg: &Config, is_cross_attention: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let self_ = TextSelfAttention::new(cfg, is_cross_attention, vb.pp("self"))?;
|
||||
let output = TextSelfOutput::new(cfg, vb.pp("output"))?;
|
||||
Ok(Self { self_, output })
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.self_.reset_kv_cache()
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: Option<&Tensor>,
|
||||
attention_mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let self_outputs = self
|
||||
.self_
|
||||
.forward(xs, encoder_hidden_states, attention_mask)?;
|
||||
self.output.forward(&self_outputs, xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextIntermediate {
|
||||
dense: Linear,
|
||||
intermediate_act_fn: candle_nn::Activation,
|
||||
}
|
||||
|
||||
impl TextIntermediate {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("dense"))?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
intermediate_act_fn: cfg.hidden_act,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextIntermediate {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.dense)?.apply(&self.intermediate_act_fn)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl TextOutput {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self { dense, layer_norm })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
(xs.apply(&self.dense)? + input_tensor)?.apply(&self.layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextLayer {
|
||||
attention: TextAttention,
|
||||
cross_attention: Option<TextAttention>,
|
||||
intermediate: TextIntermediate,
|
||||
output: TextOutput,
|
||||
}
|
||||
|
||||
impl TextLayer {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention = TextAttention::new(cfg, false, vb.pp("attention"))?;
|
||||
let cross_attention = if cfg.is_decoder {
|
||||
Some(TextAttention::new(cfg, true, vb.pp("crossattention"))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let intermediate = TextIntermediate::new(cfg, vb.pp("intermediate"))?;
|
||||
let output = TextOutput::new(cfg, vb.pp("output"))?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
cross_attention,
|
||||
intermediate,
|
||||
output,
|
||||
})
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.attention.reset_kv_cache();
|
||||
if let Some(ca) = &mut self.cross_attention {
|
||||
ca.reset_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
|
||||
let attention_output = match &mut self.cross_attention {
|
||||
Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
|
||||
None => candle::bail!("expected some cross-attn"),
|
||||
};
|
||||
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
||||
self.output.forward(&intermediate_output, &attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextEncoder {
|
||||
layers: Vec<TextLayer>,
|
||||
}
|
||||
|
||||
impl TextEncoder {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("layer");
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
for i in 0..cfg.num_hidden_layers {
|
||||
let layer = TextLayer::new(cfg, vb.pp(i))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
xs: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextPooler {
|
||||
dense: Linear,
|
||||
}
|
||||
|
||||
impl TextPooler {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
Ok(Self { dense })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextPooler {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.narrow(D::Minus1, 0, 1)?
|
||||
.squeeze(D::Minus1)?
|
||||
.apply(&self.dense)?
|
||||
.tanh()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextPredictionHeadTransform {
|
||||
dense: Linear,
|
||||
transform_act_fn: candle_nn::Activation,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl TextPredictionHeadTransform {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
transform_act_fn: cfg.hidden_act,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextPredictionHeadTransform {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.dense)?
|
||||
.apply(&self.transform_act_fn)?
|
||||
.apply(&self.layer_norm)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextLMPredictionHead {
|
||||
transform: TextPredictionHeadTransform,
|
||||
decoder: Linear,
|
||||
}
|
||||
|
||||
impl TextLMPredictionHead {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let transform = TextPredictionHeadTransform::new(cfg, vb.pp("transform"))?;
|
||||
let weight = QMatMul::new(cfg.hidden_size, cfg.vocab_size, vb.pp("decoder"))?;
|
||||
let bias = vb.get(cfg.vocab_size, "bias")?.dequantize(vb.device())?;
|
||||
let decoder = Linear::from_weights(weight, Some(bias));
|
||||
Ok(Self { transform, decoder })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextLMPredictionHead {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.transform)?.apply(&self.decoder)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextOnlyMLMHead {
|
||||
predictions: TextLMPredictionHead,
|
||||
}
|
||||
|
||||
impl TextOnlyMLMHead {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let predictions = TextLMPredictionHead::new(cfg, vb.pp("predictions"))?;
|
||||
Ok(Self { predictions })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for TextOnlyMLMHead {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
self.predictions.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct TextModel {
|
||||
embeddings: TextEmbeddings,
|
||||
encoder: TextEncoder,
|
||||
past_kv_len: usize,
|
||||
// We do not need the pooler for caption generation
|
||||
}
|
||||
|
||||
impl TextModel {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let embeddings = TextEmbeddings::new(cfg, vb.pp("embeddings"))?;
|
||||
let encoder = TextEncoder::new(cfg, vb.pp("encoder"))?;
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
past_kv_len: 0,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = input_ids.dims2()?;
|
||||
let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
|
||||
let sequence_output =
|
||||
self.encoder
|
||||
.forward(&embedding_output, encoder_hidden_states, attention_mask)?;
|
||||
self.past_kv_len += seq_len;
|
||||
// We're interested in the sequence-output rather than the pooled-output.
|
||||
Ok(sequence_output)
|
||||
}
|
||||
|
||||
fn reset_kv_cache(&mut self) {
|
||||
self.past_kv_len = 0;
|
||||
self.encoder.reset_kv_cache();
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TextLMHeadModel {
|
||||
bert: TextModel,
|
||||
cls: TextOnlyMLMHead,
|
||||
}
|
||||
|
||||
impl TextLMHeadModel {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let bert = TextModel::new(cfg, vb.pp("bert"))?;
|
||||
let cls = TextOnlyMLMHead::new(cfg, vb.pp("cls"))?;
|
||||
Ok(Self { bert, cls })
|
||||
}
|
||||
|
||||
pub fn forward(
|
||||
&mut self,
|
||||
input_ids: &Tensor,
|
||||
encoder_hidden_states: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
let seq_len = input_ids.dim(1)?;
|
||||
let mask: Vec<_> = (0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (seq_len, seq_len), input_ids.device())?;
|
||||
let sequence_output = self.bert.forward(input_ids, encoder_hidden_states, &mask)?;
|
||||
let prediction_scores = self.cls.forward(&sequence_output)?;
|
||||
// return_logits is false so we don't discard the last sequence element.
|
||||
Ok(prediction_scores)
|
||||
}
|
||||
|
||||
pub fn reset_kv_cache(&mut self) {
|
||||
self.bert.reset_kv_cache()
|
||||
}
|
||||
}
|
@ -34,6 +34,12 @@ pub struct Linear {
|
||||
bias: Option<Tensor>,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let x = x.apply(&self.weight)?;
|
||||
|
Reference in New Issue
Block a user