mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Distibert (#1366)
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting * distilbet files * Apply various cleanups. * More cleanups. * More polish. --------- Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
22
candle-examples/examples/distilbert/README.md
Normal file
22
candle-examples/examples/distilbert/README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# candle-distilbert
|
||||
|
||||
DistilBert is a distiled version of the Bert model.
|
||||
|
||||
## Sentence embeddings
|
||||
|
||||
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
|
||||
> ...
|
||||
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
|
||||
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
|
||||
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
|
||||
```
|
135
candle-examples/examples/distilbert/main.rs
Normal file
135
candle-examples/examples/distilbert/main.rs
Normal file
@ -0,0 +1,135 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
/// When set, compute embeddings for this prompt.
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use the pytorch weights rather than the safetensors ones
|
||||
#[arg(long)]
|
||||
use_pth: bool,
|
||||
|
||||
/// The number of times to run the prompt.
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
let default_model = "distilbert-base-uncased".to_string();
|
||||
let default_revision = "main".to_string();
|
||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
};
|
||||
let model = DistilBertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
|
||||
let tokenizer = tokenizer
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mask = get_mask(tokens.len(), device);
|
||||
|
||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||
|
||||
let ys = model.forward(&token_ids, &mask)?;
|
||||
println!("{ys}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
}
|
342
candle-transformers/src/models/distilbert.rs
Normal file
342
candle-transformers/src/models/distilbert.rs
Normal file
@ -0,0 +1,342 @@
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
struct HiddenActLayer {
|
||||
act: HiddenAct,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl HiddenActLayer {
|
||||
fn new(act: HiddenAct) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
|
||||
Self { act, span }
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for HiddenActLayer {
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
match self.act {
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||
HiddenAct::Gelu => xs.gelu(),
|
||||
HiddenAct::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
#[default]
|
||||
Absolute,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
dim: usize,
|
||||
n_layers: usize,
|
||||
n_heads: usize,
|
||||
hidden_dim: usize,
|
||||
activation: HiddenAct,
|
||||
max_position_embeddings: usize,
|
||||
initializer_range: f64,
|
||||
pad_token_id: usize,
|
||||
#[serde(default)]
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
#[serde(default)]
|
||||
use_cache: bool,
|
||||
model_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 30522,
|
||||
dim: 768,
|
||||
n_layers: 12,
|
||||
n_heads: 12,
|
||||
hidden_dim: 3072,
|
||||
activation: HiddenAct::Gelu,
|
||||
max_position_embeddings: 512,
|
||||
initializer_range: 0.02,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
model_type: Some("distilbert".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Embeddings {
|
||||
word_embeddings: Embedding,
|
||||
position_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Embeddings {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings =
|
||||
candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
|
||||
let position_embeddings = candle_nn::embedding(
|
||||
config.max_position_embeddings,
|
||||
config.dim,
|
||||
vb.pp("position_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
position_embeddings,
|
||||
layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
||||
let embeddings =
|
||||
input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
|
||||
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
struct MultiHeadSelfAttention {
|
||||
q_lin: Linear,
|
||||
k_lin: Linear,
|
||||
v_lin: Linear,
|
||||
out_lin: Linear,
|
||||
n_heads: usize,
|
||||
attention_head_size: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttention {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention_head_size = config.dim / config.n_heads;
|
||||
let all_head_size = config.n_heads * attention_head_size;
|
||||
let dim = config.dim;
|
||||
let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
|
||||
let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
|
||||
let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
|
||||
let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
|
||||
Ok(Self {
|
||||
q_lin,
|
||||
k_lin,
|
||||
v_lin,
|
||||
out_lin,
|
||||
n_heads: config.n_heads,
|
||||
attention_head_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttention {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (bs, q_length, _dim) = hidden_states.dims3()?;
|
||||
|
||||
let dim_per_head = self.attention_head_size;
|
||||
let q = self.q_lin.forward(hidden_states)?;
|
||||
let k = self.k_lin.forward(hidden_states)?;
|
||||
let v = self.v_lin.forward(hidden_states)?;
|
||||
|
||||
let q = q
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((bs, q_length, self.n_heads, dim_per_head))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
|
||||
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
|
||||
let mask = attention_mask.broadcast_as(scores.shape())?;
|
||||
|
||||
let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
|
||||
let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
|
||||
|
||||
let context = weights.matmul(&v.contiguous()?)?;
|
||||
let context = context
|
||||
.transpose(1, 2)?
|
||||
.reshape((bs, q_length, self.n_heads * dim_per_head))?
|
||||
.contiguous()?;
|
||||
let context = self.out_lin.forward(&context)?;
|
||||
|
||||
Ok(context)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::upper_case_acronyms)]
|
||||
struct FFN {
|
||||
lin1: Linear,
|
||||
lin2: Linear,
|
||||
activation: HiddenActLayer,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FFN {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
|
||||
let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
|
||||
Ok(Self {
|
||||
lin1,
|
||||
lin2,
|
||||
activation: HiddenActLayer::new(config.activation),
|
||||
span: tracing::span!(tracing::Level::TRACE, "ffn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for FFN {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
hidden_states
|
||||
.apply(&self.lin1)?
|
||||
.apply(&self.activation)?
|
||||
.apply(&self.lin2)
|
||||
}
|
||||
}
|
||||
|
||||
struct TransformerBlock {
|
||||
attention: MultiHeadSelfAttention,
|
||||
sa_layer_norm: LayerNorm,
|
||||
ffn: FFN,
|
||||
output_layer_norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl TransformerBlock {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
|
||||
let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
|
||||
let ffn = FFN::load(vb.pp("ffn"), config)?;
|
||||
let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
sa_layer_norm,
|
||||
ffn,
|
||||
output_layer_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TransformerBlock {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let sa_output = self.attention.forward(hidden_states, attention_mask)?;
|
||||
// TODO: Support cross-attention?
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||
let sa_output = sa_output.broadcast_add(hidden_states)?;
|
||||
let sa_output = self.sa_layer_norm.forward(&sa_output)?;
|
||||
|
||||
let ffn_output = self.ffn.forward(&sa_output)?;
|
||||
let ffn_output = (&ffn_output + sa_output)?;
|
||||
let output = self.output_layer_norm.forward(&ffn_output)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
||||
struct Transformer {
|
||||
layers: Vec<TransformerBlock>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.n_layers)
|
||||
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(Transformer { layers, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Transformer {
|
||||
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
for layer in self.layers.iter() {
|
||||
hidden_states = layer.forward(&hidden_states, attention_mask)?;
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DistilBertModel {
|
||||
embeddings: Embeddings,
|
||||
transformer: Transformer,
|
||||
pub device: Device,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DistilBertModel {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let (embeddings, transformer) = match (
|
||||
Embeddings::load(vb.pp("embeddings"), config),
|
||||
Transformer::load(vb.pp("transformer"), config),
|
||||
) {
|
||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let Some(model_type) = &config.model_type {
|
||||
if let (Ok(embeddings), Ok(encoder)) = (
|
||||
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
|
||||
) {
|
||||
(embeddings, encoder)
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
transformer,
|
||||
device: vb.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let embedding_output = self.embeddings.forward(input_ids)?;
|
||||
let sequence_output = self
|
||||
.transformer
|
||||
.forward(&embedding_output, attention_mask)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
@ -4,6 +4,7 @@ pub mod blip;
|
||||
pub mod blip_text;
|
||||
pub mod convmixer;
|
||||
pub mod dinov2;
|
||||
pub mod distilbert;
|
||||
pub mod efficientnet;
|
||||
pub mod falcon;
|
||||
pub mod jina_bert;
|
||||
|
Reference in New Issue
Block a user