mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Move some models to candle-transformers so that it's easier to re-use. (#794)
* Move some models to candle-transformers so that they can be shared. * Also move falcon. * Move Llama. * Move whisper (partial).
This commit is contained in:
@ -13,18 +13,18 @@ readme = "README.md"
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
@ -3,14 +3,13 @@ extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
mod model;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use model::{BertModel, Config, DTYPE};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
|
@ -1,568 +0,0 @@
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
|
||||
struct HiddenActLayer {
|
||||
act: HiddenAct,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl HiddenActLayer {
|
||||
fn new(act: HiddenAct) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "hidden-act");
|
||||
Self { act, span }
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
match self.act {
|
||||
// TODO: The all-MiniLM-L6-v2 model uses "gelu" whereas this is "gelu_new", this explains some
|
||||
// small numerical difference.
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
|
||||
HiddenAct::Gelu => xs.gelu(),
|
||||
HiddenAct::Relu => xs.relu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Self { weight, bias, span }
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let w = match x.dims() {
|
||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
||||
_ => self.weight.t()?,
|
||||
};
|
||||
let x = x.matmul(&w)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||
Self {
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
span,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.dims3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
#[default]
|
||||
Absolute,
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/configuration_bert.py#L1
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
hidden_size: usize,
|
||||
num_hidden_layers: usize,
|
||||
num_attention_heads: usize,
|
||||
intermediate_size: usize,
|
||||
hidden_act: HiddenAct,
|
||||
hidden_dropout_prob: f64,
|
||||
max_position_embeddings: usize,
|
||||
type_vocab_size: usize,
|
||||
initializer_range: f64,
|
||||
layer_norm_eps: f64,
|
||||
pad_token_id: usize,
|
||||
#[serde(default)]
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
#[serde(default)]
|
||||
use_cache: bool,
|
||||
classifier_dropout: Option<f64>,
|
||||
model_type: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 30522,
|
||||
hidden_size: 768,
|
||||
num_hidden_layers: 12,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 3072,
|
||||
hidden_act: HiddenAct::Gelu,
|
||||
hidden_dropout_prob: 0.1,
|
||||
max_position_embeddings: 512,
|
||||
type_vocab_size: 2,
|
||||
initializer_range: 0.02,
|
||||
layer_norm_eps: 1e-12,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
classifier_dropout: None,
|
||||
model_type: Some("bert".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn _all_mini_lm_l6_v2() -> Self {
|
||||
// https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/blob/main/config.json
|
||||
Self {
|
||||
vocab_size: 30522,
|
||||
hidden_size: 384,
|
||||
num_hidden_layers: 6,
|
||||
num_attention_heads: 12,
|
||||
intermediate_size: 1536,
|
||||
hidden_act: HiddenAct::Gelu,
|
||||
hidden_dropout_prob: 0.1,
|
||||
max_position_embeddings: 512,
|
||||
type_vocab_size: 2,
|
||||
initializer_range: 0.02,
|
||||
layer_norm_eps: 1e-12,
|
||||
pad_token_id: 0,
|
||||
position_embedding_type: PositionEmbeddingType::Absolute,
|
||||
use_cache: true,
|
||||
classifier_dropout: None,
|
||||
model_type: Some("bert".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = vb.get(size2, "bias")?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
#[allow(dead_code)]
|
||||
pr: f64,
|
||||
}
|
||||
|
||||
impl Dropout {
|
||||
fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L180
|
||||
struct BertEmbeddings {
|
||||
word_embeddings: Embedding,
|
||||
position_embeddings: Option<Embedding>,
|
||||
token_type_embeddings: Embedding,
|
||||
layer_norm: LayerNorm,
|
||||
dropout: Dropout,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertEmbeddings {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let word_embeddings = embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
vb.pp("word_embeddings"),
|
||||
)?;
|
||||
let position_embeddings = embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
vb.pp("position_embeddings"),
|
||||
)?;
|
||||
let token_type_embeddings = embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
vb.pp("token_type_embeddings"),
|
||||
)?;
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
position_embeddings: Some(position_embeddings),
|
||||
token_type_embeddings,
|
||||
layer_norm,
|
||||
dropout: Dropout::new(config.hidden_dropout_prob),
|
||||
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_bsize, seq_len) = input_ids.dims2()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
if let Some(position_embeddings) = &self.position_embeddings {
|
||||
// TODO: Proper absolute positions?
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
let position_ids = Tensor::new(&position_ids[..], input_ids.device())?;
|
||||
embeddings = embeddings.broadcast_add(&position_embeddings.forward(&position_ids)?)?
|
||||
}
|
||||
let embeddings = self.layer_norm.forward(&embeddings)?;
|
||||
let embeddings = self.dropout.forward(&embeddings)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
struct BertSelfAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
dropout: Dropout,
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
span: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention_head_size = config.hidden_size / config.num_attention_heads;
|
||||
let all_head_size = config.num_attention_heads * attention_head_size;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
let hidden_size = config.hidden_size;
|
||||
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
|
||||
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
|
||||
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout,
|
||||
num_attention_heads: config.num_attention_heads,
|
||||
attention_head_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||
})
|
||||
}
|
||||
|
||||
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut new_x_shape = xs.dims().to_vec();
|
||||
new_x_shape.pop();
|
||||
new_x_shape.push(self.num_attention_heads);
|
||||
new_x_shape.push(self.attention_head_size);
|
||||
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
|
||||
xs.contiguous()
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query_layer = self.query.forward(hidden_states)?;
|
||||
let key_layer = self.key.forward(hidden_states)?;
|
||||
let value_layer = self.value.forward(hidden_states)?;
|
||||
|
||||
let query_layer = self.transpose_for_scores(&query_layer)?;
|
||||
let key_layer = self.transpose_for_scores(&key_layer)?;
|
||||
let value_layer = self.transpose_for_scores(&value_layer)?;
|
||||
|
||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_probs = {
|
||||
let _enter_sm = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax(&attention_scores, candle::D::Minus1)?
|
||||
};
|
||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
|
||||
let context_layer = context_layer.flatten_from(candle::D::Minus2)?;
|
||||
Ok(context_layer)
|
||||
}
|
||||
}
|
||||
|
||||
struct BertSelfOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
dropout: Dropout,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertSelfOutput {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-out"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L392
|
||||
struct BertAttention {
|
||||
self_attention: BertSelfAttention,
|
||||
self_output: BertSelfOutput,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertAttention {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let self_attention = BertSelfAttention::load(vb.pp("self"), config)?;
|
||||
let self_output = BertSelfOutput::load(vb.pp("output"), config)?;
|
||||
Ok(Self {
|
||||
self_attention,
|
||||
self_output,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let self_outputs = self.self_attention.forward(hidden_states)?;
|
||||
let attention_output = self.self_output.forward(&self_outputs, hidden_states)?;
|
||||
Ok(attention_output)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L441
|
||||
struct BertIntermediate {
|
||||
dense: Linear,
|
||||
intermediate_act: HiddenActLayer,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertIntermediate {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.hidden_size, config.intermediate_size, vb.pp("dense"))?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
intermediate_act: HiddenActLayer::new(config.hidden_act),
|
||||
span: tracing::span!(tracing::Level::TRACE, "inter"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let ys = self.intermediate_act.forward(&hidden_states)?;
|
||||
Ok(ys)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L456
|
||||
struct BertOutput {
|
||||
dense: Linear,
|
||||
layer_norm: LayerNorm,
|
||||
dropout: Dropout,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertOutput {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.intermediate_size, config.hidden_size, vb.pp("dense"))?;
|
||||
let layer_norm = layer_norm(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
vb.pp("LayerNorm"),
|
||||
)?;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
layer_norm,
|
||||
dropout,
|
||||
span: tracing::span!(tracing::Level::TRACE, "out"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states = self.dense.forward(hidden_states)?;
|
||||
let hidden_states = self.dropout.forward(&hidden_states)?;
|
||||
self.layer_norm.forward(&(hidden_states + input_tensor)?)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L470
|
||||
struct BertLayer {
|
||||
attention: BertAttention,
|
||||
intermediate: BertIntermediate,
|
||||
output: BertOutput,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertLayer {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let attention = BertAttention::load(vb.pp("attention"), config)?;
|
||||
let intermediate = BertIntermediate::load(vb.pp("intermediate"), config)?;
|
||||
let output = BertOutput::load(vb.pp("output"), config)?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
intermediate,
|
||||
output,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attention_output = self.attention.forward(hidden_states)?;
|
||||
// TODO: Support cross-attention?
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
// TODO: Support something similar to `apply_chunking_to_forward`?
|
||||
let intermediate_output = self.intermediate.forward(&attention_output)?;
|
||||
let layer_output = self
|
||||
.output
|
||||
.forward(&intermediate_output, &attention_output)?;
|
||||
Ok(layer_output)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L556
|
||||
struct BertEncoder {
|
||||
layers: Vec<BertLayer>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertEncoder {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let layers = (0..config.num_hidden_layers)
|
||||
.map(|index| BertLayer::load(vb.pp(&format!("layer.{index}")), config))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "encoder");
|
||||
Ok(BertEncoder { layers, span })
|
||||
}
|
||||
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut hidden_states = hidden_states.clone();
|
||||
// Use a loop rather than a fold as it's easier to modify when adding debug/...
|
||||
for layer in self.layers.iter() {
|
||||
hidden_states = layer.forward(&hidden_states)?
|
||||
}
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
|
||||
pub struct BertModel {
|
||||
embeddings: BertEmbeddings,
|
||||
encoder: BertEncoder,
|
||||
pub device: Device,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertModel {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let (embeddings, encoder) = match (
|
||||
BertEmbeddings::load(vb.pp("embeddings"), config),
|
||||
BertEncoder::load(vb.pp("encoder"), config),
|
||||
) {
|
||||
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let Some(model_type) = &config.model_type {
|
||||
if let (Ok(embeddings), Ok(encoder)) = (
|
||||
BertEmbeddings::load(vb.pp(&format!("{model_type}.embeddings")), config),
|
||||
BertEncoder::load(vb.pp(&format!("{model_type}.encoder")), config),
|
||||
) {
|
||||
(embeddings, encoder)
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
} else {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
device: vb.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, token_type_ids: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
@ -7,8 +7,7 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
mod model;
|
||||
use model::{Config, GPTBigCode};
|
||||
use candle_transformers::models::bigcode::{Config, GPTBigCode};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
@ -1,359 +0,0 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
fn make_causal_mask(t: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j <= i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
// max_position_embeddings aka n_positions
|
||||
pub max_position_embeddings: usize,
|
||||
// num_hidden_layers aka n_layer
|
||||
pub num_hidden_layers: usize,
|
||||
// hidden_size aka n_embd
|
||||
pub hidden_size: usize,
|
||||
pub layer_norm_epsilon: f64,
|
||||
pub n_inner: Option<usize>,
|
||||
// num_attention_heads aka n_head
|
||||
pub num_attention_heads: usize,
|
||||
pub multi_query: bool,
|
||||
pub use_cache: bool,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
#[allow(dead_code)]
|
||||
pub fn starcoder_1b() -> Self {
|
||||
Self {
|
||||
vocab_size: 49152,
|
||||
max_position_embeddings: 8192,
|
||||
num_hidden_layers: 24,
|
||||
hidden_size: 2048,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
n_inner: Some(8192),
|
||||
num_attention_heads: 16,
|
||||
multi_query: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn starcoder_3b() -> Self {
|
||||
Self {
|
||||
vocab_size: 49152,
|
||||
max_position_embeddings: 8192,
|
||||
num_hidden_layers: 36,
|
||||
hidden_size: 2816,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
n_inner: Some(11264),
|
||||
num_attention_heads: 22,
|
||||
multi_query: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn starcoder_7b() -> Self {
|
||||
Self {
|
||||
vocab_size: 49152,
|
||||
max_position_embeddings: 8192,
|
||||
num_hidden_layers: 42,
|
||||
hidden_size: 4096,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
n_inner: Some(16384),
|
||||
num_attention_heads: 32,
|
||||
multi_query: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn starcoder() -> Self {
|
||||
Self {
|
||||
vocab_size: 49152,
|
||||
max_position_embeddings: 8192,
|
||||
num_hidden_layers: 40,
|
||||
hidden_size: 6144,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
n_inner: Some(24576),
|
||||
num_attention_heads: 48,
|
||||
multi_query: true,
|
||||
use_cache: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Attention {
|
||||
c_attn: Linear,
|
||||
c_proj: Linear,
|
||||
kv_cache: Option<Tensor>,
|
||||
use_cache: bool,
|
||||
embed_dim: usize,
|
||||
kv_dim: usize,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
multi_query: bool,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let head_dim = hidden_size / cfg.num_attention_heads;
|
||||
let kv_heads = if cfg.multi_query {
|
||||
1
|
||||
} else {
|
||||
cfg.num_attention_heads
|
||||
};
|
||||
let kv_dim = kv_heads * head_dim;
|
||||
let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?;
|
||||
let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?;
|
||||
Ok(Self {
|
||||
c_proj,
|
||||
c_attn,
|
||||
embed_dim: hidden_size,
|
||||
kv_cache: None,
|
||||
use_cache: cfg.use_cache,
|
||||
kv_dim,
|
||||
head_dim,
|
||||
num_heads: cfg.num_attention_heads,
|
||||
multi_query: cfg.multi_query,
|
||||
})
|
||||
}
|
||||
|
||||
fn attn(
|
||||
&self,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
value: &Tensor,
|
||||
attention_mask: &Tensor,
|
||||
) -> Result<Tensor> {
|
||||
if query.dtype() != DType::F32 {
|
||||
// If we start supporting f16 models, we may need the upcasting scaling bits.
|
||||
// https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133
|
||||
candle::bail!("upcasting is not supported {:?}", query.dtype())
|
||||
}
|
||||
let scale_factor = 1f64 / (self.head_dim as f64).sqrt();
|
||||
let initial_query_shape = query.shape();
|
||||
let key_len = key.dim(D::Minus1)?;
|
||||
let (query, key, attn_shape, attn_view) = if self.multi_query {
|
||||
let (b_sz, query_len, _) = query.dims3()?;
|
||||
let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
|
||||
let attn_shape = (b_sz, query_len, self.num_heads, key_len);
|
||||
let attn_view = (b_sz, query_len * self.num_heads, key_len);
|
||||
(query, key.clone(), attn_shape, attn_view)
|
||||
} else {
|
||||
let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?;
|
||||
let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?;
|
||||
let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?;
|
||||
let attn_shape = (b_sz, self.num_heads, query_len, key_len);
|
||||
let attn_view = (b_sz * self.num_heads, query_len, key_len);
|
||||
(query, key, attn_shape, attn_view)
|
||||
};
|
||||
|
||||
let attn_weights =
|
||||
(query.matmul(&key.contiguous()?)? * scale_factor)?.reshape(attn_shape)?;
|
||||
let attention_mask = attention_mask.broadcast_as(attn_shape)?;
|
||||
let mask_value =
|
||||
Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?;
|
||||
let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?;
|
||||
let attn_weights = candle_nn::ops::softmax(&attn_weights, D::Minus1)?;
|
||||
let value = value.contiguous()?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_weights
|
||||
.reshape(attn_view)?
|
||||
.matmul(&value)?
|
||||
.reshape(initial_query_shape)?
|
||||
} else {
|
||||
attn_weights.matmul(&value)?
|
||||
};
|
||||
Ok(attn_output)
|
||||
}
|
||||
|
||||
fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let qkv = self.c_attn.forward(hidden_states)?;
|
||||
let (query, key_value) = if self.multi_query {
|
||||
let query = qkv.i((.., .., ..self.embed_dim))?;
|
||||
let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?;
|
||||
(query, key_value)
|
||||
} else {
|
||||
let mut dims = qkv.dims().to_vec();
|
||||
dims.pop();
|
||||
dims.push(self.embed_dim);
|
||||
dims.push(self.head_dim * 3);
|
||||
let qkv = qkv.reshape(dims)?.transpose(1, 2)?;
|
||||
let query = qkv.i((.., .., .., ..self.head_dim))?;
|
||||
let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?;
|
||||
(query, key_value)
|
||||
};
|
||||
let mut key_value = key_value;
|
||||
if self.use_cache {
|
||||
if let Some(kv_cache) = &self.kv_cache {
|
||||
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
|
||||
// arbitrarily large sizes.
|
||||
key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?;
|
||||
}
|
||||
self.kv_cache = Some(key_value.clone())
|
||||
}
|
||||
|
||||
let key = key_value.narrow(D::Minus1, 0, self.head_dim)?;
|
||||
let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?;
|
||||
let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?;
|
||||
let attn_output = if self.multi_query {
|
||||
attn_output
|
||||
} else {
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape(hidden_states.shape())?
|
||||
};
|
||||
let attn_output = self.c_proj.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
c_fc: Linear,
|
||||
c_proj: Linear,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?;
|
||||
let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?;
|
||||
Ok(Self { c_fc, c_proj })
|
||||
}
|
||||
|
||||
fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?;
|
||||
let hidden_states = self.c_proj.forward(&hidden_states)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add cross-attention?
|
||||
struct Block {
|
||||
ln_1: LayerNorm,
|
||||
attn: Attention,
|
||||
ln_2: LayerNorm,
|
||||
mlp: Mlp,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size);
|
||||
let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?;
|
||||
let attn = Attention::load(vb.pp("attn"), cfg)?;
|
||||
let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?;
|
||||
let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?;
|
||||
Ok(Self {
|
||||
ln_1,
|
||||
attn,
|
||||
ln_2,
|
||||
mlp,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let residual = hidden_states;
|
||||
let hidden_states = self.ln_1.forward(hidden_states)?;
|
||||
let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?;
|
||||
let hidden_states = (&attn_outputs + residual)?;
|
||||
let residual = &hidden_states;
|
||||
let hidden_states = self.ln_2.forward(&hidden_states)?;
|
||||
let hidden_states = self.mlp.forward(&hidden_states)?;
|
||||
let hidden_states = (&hidden_states + residual)?;
|
||||
Ok(hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GPTBigCode {
|
||||
wte: Embedding,
|
||||
wpe: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: LayerNorm,
|
||||
lm_head: Linear,
|
||||
bias: Tensor,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
impl GPTBigCode {
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let vb_t = vb.pp("transformer");
|
||||
let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?;
|
||||
let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?;
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?;
|
||||
let lm_head = linear(hidden_size, cfg.vocab_size, false, vb_t.pp("wte"))?;
|
||||
let bias = make_causal_mask(cfg.max_position_embeddings, vb.device())?;
|
||||
Ok(Self {
|
||||
wte,
|
||||
wpe,
|
||||
blocks,
|
||||
lm_head,
|
||||
ln_f,
|
||||
bias,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result<Tensor> {
|
||||
let dev = input_ids.device();
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
|
||||
let key_len = past_len + seq_len;
|
||||
let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?;
|
||||
// MQA models: (batch_size, query_length, n_heads, key_length)
|
||||
// MHA models: (batch_size, n_heads, query_length, key_length)
|
||||
let seq_len_dim = if self.config.multi_query { 2 } else { 1 };
|
||||
let attention_mask = attention_mask.unsqueeze(seq_len_dim)?;
|
||||
|
||||
let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?;
|
||||
let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?;
|
||||
let input_embeds = self.wte.forward(input_ids)?;
|
||||
let position_embeds = self.wpe.forward(&position_ids)?;
|
||||
|
||||
let mut hidden_states = (&input_embeds + &position_embeds)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_states = block.forward(&hidden_states, &attention_mask)?;
|
||||
}
|
||||
let hidden_states = self.ln_f.forward(&hidden_states)?;
|
||||
let hidden_states = hidden_states
|
||||
.reshape((b_sz, seq_len, self.config.hidden_size))?
|
||||
.narrow(1, seq_len - 1, 1)?;
|
||||
let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
@ -14,8 +14,7 @@ use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Falcon};
|
||||
use candle_transformers::models::falcon::{Config, Falcon};
|
||||
|
||||
struct TextGeneration {
|
||||
model: Falcon,
|
||||
|
@ -1,485 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
|
||||
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(size2, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Linear::new(weight, bias))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
|
||||
(Ok(weight), Ok(bias)) => (weight, bias),
|
||||
(Err(err), _) | (_, Err(err)) => {
|
||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(LayerNorm::new(weight, bias, eps))
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
// https://raw.githubusercontent.com/huggingface/transformers/030c863aaa0165e98352b61697430bf69bf33755/src/transformers/models/falcon/configuration_falcon.py
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
pub vocab_size: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub layer_norm_epsilon: f64,
|
||||
pub initializer_range: f64,
|
||||
pub use_cache: bool,
|
||||
pub bos_token_id: u32,
|
||||
pub eos_token_id: u32,
|
||||
pub hidden_dropout: f64,
|
||||
pub attention_dropout: f64,
|
||||
pub n_head_kv: Option<usize>,
|
||||
pub alibi: bool,
|
||||
pub new_decoder_architecture: bool,
|
||||
pub multi_query: bool,
|
||||
pub parallel_attn: bool,
|
||||
pub bias: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
vocab_size: 65024,
|
||||
hidden_size: 4544,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 71,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
initializer_range: 0.02,
|
||||
use_cache: true,
|
||||
bos_token_id: 11,
|
||||
eos_token_id: 11,
|
||||
hidden_dropout: 0.0,
|
||||
attention_dropout: 0.0,
|
||||
n_head_kv: None,
|
||||
alibi: false,
|
||||
new_decoder_architecture: false,
|
||||
multi_query: true,
|
||||
parallel_attn: true,
|
||||
bias: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.alibi {
|
||||
anyhow::bail!("alibi is not supported");
|
||||
}
|
||||
if self.new_decoder_architecture {
|
||||
anyhow::bail!("new_decoder_architecture is not supported");
|
||||
}
|
||||
if self.n_head_kv.is_some() {
|
||||
anyhow::bail!("n_head_kv is not supported");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json
|
||||
pub fn falcon7b() -> Self {
|
||||
// This is currently on par with the defaults, the defaults come from the Python default
|
||||
// arguments for the config initialization whereas the following come from the json config.
|
||||
Self {
|
||||
vocab_size: 65024,
|
||||
hidden_size: 4544,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 71,
|
||||
layer_norm_epsilon: 1e-5,
|
||||
initializer_range: 0.02,
|
||||
use_cache: true,
|
||||
bos_token_id: 11,
|
||||
eos_token_id: 11,
|
||||
hidden_dropout: 0.,
|
||||
attention_dropout: 0.,
|
||||
n_head_kv: None,
|
||||
alibi: false,
|
||||
new_decoder_architecture: false,
|
||||
multi_query: true,
|
||||
parallel_attn: true,
|
||||
bias: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn head_dim(&self) -> usize {
|
||||
self.hidden_size / self.num_attention_heads
|
||||
}
|
||||
|
||||
fn rotary(&self) -> bool {
|
||||
!self.alibi
|
||||
}
|
||||
}
|
||||
|
||||
fn rotate_half(x: &Tensor) -> Result<Tensor> {
|
||||
let l = x.dim(D::Minus1)?;
|
||||
let x1 = x.narrow(D::Minus1, 0, l / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, l / 2, l - l / 2)?;
|
||||
let x21 = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
Ok(x21)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FalconRotaryEmbedding {
|
||||
inv_freq: Tensor,
|
||||
cache: Option<(usize, Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl FalconRotaryEmbedding {
|
||||
fn load(device: &Device, cfg: &Config) -> Result<Self> {
|
||||
let head_dim = cfg.head_dim();
|
||||
let inv_freq: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
Ok(Self {
|
||||
inv_freq: Tensor::new(inv_freq.as_slice(), device)?,
|
||||
cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn cos_sin(
|
||||
&mut self,
|
||||
seq_len: usize,
|
||||
device: &Device,
|
||||
dtype: DType,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
match &self.cache {
|
||||
Some((s, cos, sin)) if *s == seq_len => {
|
||||
return Ok((cos.clone(), sin.clone()));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
let t = Tensor::arange(0, seq_len as u32, device)?.to_dtype(dtype)?;
|
||||
let inv_freq = self.inv_freq.to_dtype(dtype)?;
|
||||
let freqs = t.unsqueeze(1)?.matmul(&inv_freq.unsqueeze(0)?)?;
|
||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
||||
let cos = emb.cos()?;
|
||||
let sin = emb.sin()?;
|
||||
self.cache = Some((seq_len, cos.clone(), sin.clone()));
|
||||
Ok((cos, sin))
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
query: &Tensor,
|
||||
key: &Tensor,
|
||||
past_kv_len: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_batch, seq_len, _head_dim) = query.dims3()?;
|
||||
let (cos, sin) = self.cos_sin(MAX_SEQ_LEN, query.device(), query.dtype())?;
|
||||
let cos = cos.narrow(0, past_kv_len, seq_len)?;
|
||||
let sin = sin.narrow(0, past_kv_len, seq_len)?;
|
||||
let qs = (query.broadcast_mul(&cos)? + &rotate_half(query)?.broadcast_mul(&sin)?)?;
|
||||
let ks = (key.broadcast_mul(&cos)? + &rotate_half(key)?.broadcast_mul(&sin)?)?;
|
||||
Ok((qs, ks))
|
||||
}
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FalconAttention {
|
||||
query_key_value: Linear,
|
||||
dense: Linear,
|
||||
maybe_rotary: Option<FalconRotaryEmbedding>,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
inv_norm_factor: f64,
|
||||
multi_query: bool,
|
||||
use_cache: bool,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
n_head_kv: usize,
|
||||
}
|
||||
|
||||
impl FalconAttention {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let maybe_rotary = if cfg.rotary() {
|
||||
let rotary = FalconRotaryEmbedding::load(vb.device(), cfg)?;
|
||||
Some(rotary)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let head_dim = cfg.head_dim();
|
||||
let hidden_size = cfg.hidden_size;
|
||||
let qkv_out_dim = if cfg.multi_query {
|
||||
hidden_size + 2 * head_dim
|
||||
} else {
|
||||
3 * hidden_size
|
||||
};
|
||||
let query_key_value = linear(hidden_size, qkv_out_dim, cfg.bias, vb.pp("query_key_value"))?;
|
||||
let dense = linear(hidden_size, hidden_size, cfg.bias, vb.pp("dense"))?;
|
||||
Ok(Self {
|
||||
query_key_value,
|
||||
dense,
|
||||
maybe_rotary,
|
||||
kv_cache: None,
|
||||
inv_norm_factor: 1. / (head_dim as f64).sqrt(),
|
||||
multi_query: cfg.multi_query,
|
||||
use_cache: cfg.use_cache,
|
||||
num_heads: cfg.num_attention_heads,
|
||||
n_head_kv: cfg.n_head_kv.unwrap_or(1),
|
||||
head_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn split_heads(&self, fused_qkv: &Tensor) -> Result<(Tensor, Tensor, Tensor)> {
|
||||
let (b_sz, seq_len, _) = fused_qkv.dims3()?;
|
||||
if !self.multi_query {
|
||||
let fused_qkv = fused_qkv.reshape((b_sz, seq_len, self.num_heads, 3, self.head_dim))?;
|
||||
let q = fused_qkv.narrow(D::Minus2, 0, 1)?.squeeze(D::Minus2)?;
|
||||
let k = fused_qkv.narrow(D::Minus2, 1, 1)?.squeeze(D::Minus2)?;
|
||||
let v = fused_qkv.narrow(D::Minus2, 2, 1)?.squeeze(D::Minus2)?;
|
||||
Ok((q, k, v))
|
||||
} else {
|
||||
let fused_qkv =
|
||||
fused_qkv.reshape((b_sz, seq_len, self.num_heads + 2, self.head_dim))?;
|
||||
let d = fused_qkv.dim(D::Minus2)?;
|
||||
let q = fused_qkv.narrow(D::Minus2, 0, d - 2)?;
|
||||
let k = fused_qkv.narrow(D::Minus2, d - 2, 1)?;
|
||||
let v = fused_qkv.narrow(D::Minus2, d - 1, 1)?;
|
||||
Ok((q, k, v))
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let fused_qkv = self.query_key_value.forward(x)?;
|
||||
let head_dim = self.head_dim;
|
||||
let (query, key, value) = self.split_heads(&fused_qkv)?;
|
||||
let (b_sz, seq_len, _, _) = query.dims4()?;
|
||||
let query = query
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
||||
let key = key
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
|
||||
let value = value
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz * self.n_head_kv, seq_len, head_dim))?;
|
||||
let (query, key) = if let Some(r) = &mut self.maybe_rotary {
|
||||
r.forward(&query, &key, past_kv_len)?
|
||||
} else {
|
||||
(query, key)
|
||||
};
|
||||
let (mut key, mut value) = (key, value);
|
||||
let mask = masked_fill(&mask.to_dtype(DType::F32)?, mask, -1e9)?.to_dtype(query.dtype())?;
|
||||
if self.use_cache {
|
||||
if let Some((cache_k, cache_v)) = &self.kv_cache {
|
||||
// TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for
|
||||
// arbitrarily large sizes.
|
||||
key = Tensor::cat(&[cache_k, &key], 1)?.contiguous()?;
|
||||
value = Tensor::cat(&[cache_v, &value], 1)?.contiguous()?;
|
||||
}
|
||||
self.kv_cache = Some((key.clone(), value.clone()))
|
||||
}
|
||||
let query = query.reshape((b_sz * self.num_heads, seq_len, head_dim))?;
|
||||
let all_len = past_kv_len + seq_len;
|
||||
let key = key.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
|
||||
let value = value.reshape((b_sz * self.n_head_kv, all_len, head_dim))?;
|
||||
|
||||
let (key, value) = if self.n_head_kv == 1 {
|
||||
(
|
||||
key.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
|
||||
value.broadcast_as((b_sz * self.num_heads, all_len, head_dim))?,
|
||||
)
|
||||
} else {
|
||||
(key, value)
|
||||
};
|
||||
|
||||
// Only handle the case where alibi is None here, and non-flash attention.
|
||||
let attention_scores = (query.matmul(&key.t()?)? * self.inv_norm_factor)?;
|
||||
let attention_scores = candle_nn::ops::softmax(
|
||||
&attention_scores
|
||||
.broadcast_add(&mask.squeeze(1)?)?
|
||||
.to_dtype(DType::F32)?,
|
||||
D::Minus1,
|
||||
)?
|
||||
.to_dtype(x.dtype())?;
|
||||
let attn_output = attention_scores
|
||||
.matmul(&value)?
|
||||
.reshape((b_sz, self.num_heads, seq_len, head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.num_heads * head_dim))?;
|
||||
let attn_output = self.dense.forward(&attn_output)?;
|
||||
Ok(attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FalconMlp {
|
||||
dense_h_to_4h: Linear,
|
||||
dense_4h_to_h: Linear,
|
||||
}
|
||||
|
||||
impl FalconMlp {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = cfg.hidden_size;
|
||||
let b = cfg.bias;
|
||||
let dense_h_to_4h = linear(h, 4 * h, b, vb.pp("dense_h_to_4h"))?;
|
||||
let dense_4h_to_h = linear(4 * h, h, b, vb.pp("dense_4h_to_h"))?;
|
||||
Ok(Self {
|
||||
dense_h_to_4h,
|
||||
dense_4h_to_h,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.dense_h_to_4h.forward(x)?.gelu()?;
|
||||
let x = self.dense_4h_to_h.forward(&x)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FalconDecoderLayer {
|
||||
inp_layernorm: LayerNorm,
|
||||
self_attention: FalconAttention,
|
||||
post_attention_layernorm: Option<LayerNorm>,
|
||||
mlp: FalconMlp,
|
||||
parallel_attn: bool,
|
||||
}
|
||||
|
||||
impl FalconDecoderLayer {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mlp = FalconMlp::load(vb.pp("mlp"), cfg)?;
|
||||
let inp_layernorm = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
vb.pp("input_layernorm"),
|
||||
)?;
|
||||
let self_attention = FalconAttention::load(vb.pp("self_attention"), cfg)?;
|
||||
let post_attention_layernorm = if cfg.parallel_attn {
|
||||
None
|
||||
} else {
|
||||
let ln = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Some(ln)
|
||||
};
|
||||
Ok(Self {
|
||||
inp_layernorm,
|
||||
self_attention,
|
||||
post_attention_layernorm,
|
||||
mlp,
|
||||
parallel_attn: cfg.parallel_attn,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, x: &Tensor, mask: &Tensor, past_kv_len: usize) -> Result<Tensor> {
|
||||
let residual = x.clone();
|
||||
let ln_attn = self.inp_layernorm.forward(x)?;
|
||||
let attn_output = self.self_attention.forward(&ln_attn, mask, past_kv_len)?;
|
||||
let (residual, ln_mlp) = match &self.post_attention_layernorm {
|
||||
None => (residual, ln_attn),
|
||||
Some(pal) => {
|
||||
// This should include some dropout.
|
||||
let residual = (&attn_output + &residual)?;
|
||||
let ln_mlp = pal.forward(&residual)?;
|
||||
(residual, ln_mlp)
|
||||
}
|
||||
};
|
||||
let mlp_output = self.mlp.forward(&ln_mlp)?;
|
||||
|
||||
let mlp_output = if self.parallel_attn {
|
||||
(mlp_output + attn_output)?
|
||||
} else {
|
||||
mlp_output
|
||||
};
|
||||
let output = (mlp_output + residual)?;
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Falcon {
|
||||
word_embeddings: Embedding,
|
||||
blocks: Vec<FalconDecoderLayer>,
|
||||
ln_f: LayerNorm,
|
||||
lm_head: Linear,
|
||||
config: Config,
|
||||
}
|
||||
|
||||
fn make_causal_mask(t: usize) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
fn prepare_attn_mask(b_sz: usize, seq_len: usize) -> Result<Tensor> {
|
||||
// let mask = Tensor::ones((b_sz, seq_len), DType::U32, &Device::Cpu)?;
|
||||
let mask = make_causal_mask(seq_len)?;
|
||||
let mask = mask.broadcast_as((b_sz, 1, seq_len, seq_len))?;
|
||||
Ok(mask)
|
||||
}
|
||||
|
||||
impl Falcon {
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: Config) -> Result<Self> {
|
||||
let word_embeddings = embedding(
|
||||
cfg.vocab_size,
|
||||
cfg.hidden_size,
|
||||
vb.pp("transformer.word_embeddings"),
|
||||
)?;
|
||||
let blocks = (0..cfg.num_hidden_layers)
|
||||
.map(|i| FalconDecoderLayer::load(vb.pp(&format!("transformer.h.{i}")), &cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_f = layer_norm(
|
||||
cfg.hidden_size,
|
||||
cfg.layer_norm_epsilon,
|
||||
vb.pp("transformer.ln_f"),
|
||||
)?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?;
|
||||
Ok(Self {
|
||||
word_embeddings,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
config: cfg,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||
let mut hidden_state = self.word_embeddings.forward(input_ids)?;
|
||||
let past_kv_len = match &self.blocks[0].self_attention.kv_cache {
|
||||
Some((k, _)) => k.dim(1)?,
|
||||
None => 0,
|
||||
};
|
||||
let causal_mask = prepare_attn_mask(b_sz, seq_len)?.to_device(input_ids.device())?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
hidden_state = block.forward(&hidden_state, &causal_mask, past_kv_len)?;
|
||||
}
|
||||
let hidden_state = self.ln_f.forward(&hidden_state)?;
|
||||
let hidden_state = hidden_state.narrow(1, seq_len - 1, 1)?;
|
||||
let logits = self.lm_head.forward(&hidden_state)?.squeeze(1)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
mod model;
|
||||
use candle_transformers::models::llama as model;
|
||||
use model::{Config, Llama, LlamaConfig};
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
|
@ -1,446 +0,0 @@
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LlamaConfig {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: Option<usize>,
|
||||
pub rms_norm_eps: f64,
|
||||
#[serde(default = "default_rope")]
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
fn default_rope() -> f32 {
|
||||
10_000.0
|
||||
}
|
||||
|
||||
impl LlamaConfig {
|
||||
pub fn into_config(self, use_flash_attn: bool) -> Config {
|
||||
Config {
|
||||
hidden_size: self.hidden_size,
|
||||
intermediate_size: self.intermediate_size,
|
||||
vocab_size: self.vocab_size,
|
||||
num_hidden_layers: self.num_hidden_layers,
|
||||
num_attention_heads: self.num_attention_heads,
|
||||
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
||||
rms_norm_eps: self.rms_norm_eps,
|
||||
rope_theta: self.rope_theta,
|
||||
use_flash_attn,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
pub hidden_size: usize,
|
||||
pub intermediate_size: usize,
|
||||
pub vocab_size: usize,
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_attention_heads: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub use_flash_attn: bool,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f32,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn config_7b_v1(use_flash_attn: bool) -> Self {
|
||||
Self {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-6,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config_7b_v2(use_flash_attn: bool) -> Self {
|
||||
Self {
|
||||
hidden_size: 4096,
|
||||
intermediate_size: 11008,
|
||||
vocab_size: 32000,
|
||||
num_hidden_layers: 32,
|
||||
num_attention_heads: 32,
|
||||
num_key_value_heads: 32,
|
||||
use_flash_attn,
|
||||
rms_norm_eps: 1e-5,
|
||||
rope_theta: 10_000.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Cache {
|
||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||
pub use_kv_cache: bool,
|
||||
#[allow(clippy::type_complexity)]
|
||||
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
device: Device,
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||
// precompute freqs_cis
|
||||
let n_elem = config.hidden_size / config.num_attention_heads;
|
||||
let theta: Vec<_> = (0..n_elem)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
// This is different from the paper, see:
|
||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||
Ok(Self {
|
||||
masks: Arc::new(Mutex::new(HashMap::new())),
|
||||
use_kv_cache,
|
||||
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
|
||||
device: device.clone(),
|
||||
cos,
|
||||
sin,
|
||||
})
|
||||
}
|
||||
|
||||
fn mask(&self, t: usize) -> Result<Tensor> {
|
||||
let mut masks = self.masks.lock().unwrap();
|
||||
if let Some(mask) = masks.get(&t) {
|
||||
Ok(mask.clone())
|
||||
} else {
|
||||
let mask: Vec<_> = (0..t)
|
||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (t, t), &self.device)?;
|
||||
masks.insert(t, mask.clone());
|
||||
Ok(mask)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||
}
|
||||
|
||||
struct RmsNorm {
|
||||
inner: candle_nn::RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
struct CausalSelfAttention {
|
||||
q_proj: Linear,
|
||||
k_proj: Linear,
|
||||
v_proj: Linear,
|
||||
o_proj: Linear,
|
||||
num_attention_heads: usize,
|
||||
num_key_value_heads: usize,
|
||||
head_dim: usize,
|
||||
cache: Cache,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
}
|
||||
|
||||
#[cfg(feature = "flash-attn")]
|
||||
fn flash_attn(
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
softmax_scale: f32,
|
||||
causal: bool,
|
||||
) -> Result<Tensor> {
|
||||
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "flash-attn"))]
|
||||
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||
unimplemented!("compile with '--features flash-attn'")
|
||||
}
|
||||
|
||||
impl CausalSelfAttention {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, hidden_size) = x.dims4()?;
|
||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||
Ok(rope)
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let q = self.q_proj.forward(x)?;
|
||||
let k = self.k_proj.forward(x)?;
|
||||
let v = self.v_proj.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let mut v = v
|
||||
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||
|
||||
if self.cache.use_kv_cache {
|
||||
let mut cache = self.cache.kvs.lock().unwrap();
|
||||
if let Some((cache_k, cache_v)) = &cache[block_idx] {
|
||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||
let k_seq_len = k.dims()[1];
|
||||
if k_seq_len > MAX_SEQ_LEN {
|
||||
k = k
|
||||
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.contiguous()?
|
||||
}
|
||||
let v_seq_len = v.dims()[1];
|
||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
||||
v = v
|
||||
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
||||
.contiguous()?
|
||||
}
|
||||
}
|
||||
cache[block_idx] = Some((k.clone(), v.clone()))
|
||||
}
|
||||
|
||||
let k = self.repeat_kv(k)?;
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let y = if self.use_flash_attn {
|
||||
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||
let q = q.transpose(1, 2)?;
|
||||
let k = k.transpose(1, 2)?;
|
||||
let v = v.transpose(1, 2)?;
|
||||
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||
flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?.transpose(1, 2)?
|
||||
} else {
|
||||
let in_dtype = q.dtype();
|
||||
let q = q.to_dtype(DType::F32)?;
|
||||
let k = k.to_dtype(DType::F32)?;
|
||||
let v = v.to_dtype(DType::F32)?;
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)?
|
||||
};
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||
let y = self.o_proj.forward(&y)?;
|
||||
Ok(y)
|
||||
}
|
||||
|
||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||
if n_rep == 1 {
|
||||
Ok(x)
|
||||
} else {
|
||||
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
|
||||
let x = x
|
||||
.unsqueeze(2)?
|
||||
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
|
||||
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let size_in = cfg.hidden_size;
|
||||
let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
|
||||
let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
|
||||
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
|
||||
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
|
||||
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
|
||||
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
num_attention_heads: cfg.num_attention_heads,
|
||||
num_key_value_heads: cfg.num_key_value_heads,
|
||||
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||
cache: cache.clone(),
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span,
|
||||
span_rot,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
}
|
||||
|
||||
struct Mlp {
|
||||
c_fc1: Linear,
|
||||
c_fc2: Linear,
|
||||
c_proj: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
|
||||
self.c_proj.forward(&x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||
let h_size = cfg.hidden_size;
|
||||
let i_size = cfg.intermediate_size;
|
||||
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
|
||||
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
|
||||
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
|
||||
Ok(Self {
|
||||
c_fc1,
|
||||
c_fc2,
|
||||
c_proj,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct Block {
|
||||
rms_1: RmsNorm,
|
||||
attn: CausalSelfAttention,
|
||||
rms_2: RmsNorm,
|
||||
mlp: Mlp,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = x;
|
||||
let x = self.rms_1.forward(x)?;
|
||||
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
|
||||
let residual = &x;
|
||||
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
|
||||
Ok(x)
|
||||
}
|
||||
|
||||
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "block");
|
||||
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
|
||||
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
|
||||
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||
let rms_2 = RmsNorm::load(
|
||||
cfg.hidden_size,
|
||||
cfg.rms_norm_eps,
|
||||
vb.pp("post_attention_layernorm"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
rms_1,
|
||||
attn,
|
||||
rms_2,
|
||||
mlp,
|
||||
span,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Llama {
|
||||
wte: Embedding,
|
||||
blocks: Vec<Block>,
|
||||
ln_f: RmsNorm,
|
||||
lm_head: Linear,
|
||||
}
|
||||
|
||||
impl Llama {
|
||||
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mut x = self.wte.forward(x)?;
|
||||
for (block_idx, block) in self.blocks.iter().enumerate() {
|
||||
x = block.forward(&x, index_pos, block_idx)?;
|
||||
}
|
||||
let x = self.ln_f.forward(&x)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let logits = self.lm_head.forward(&x)?;
|
||||
logits.to_dtype(DType::F32)
|
||||
}
|
||||
|
||||
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
|
||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||
let ln_f = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?;
|
||||
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||
.map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cache, cfg).unwrap())
|
||||
.collect();
|
||||
|
||||
Ok(Self {
|
||||
wte,
|
||||
blocks,
|
||||
ln_f,
|
||||
lm_head,
|
||||
})
|
||||
}
|
||||
}
|
@ -1,214 +0,0 @@
|
||||
// Audio processing code, adapted from whisper.cpp
|
||||
// https://github.com/ggerganov/whisper.cpp
|
||||
|
||||
pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {}
|
||||
|
||||
impl Float for f32 {}
|
||||
impl Float for f64 {}
|
||||
|
||||
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357
|
||||
fn fft<T: Float>(inp: &[T]) -> Vec<T> {
|
||||
let n = inp.len();
|
||||
let zero = T::zero();
|
||||
if n == 1 {
|
||||
return vec![inp[0], zero];
|
||||
}
|
||||
if n % 2 == 1 {
|
||||
return dft(inp);
|
||||
}
|
||||
let mut out = vec![zero; n * 2];
|
||||
|
||||
let mut even = Vec::with_capacity(n / 2);
|
||||
let mut odd = Vec::with_capacity(n / 2);
|
||||
|
||||
for (i, &inp) in inp.iter().enumerate() {
|
||||
if i % 2 == 0 {
|
||||
even.push(inp)
|
||||
} else {
|
||||
odd.push(inp);
|
||||
}
|
||||
}
|
||||
|
||||
let even_fft = fft(&even);
|
||||
let odd_fft = fft(&odd);
|
||||
|
||||
let two_pi = T::PI() + T::PI();
|
||||
let n_t = T::from(n).unwrap();
|
||||
for k in 0..n / 2 {
|
||||
let k_t = T::from(k).unwrap();
|
||||
let theta = two_pi * k_t / n_t;
|
||||
let re = theta.cos();
|
||||
let im = -theta.sin();
|
||||
|
||||
let re_odd = odd_fft[2 * k];
|
||||
let im_odd = odd_fft[2 * k + 1];
|
||||
|
||||
out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd;
|
||||
out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd;
|
||||
|
||||
out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd;
|
||||
out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd;
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337
|
||||
fn dft<T: Float>(inp: &[T]) -> Vec<T> {
|
||||
let zero = T::zero();
|
||||
let n = inp.len();
|
||||
let two_pi = T::PI() + T::PI();
|
||||
|
||||
let mut out = Vec::new();
|
||||
out.reserve(2 * n);
|
||||
let n_t = T::from(n).unwrap();
|
||||
for k in 0..n {
|
||||
let k_t = T::from(k).unwrap();
|
||||
let mut re = zero;
|
||||
let mut im = zero;
|
||||
|
||||
for (j, &inp) in inp.iter().enumerate() {
|
||||
let j_t = T::from(j).unwrap();
|
||||
let angle = two_pi * k_t * j_t / n_t;
|
||||
re += inp * angle.cos();
|
||||
im -= inp * angle.sin();
|
||||
}
|
||||
|
||||
out.push(re);
|
||||
out.push(im);
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414
|
||||
fn log_mel_spectrogram_w<T: Float>(
|
||||
ith: usize,
|
||||
hann: &[T],
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
fft_size: usize,
|
||||
fft_step: usize,
|
||||
speed_up: bool,
|
||||
n_len: usize,
|
||||
n_mel: usize,
|
||||
n_threads: usize,
|
||||
) -> Vec<T> {
|
||||
let n_fft = if speed_up {
|
||||
1 + fft_size / 4
|
||||
} else {
|
||||
1 + fft_size / 2
|
||||
};
|
||||
|
||||
let zero = T::zero();
|
||||
let half = T::from(0.5).unwrap();
|
||||
let mut fft_in = vec![zero; fft_size];
|
||||
let mut mel = vec![zero; n_len * n_mel];
|
||||
|
||||
for i in (ith..n_len).step_by(n_threads) {
|
||||
let offset = i * fft_step;
|
||||
|
||||
// apply Hanning window
|
||||
for j in 0..fft_size {
|
||||
fft_in[j] = if offset + j < samples.len() {
|
||||
hann[j] * samples[offset + j]
|
||||
} else {
|
||||
zero
|
||||
}
|
||||
}
|
||||
|
||||
// FFT -> mag^2
|
||||
let mut fft_out: Vec<T> = fft(&fft_in);
|
||||
|
||||
for j in 0..fft_size {
|
||||
fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1];
|
||||
}
|
||||
for j in 1..fft_size / 2 {
|
||||
let v = fft_out[fft_size - j];
|
||||
fft_out[j] += v;
|
||||
}
|
||||
|
||||
if speed_up {
|
||||
// scale down in the frequency domain results in a speed up in the time domain
|
||||
for j in 0..n_fft {
|
||||
fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// mel spectrogram
|
||||
for j in 0..n_mel {
|
||||
let mut sum = zero;
|
||||
for k in 0..n_fft {
|
||||
sum += fft_out[k] * filters[j * n_fft + k];
|
||||
}
|
||||
mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10();
|
||||
}
|
||||
}
|
||||
mel
|
||||
}
|
||||
|
||||
fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
fft_size: usize,
|
||||
fft_step: usize,
|
||||
n_mel: usize,
|
||||
speed_up: bool,
|
||||
) -> Vec<T> {
|
||||
let zero = T::zero();
|
||||
let two_pi = T::PI() + T::PI();
|
||||
let half = T::from(0.5).unwrap();
|
||||
let one = T::from(1.0).unwrap();
|
||||
let four = T::from(4.0).unwrap();
|
||||
let fft_size_t = T::from(fft_size).unwrap();
|
||||
|
||||
let hann: Vec<T> = (0..fft_size)
|
||||
.map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos()))
|
||||
.collect();
|
||||
let n_len = samples.len() / fft_step;
|
||||
|
||||
// pad audio with at least one extra chunk of zeros
|
||||
let pad = 100 * super::CHUNK_LENGTH / 2;
|
||||
let n_len = if n_len % pad != 0 {
|
||||
(n_len / pad + 1) * pad
|
||||
} else {
|
||||
n_len
|
||||
};
|
||||
let n_len = n_len + pad;
|
||||
let samples = {
|
||||
let mut samples_padded = samples.to_vec();
|
||||
let to_add = n_len * fft_step - samples.len();
|
||||
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
||||
samples_padded
|
||||
};
|
||||
|
||||
// Use a single thread for now.
|
||||
let mut mel = log_mel_spectrogram_w(
|
||||
0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1,
|
||||
);
|
||||
let mmax = mel
|
||||
.iter()
|
||||
.max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater))
|
||||
.copied()
|
||||
.unwrap_or(zero)
|
||||
- T::from(8).unwrap();
|
||||
for m in mel.iter_mut() {
|
||||
let v = T::max(*m, mmax);
|
||||
*m = v / four + one
|
||||
}
|
||||
mel
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
) -> anyhow::Result<Vec<T>> {
|
||||
let mel = log_mel_spectrogram_(
|
||||
samples,
|
||||
filters,
|
||||
super::N_FFT,
|
||||
super::HOP_LENGTH,
|
||||
super::N_MELS,
|
||||
false,
|
||||
);
|
||||
Ok(mel)
|
||||
}
|
@ -10,41 +10,16 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle::{Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod audio;
|
||||
mod model;
|
||||
use model::{Config, Whisper};
|
||||
mod multilingual;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
// Audio parameters.
|
||||
const SAMPLE_RATE: usize = 16000;
|
||||
const N_FFT: usize = 400;
|
||||
const N_MELS: usize = 80;
|
||||
const HOP_LENGTH: usize = 160;
|
||||
const CHUNK_LENGTH: usize = 30;
|
||||
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||
|
||||
const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||
const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
use candle_transformers::models::whisper::{self as m, audio, model};
|
||||
use model::{Config, Whisper};
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
@ -94,7 +69,7 @@ impl Decoder {
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
) -> Result<Self> {
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
@ -109,11 +84,11 @@ impl Decoder {
|
||||
})
|
||||
.collect();
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
@ -220,17 +195,17 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
||||
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
||||
let dr: Result<DecodingResult> = self.decode(segment, t);
|
||||
if i == TEMPERATURES.len() - 1 {
|
||||
if i == m::TEMPERATURES.len() - 1 {
|
||||
return dr;
|
||||
}
|
||||
// On errors, we try again with a different temperature.
|
||||
match dr {
|
||||
Ok(dr) => {
|
||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
||||
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
}
|
||||
}
|
||||
@ -248,13 +223,13 @@ impl Decoder {
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
let start = std::time::Instant::now();
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
||||
println!("no speech detected, skipping {seek} {dr:?}");
|
||||
continue;
|
||||
}
|
||||
@ -492,8 +467,8 @@ fn main() -> Result<()> {
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
|
||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
@ -501,14 +476,14 @@ fn main() -> Result<()> {
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
println!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let mut model = Whisper::load(&vb, config)?;
|
||||
|
||||
|
@ -1,416 +0,0 @@
|
||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub num_mel_bins: usize, // n_mels
|
||||
pub max_source_positions: usize, // n_audio_ctx
|
||||
pub d_model: usize, // n_audio_state
|
||||
pub encoder_attention_heads: usize, // n_audio_head
|
||||
pub encoder_layers: usize, // n_audio_layer
|
||||
pub vocab_size: usize, // n_vocab
|
||||
pub max_target_positions: usize, // n_text_ctx
|
||||
// pub n_text_state: usize,
|
||||
pub decoder_attention_heads: usize, // n_text_head
|
||||
pub decoder_layers: usize, // n_text_layer
|
||||
pub suppress_tokens: Vec<u32>,
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
//
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
||||
let bias = vb.get(out_channels, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
Ok(LayerNorm::new(weight, bias, 1e-5))
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
|
||||
struct MultiHeadAttention {
|
||||
query: Linear,
|
||||
key: Linear,
|
||||
value: Linear,
|
||||
out: Linear,
|
||||
n_head: usize,
|
||||
span: tracing::Span,
|
||||
softmax_span: tracing::Span,
|
||||
matmul_span: tracing::Span,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
||||
let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
|
||||
let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
|
||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||
let out = linear(n_state, n_state, vb.pp("out_proj"))?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
n_head,
|
||||
span,
|
||||
softmax_span,
|
||||
matmul_span,
|
||||
kv_cache: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let q = self.query.forward(x)?;
|
||||
let (k, v) = match xa {
|
||||
None => {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
(k, v)
|
||||
}
|
||||
Some(x) => {
|
||||
if flush_cache {
|
||||
self.kv_cache = None;
|
||||
}
|
||||
if let Some((k, v)) = &self.kv_cache {
|
||||
(k.clone(), v.clone())
|
||||
} else {
|
||||
let k = self.key.forward(x)?;
|
||||
let v = self.value.forward(x)?;
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
let wv = self.qkv_attention(&q, &k, &v, mask)?;
|
||||
let out = self.out.forward(&wv)?;
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
x.reshape(target_dims)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn qkv_attention(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
v: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
) -> Result<Tensor> {
|
||||
let (_, n_ctx, n_state) = q.dims3()?;
|
||||
let scale = ((n_state / self.n_head) as f64).powf(-0.25);
|
||||
let q = (self.reshape_head(q)? * scale)?;
|
||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||
let v = self.reshape_head(v)?.contiguous()?;
|
||||
let mut qk = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
q.matmul(&k)?
|
||||
};
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
|
||||
qk = qk.broadcast_add(&mask)?
|
||||
}
|
||||
let w = {
|
||||
let _enter = self.softmax_span.enter();
|
||||
softmax(&qk, D::Minus1)?
|
||||
};
|
||||
let wv = {
|
||||
let _enter = self.matmul_span.enter();
|
||||
w.matmul(&v)?
|
||||
}
|
||||
.transpose(1, 2)?
|
||||
.flatten_from(2)?;
|
||||
Ok(wv)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
|
||||
struct ResidualAttentionBlock {
|
||||
attn: MultiHeadAttention,
|
||||
attn_ln: LayerNorm,
|
||||
cross_attn: Option<(MultiHeadAttention, LayerNorm)>,
|
||||
mlp_linear1: Linear,
|
||||
mlp_linear2: Linear,
|
||||
mlp_ln: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
||||
let cross_attn = if ca {
|
||||
let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?;
|
||||
let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?;
|
||||
Some((cross_attn, cross_attn_ln))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let n_mlp = n_state * 4;
|
||||
let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?;
|
||||
let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?;
|
||||
let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
attn_ln,
|
||||
cross_attn,
|
||||
mlp_linear1,
|
||||
mlp_linear2,
|
||||
mlp_ln,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
xa: Option<&Tensor>,
|
||||
mask: Option<&Tensor>,
|
||||
flush_kv_cache: bool,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attn = self
|
||||
.attn
|
||||
.forward(&self.attn_ln.forward(x)?, None, mask, flush_kv_cache)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &mut self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa, None, flush_kv_cache)?)?;
|
||||
}
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
.mlp_linear1
|
||||
.forward(&self.mlp_ln.forward(&x)?)?
|
||||
.gelu()?,
|
||||
)?;
|
||||
x + mlp
|
||||
}
|
||||
}
|
||||
|
||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
|
||||
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
|
||||
Ok(sincos)
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||
pub struct AudioEncoder {
|
||||
conv1: Conv1d,
|
||||
conv2: Conv1d,
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln_post: LayerNorm,
|
||||
span: tracing::Span,
|
||||
conv1_span: tracing::Span,
|
||||
conv2_span: tracing::Span,
|
||||
}
|
||||
|
||||
impl AudioEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
|
||||
let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
|
||||
let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.encoder_attention_heads;
|
||||
let n_ctx = cfg.max_source_positions;
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
};
|
||||
let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?;
|
||||
let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?;
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}")))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln_post,
|
||||
conv1_span,
|
||||
conv2_span,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x = {
|
||||
let _enter = self.conv1_span.enter();
|
||||
self.conv1.forward(x)?.gelu()?
|
||||
};
|
||||
let x = {
|
||||
let _enter = self.conv2_span.enter();
|
||||
self.conv2.forward(&x)?.gelu()?
|
||||
};
|
||||
let x = x.transpose(1, 2)?;
|
||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
let mut x = x.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, None, None, flush_kv_cache)?
|
||||
}
|
||||
let x = self.ln_post.forward(&x)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
|
||||
pub struct TextDecoder {
|
||||
token_embedding: Embedding,
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
span: tracing::Span,
|
||||
span_final: tracing::Span,
|
||||
}
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
||||
let span_final = tracing::span!(tracing::Level::TRACE, "text-decoder-final");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?;
|
||||
let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?;
|
||||
let blocks = (0..cfg.decoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}")))
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = layer_norm(n_state, vb.pp("layer_norm"))?;
|
||||
let mask: Vec<_> = (0..n_ctx)
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?;
|
||||
Ok(Self {
|
||||
token_embedding,
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln,
|
||||
mask,
|
||||
span,
|
||||
span_final,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, xa: &Tensor, flush_kv_cache: bool) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let last = x.dim(D::Minus1)?;
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter_mut() {
|
||||
x = block.forward(&x, Some(xa), Some(&self.mask), flush_kv_cache)?;
|
||||
}
|
||||
self.ln.forward(&x)
|
||||
}
|
||||
|
||||
pub fn final_linear(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let b_size = x.dim(0)?;
|
||||
let w = self.token_embedding.embeddings().broadcast_left(b_size)?;
|
||||
let logits = {
|
||||
let _enter = self.span_final.enter();
|
||||
x.matmul(&w.t()?)?
|
||||
};
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
|
||||
pub struct Whisper {
|
||||
pub encoder: AudioEncoder,
|
||||
pub decoder: TextDecoder,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl Whisper {
|
||||
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
||||
let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?;
|
||||
let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
config,
|
||||
})
|
||||
}
|
||||
}
|
@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
|
||||
.iter()
|
||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder.forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
|
Reference in New Issue
Block a user