* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting

* distilbet files

* Apply various cleanups.

* More cleanups.

* More polish.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
Odunayo
2023-11-24 10:09:14 -05:00
committed by GitHub
parent ca19a9af62
commit 762e996ce6
4 changed files with 500 additions and 0 deletions

View File

@ -0,0 +1,22 @@
# candle-distilbert
DistilBert is a distiled version of the Bert model.
## Sentence embeddings
DistilBert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
> ...
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
> Tensor[[1, 7, 768], f32]
```

View File

@ -0,0 +1,135 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: String,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
}
impl Args {
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let default_model = "distilbert-base-uncased".to_string();
let default_revision = "main".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = DistilBertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}
fn get_mask(size: usize, device: &Device) -> Tensor {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device).unwrap()
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mask = get_mask(tokens.len(), device);
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
println!("mask: {:?}", mask.to_vec2::<u8>());
let ys = model.forward(&token_ids, &mask)?;
println!("{ys}");
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}

View File

@ -0,0 +1,342 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
enum HiddenAct {
Gelu,
Relu,
}
struct HiddenActLayer {
act: HiddenAct,
span: tracing::Span,
}
impl HiddenActLayer {
fn new(act: HiddenAct) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
Self { act, span }
}
}
impl Module for HiddenActLayer {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
match self.act {
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
HiddenAct::Gelu => xs.gelu(),
HiddenAct::Relu => xs.relu(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
enum PositionEmbeddingType {
#[default]
Absolute,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
vocab_size: usize,
dim: usize,
n_layers: usize,
n_heads: usize,
hidden_dim: usize,
activation: HiddenAct,
max_position_embeddings: usize,
initializer_range: f64,
pad_token_id: usize,
#[serde(default)]
position_embedding_type: PositionEmbeddingType,
#[serde(default)]
use_cache: bool,
model_type: Option<String>,
}
impl Default for Config {
fn default() -> Self {
Self {
vocab_size: 30522,
dim: 768,
n_layers: 12,
n_heads: 12,
hidden_dim: 3072,
activation: HiddenAct::Gelu,
max_position_embeddings: 512,
initializer_range: 0.02,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Absolute,
use_cache: true,
model_type: Some("distilbert".to_string()),
}
}
}
struct Embeddings {
word_embeddings: Embedding,
position_embeddings: Embedding,
layer_norm: LayerNorm,
span: tracing::Span,
}
impl Embeddings {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let word_embeddings =
candle_nn::embedding(config.vocab_size, config.dim, vb.pp("word_embeddings"))?;
let position_embeddings = candle_nn::embedding(
config.max_position_embeddings,
config.dim,
vb.pp("position_embeddings"),
)?;
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("LayerNorm"))?;
Ok(Self {
word_embeddings,
position_embeddings,
layer_norm,
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
})
}
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (_bsize, seq_len) = input_ids.dims2()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
let embeddings =
input_embeddings.broadcast_add(&self.position_embeddings.forward(&position_ids)?)?;
let embeddings = self.layer_norm.forward(&embeddings)?;
Ok(embeddings)
}
}
struct MultiHeadSelfAttention {
q_lin: Linear,
k_lin: Linear,
v_lin: Linear,
out_lin: Linear,
n_heads: usize,
attention_head_size: usize,
span: tracing::Span,
}
impl MultiHeadSelfAttention {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention_head_size = config.dim / config.n_heads;
let all_head_size = config.n_heads * attention_head_size;
let dim = config.dim;
let q_lin = linear(dim, all_head_size, vb.pp("q_lin"))?;
let v_lin = linear(dim, all_head_size, vb.pp("v_lin"))?;
let k_lin = linear(dim, all_head_size, vb.pp("k_lin"))?;
let out_lin = linear(all_head_size, dim, vb.pp("out_lin"))?;
Ok(Self {
q_lin,
k_lin,
v_lin,
out_lin,
n_heads: config.n_heads,
attention_head_size,
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}
}
impl MultiHeadSelfAttention {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (bs, q_length, _dim) = hidden_states.dims3()?;
let dim_per_head = self.attention_head_size;
let q = self.q_lin.forward(hidden_states)?;
let k = self.k_lin.forward(hidden_states)?;
let v = self.v_lin.forward(hidden_states)?;
let q = q
.reshape((bs, q_length, self.n_heads, dim_per_head))?
.transpose(1, 2)?;
let k = k
.reshape((bs, q_length, self.n_heads, dim_per_head))?
.transpose(1, 2)?;
let v = v
.reshape((bs, q_length, self.n_heads, dim_per_head))?
.transpose(1, 2)?;
let q: Tensor = (q / (dim_per_head as f64).sqrt())?;
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
let mask = attention_mask.broadcast_as(scores.shape())?;
let scores = masked_fill(&scores.to_dtype(DType::F32)?, &mask, f32::NEG_INFINITY)?;
let weights = candle_nn::ops::softmax(&scores, candle::D::Minus1)?;
let context = weights.matmul(&v.contiguous()?)?;
let context = context
.transpose(1, 2)?
.reshape((bs, q_length, self.n_heads * dim_per_head))?
.contiguous()?;
let context = self.out_lin.forward(&context)?;
Ok(context)
}
}
#[allow(clippy::upper_case_acronyms)]
struct FFN {
lin1: Linear,
lin2: Linear,
activation: HiddenActLayer,
span: tracing::Span,
}
impl FFN {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let lin1 = linear(config.dim, config.hidden_dim, vb.pp("lin1"))?;
let lin2 = linear(config.hidden_dim, config.dim, vb.pp("lin2"))?;
Ok(Self {
lin1,
lin2,
activation: HiddenActLayer::new(config.activation),
span: tracing::span!(tracing::Level::TRACE, "ffn"),
})
}
}
impl Module for FFN {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
hidden_states
.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
struct TransformerBlock {
attention: MultiHeadSelfAttention,
sa_layer_norm: LayerNorm,
ffn: FFN,
output_layer_norm: LayerNorm,
span: tracing::Span,
}
impl TransformerBlock {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let attention = MultiHeadSelfAttention::load(vb.pp("attention"), config)?;
let sa_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("sa_layer_norm"))?;
let ffn = FFN::load(vb.pp("ffn"), config)?;
let output_layer_norm = layer_norm(config.dim, 1e-12, vb.pp("output_layer_norm"))?;
Ok(Self {
attention,
sa_layer_norm,
ffn,
output_layer_norm,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
}
impl TransformerBlock {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let sa_output = self.attention.forward(hidden_states, attention_mask)?;
// TODO: Support cross-attention?
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
// TODO: Support something similar to `apply_chunking_to_forward`?
let sa_output = sa_output.broadcast_add(hidden_states)?;
let sa_output = self.sa_layer_norm.forward(&sa_output)?;
let ffn_output = self.ffn.forward(&sa_output)?;
let ffn_output = (&ffn_output + sa_output)?;
let output = self.output_layer_norm.forward(&ffn_output)?;
Ok(output)
}
}
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
struct Transformer {
layers: Vec<TransformerBlock>,
span: tracing::Span,
}
impl Transformer {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let layers = (0..config.n_layers)
.map(|index| TransformerBlock::load(vb.pp(&format!("layer.{index}")), config))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(Transformer { layers, span })
}
}
impl Transformer {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();
// Use a loop rather than a fold as it's easier to modify when adding debug/...
for layer in self.layers.iter() {
hidden_states = layer.forward(&hidden_states, attention_mask)?;
}
Ok(hidden_states)
}
}
pub struct DistilBertModel {
embeddings: Embeddings,
transformer: Transformer,
pub device: Device,
span: tracing::Span,
}
impl DistilBertModel {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let (embeddings, transformer) = match (
Embeddings::load(vb.pp("embeddings"), config),
Transformer::load(vb.pp("transformer"), config),
) {
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
(Err(err), _) | (_, Err(err)) => {
if let Some(model_type) = &config.model_type {
if let (Ok(embeddings), Ok(encoder)) = (
Embeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
Transformer::load(vb.pp(&format!("{model_type}.transformer")), config),
) {
(embeddings, encoder)
} else {
return Err(err);
}
} else {
return Err(err);
}
}
};
Ok(Self {
embeddings,
transformer,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids)?;
let sequence_output = self
.transformer
.forward(&embedding_output, attention_mask)?;
Ok(sequence_output)
}
}

View File

@ -4,6 +4,7 @@ pub mod blip;
pub mod blip_text;
pub mod convmixer;
pub mod dinov2;
pub mod distilbert;
pub mod efficientnet;
pub mod falcon;
pub mod jina_bert;