mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Adds support for stella_en_v5 embedding model -400M variant (#2608)
* Adds support for stella_en_v5 embedding model -400M variant * Unified stella * WIP: Unified Stella * Combined stella for both 1.5B and 400M variants * Cargo fmt for the CI * removed redundant stella-400m model and example after merge into stella-en-v5 * cargo fmt --all --------- Co-authored-by: Anubhab Bandyopadhyay <4890833+AnubhabB@users.noreply.github.com> Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -21,7 +21,7 @@ Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling
|
|||||||
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
|
The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ cargo run --example stella-en-v5 --release --features <metal | cuda>
|
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 1.5b
|
||||||
|
|
||||||
>
|
>
|
||||||
> Score: 0.8178786
|
> Score: 0.8178786
|
||||||
@ -37,9 +37,29 @@ $ cargo run --example stella-en-v5 --release --features <metal | cuda>
|
|||||||
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
|
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types >
|
||||||
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
|
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
|
||||||
>
|
>
|
||||||
|
|
||||||
|
$ cargo run --example stella-en-v5 --release --features <metal | cuda> -- --which 400m
|
||||||
|
|
||||||
|
>
|
||||||
|
> Score: 0.8397539
|
||||||
|
> Query: What are some ways to reduce stress?
|
||||||
|
> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending
|
||||||
|
> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent
|
||||||
|
> stress from building up.
|
||||||
|
>
|
||||||
|
>
|
||||||
|
>
|
||||||
|
> Score: 0.809545
|
||||||
|
> Query: What are the benefits of drinking green tea?
|
||||||
|
> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage
|
||||||
|
> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types
|
||||||
|
> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.
|
||||||
|
>
|
||||||
```
|
```
|
||||||
|
|
||||||
## Supported options:
|
## Supported options:
|
||||||
- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
|
- `Stella_en_v5` has 2 model variants published - a 1.5B variant and 400M variant. This is enabled through the flag `--which`. E.g. `--which 400m` or `--which 1.5b`.
|
||||||
|
|
||||||
|
- `Stella_en_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`.
|
||||||
|
|
||||||
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.
|
- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option.
|
@ -212,6 +212,14 @@ impl EncodeTask {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "1.5b")]
|
||||||
|
Large,
|
||||||
|
#[value(name = "400m")]
|
||||||
|
Small,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -219,6 +227,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
@ -250,24 +261,33 @@ struct Args {
|
|||||||
|
|
||||||
// Tokenizer creation is super critical in our case.
|
// Tokenizer creation is super critical in our case.
|
||||||
// We are going to be `padding: Left` for each batch
|
// We are going to be `padding: Left` for each batch
|
||||||
fn create_tokenizer(tokenizer_file: &Path) -> Result<Tokenizer> {
|
fn create_tokenizer(tokenizer_file: &Path, which: Which) -> Result<Tokenizer> {
|
||||||
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||||
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
|
|
||||||
pad_id
|
|
||||||
} else {
|
|
||||||
return Err(anyhow!(
|
|
||||||
"Tokenizer doesn't contain expected `<|endoftext|>` token"
|
|
||||||
));
|
|
||||||
};
|
|
||||||
|
|
||||||
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
|
if which == Which::Large {
|
||||||
tokenizer.with_padding(Some(PaddingParams {
|
let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") {
|
||||||
strategy: PaddingStrategy::BatchLongest,
|
pad_id
|
||||||
direction: PaddingDirection::Left,
|
} else {
|
||||||
pad_id,
|
return Err(anyhow!(
|
||||||
pad_token: "<|endoftext|>".to_string(),
|
"Tokenizer doesn't contain expected `<|endoftext|>` token"
|
||||||
..Default::default()
|
));
|
||||||
}));
|
};
|
||||||
|
|
||||||
|
// This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding
|
||||||
|
tokenizer.with_padding(Some(PaddingParams {
|
||||||
|
strategy: PaddingStrategy::BatchLongest,
|
||||||
|
direction: PaddingDirection::Left,
|
||||||
|
pad_id,
|
||||||
|
pad_token: "<|endoftext|>".to_string(),
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
} else {
|
||||||
|
tokenizer.with_padding(Some(PaddingParams {
|
||||||
|
strategy: PaddingStrategy::BatchLongest,
|
||||||
|
direction: PaddingDirection::Right,
|
||||||
|
..Default::default()
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
Ok(tokenizer)
|
Ok(tokenizer)
|
||||||
}
|
}
|
||||||
@ -298,7 +318,19 @@ fn main() -> Result<()> {
|
|||||||
Some(d) => d,
|
Some(d) => d,
|
||||||
None => EmbedDim::Dim1024,
|
None => EmbedDim::Dim1024,
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string()));
|
|
||||||
|
let (repo, cfg) = match args.which {
|
||||||
|
Which::Large => (
|
||||||
|
"dunzhang/stella_en_1.5B_v5",
|
||||||
|
Config::new_1_5_b_v5(embed_dim.embed_dim()),
|
||||||
|
),
|
||||||
|
Which::Small => (
|
||||||
|
"dunzhang/stella_en_400M_v5",
|
||||||
|
Config::new_400_m_v5(embed_dim.embed_dim()),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let repo = api.repo(Repo::model(repo.to_string()));
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => repo.get("tokenizer.json")?,
|
None => repo.get("tokenizer.json")?,
|
||||||
@ -330,7 +362,7 @@ fn main() -> Result<()> {
|
|||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
|
||||||
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
|
// Initializing the tokenizer which would require us to add padding to the `left` for batch encoding
|
||||||
let tokenizer = create_tokenizer(tokenizer_filename.as_path())?;
|
let tokenizer = create_tokenizer(tokenizer_filename.as_path(), args.which)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
@ -343,11 +375,7 @@ fn main() -> Result<()> {
|
|||||||
let embed_vb =
|
let embed_vb =
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
|
unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? };
|
||||||
|
|
||||||
let model = EmbeddingModel::new(
|
let model = EmbeddingModel::new(&cfg, base_vb, embed_vb)?;
|
||||||
&Config::new_1_5_b_v5(embed_dim.embed_dim()),
|
|
||||||
base_vb,
|
|
||||||
embed_vb,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
@ -16,33 +16,49 @@
|
|||||||
//!
|
//!
|
||||||
|
|
||||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
|
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor};
|
use candle::{DType, Device, Error, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, VarBuilder};
|
use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
// internal representation for identifying which model is being used
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, serde::Deserialize)]
|
||||||
|
pub enum ModelVariant {
|
||||||
|
Large, // 1.5B
|
||||||
|
Small, // 400M
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ModelVariant {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::Large
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Same as `qwen2` family of models with the exception being the `embed_head`
|
// Same as `qwen2` family of models with the exception being the `embed_head`
|
||||||
// The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head`
|
// The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head`
|
||||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
|
pub variant: ModelVariant,
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
pub num_hidden_layers: usize,
|
pub num_hidden_layers: usize,
|
||||||
pub num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
pub num_key_value_heads: usize,
|
|
||||||
pub max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
pub max_window_layers: usize,
|
|
||||||
pub tie_word_embeddings: bool,
|
|
||||||
pub rope_theta: f64,
|
pub rope_theta: f64,
|
||||||
pub rms_norm_eps: f64,
|
|
||||||
pub hidden_act: Activation,
|
|
||||||
pub embed_head: EmbedHead,
|
pub embed_head: EmbedHead,
|
||||||
|
pub norm_eps: f64, // RMSNorm for 1.5B || LayerNorm for 400M
|
||||||
|
pub activation_fn: Activation, // Silu for 1.5B || Gelu for 400M
|
||||||
|
// Unique to 1.5B
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
// Unique to 400M
|
||||||
|
pub type_vocab_size: usize,
|
||||||
|
pub scaling_factor: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Excerpt from `stella` model card:
|
// Excerpt from `stella` model card:
|
||||||
// `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions
|
// `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions
|
||||||
// Embed head represents the config for various embedding dims supported
|
// Embed head represents the config for various embedding dims supported
|
||||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
#[derive(Debug, Default, Clone, PartialEq, serde::Deserialize)]
|
||||||
pub struct EmbedHead {
|
pub struct EmbedHead {
|
||||||
pub in_features: usize,
|
pub in_features: usize,
|
||||||
pub out_features: usize,
|
pub out_features: usize,
|
||||||
@ -68,9 +84,9 @@ impl Default for EmbedDim {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedDim {
|
impl EmbedDim {
|
||||||
pub fn config(&self) -> EmbedHead {
|
pub fn config(&self, in_features: usize) -> EmbedHead {
|
||||||
EmbedHead {
|
EmbedHead {
|
||||||
in_features: 1536,
|
in_features,
|
||||||
out_features: match &self {
|
out_features: match &self {
|
||||||
Self::Dim256 => 256,
|
Self::Dim256 => 256,
|
||||||
Self::Dim768 => 768,
|
Self::Dim768 => 768,
|
||||||
@ -91,7 +107,8 @@ impl Config {
|
|||||||
// Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json
|
// Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json
|
||||||
// Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here
|
// Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here
|
||||||
Self {
|
Self {
|
||||||
hidden_act: candle_nn::Activation::Silu,
|
variant: ModelVariant::Large,
|
||||||
|
activation_fn: candle_nn::Activation::Silu,
|
||||||
vocab_size: 151646,
|
vocab_size: 151646,
|
||||||
hidden_size: 1536,
|
hidden_size: 1536,
|
||||||
intermediate_size: 8960,
|
intermediate_size: 8960,
|
||||||
@ -99,11 +116,30 @@ impl Config {
|
|||||||
num_attention_heads: 12,
|
num_attention_heads: 12,
|
||||||
num_key_value_heads: 2,
|
num_key_value_heads: 2,
|
||||||
max_position_embeddings: 131072,
|
max_position_embeddings: 131072,
|
||||||
max_window_layers: 21,
|
|
||||||
tie_word_embeddings: false,
|
|
||||||
rope_theta: 1000000.,
|
rope_theta: 1000000.,
|
||||||
rms_norm_eps: 1e-06,
|
norm_eps: 1e-06,
|
||||||
embed_head: embed_dim.config(),
|
embed_head: embed_dim.config(1536),
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize new `stella_en_400M_v5`
|
||||||
|
pub fn new_400_m_v5(embed_dim: EmbedDim) -> Self {
|
||||||
|
Self {
|
||||||
|
variant: ModelVariant::Small,
|
||||||
|
vocab_size: 30528,
|
||||||
|
hidden_size: 1024,
|
||||||
|
intermediate_size: 4096,
|
||||||
|
num_hidden_layers: 24,
|
||||||
|
num_attention_heads: 16,
|
||||||
|
max_position_embeddings: 8192,
|
||||||
|
type_vocab_size: 2,
|
||||||
|
norm_eps: 1e-12,
|
||||||
|
scaling_factor: 2.0,
|
||||||
|
rope_theta: 160000.0,
|
||||||
|
activation_fn: Activation::Gelu,
|
||||||
|
embed_head: embed_dim.config(1024),
|
||||||
|
..Default::default()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -117,27 +153,57 @@ struct RotaryEmbedding {
|
|||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
// Factoring in `scaling factor` for `400M` variant
|
||||||
|
let max_seq_len = if cfg.scaling_factor == 0. {
|
||||||
|
cfg.max_position_embeddings
|
||||||
|
} else {
|
||||||
|
((cfg.max_position_embeddings as f64) * cfg.scaling_factor) as usize
|
||||||
|
};
|
||||||
|
|
||||||
|
// let rot_dim = if cfg.variant == ModelVariant::Small { dim / 2 } else { dim };
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
.map(|i| {
|
||||||
|
// Scaled rope_theta for 400M variant
|
||||||
|
let rope_theta = if cfg.scaling_factor == 0. {
|
||||||
|
cfg.rope_theta
|
||||||
|
} else {
|
||||||
|
cfg.rope_theta * cfg.scaling_factor
|
||||||
|
};
|
||||||
|
let mut freq = 1. / rope_theta.powf(i as f64 / dim as f64);
|
||||||
|
|
||||||
|
if cfg.scaling_factor != 0. {
|
||||||
|
freq /= cfg.scaling_factor.powf(2.0 / (dim as f64))
|
||||||
|
}
|
||||||
|
|
||||||
|
freq as f32
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let inv_freq_len = inv_freq.len();
|
let inv_freq_len = inv_freq.len();
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
|
|
||||||
|
// Calculate position embeddings with scaled sequence length
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
.to_dtype(dtype)?
|
.to_dtype(dtype)?
|
||||||
.reshape((max_seq_len, 1))?;
|
.reshape((max_seq_len, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
// if cfg.variant == ModelVariant::Small {
|
||||||
|
// freqs = Tensor::cat(&[&freqs, &freqs], 1)?
|
||||||
|
// }
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
sin: freqs.sin()?,
|
sin: freqs.sin()?,
|
||||||
cos: freqs.cos()?,
|
cos: freqs.cos()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: re-visit this
|
||||||
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
|
fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
let cos = self.cos.narrow(0, 0, seq_len)?;
|
let cos = self.cos.narrow(0, 0, seq_len)?;
|
||||||
let sin = self.sin.narrow(0, 0, seq_len)?;
|
let sin = self.sin.narrow(0, 0, seq_len)?;
|
||||||
|
|
||||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
@ -147,8 +213,9 @@ impl RotaryEmbedding {
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MLP {
|
struct MLP {
|
||||||
|
variant: ModelVariant,
|
||||||
gate_proj: Linear,
|
gate_proj: Linear,
|
||||||
up_proj: Linear,
|
up_proj: Option<Linear>, // `up_proj` only for 1.5B variant
|
||||||
down_proj: Linear,
|
down_proj: Linear,
|
||||||
act_fn: Activation,
|
act_fn: Activation,
|
||||||
}
|
}
|
||||||
@ -157,31 +224,65 @@ impl MLP {
|
|||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let hidden_sz = cfg.hidden_size;
|
let hidden_sz = cfg.hidden_size;
|
||||||
let intermediate_sz = cfg.intermediate_size;
|
let intermediate_sz = cfg.intermediate_size;
|
||||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
|
||||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
let (gate_proj, up_proj, down_proj) = match cfg.variant {
|
||||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
ModelVariant::Large => (
|
||||||
|
linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?,
|
||||||
|
Some(linear_no_bias(
|
||||||
|
hidden_sz,
|
||||||
|
intermediate_sz,
|
||||||
|
vb.pp("up_proj"),
|
||||||
|
)?),
|
||||||
|
linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
|
||||||
|
),
|
||||||
|
ModelVariant::Small => (
|
||||||
|
linear_no_bias(hidden_sz, intermediate_sz * 2, vb.pp("up_gate_proj"))?,
|
||||||
|
None,
|
||||||
|
linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
variant: cfg.variant,
|
||||||
gate_proj,
|
gate_proj,
|
||||||
up_proj,
|
up_proj,
|
||||||
down_proj,
|
down_proj,
|
||||||
act_fn: cfg.hidden_act,
|
act_fn: cfg.activation_fn,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for MLP {
|
impl Module for MLP {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
let up = self.gate_proj.forward(xs)?;
|
||||||
let rhs = xs.apply(&self.up_proj)?;
|
|
||||||
|
let (lhs, rhs) = match self.variant {
|
||||||
|
ModelVariant::Large => {
|
||||||
|
let lhs = up.apply(&self.act_fn)?;
|
||||||
|
let rhs = xs.apply(self.up_proj.as_ref().unwrap())?;
|
||||||
|
|
||||||
|
(lhs, rhs)
|
||||||
|
}
|
||||||
|
ModelVariant::Small => {
|
||||||
|
// Get the dimensions
|
||||||
|
let (_batch_size, _seq_len, hidden_dim) = up.dims3()?;
|
||||||
|
let split_size = hidden_dim / 2;
|
||||||
|
|
||||||
|
// Split along the last dimension (hidden_dim)
|
||||||
|
let up_states = up.narrow(2, 0, split_size)?;
|
||||||
|
let gate = up.narrow(2, split_size, split_size)?.apply(&self.act_fn)?;
|
||||||
|
|
||||||
|
(up_states, gate)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
(lhs * rhs)?.apply(&self.down_proj)
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Attention {
|
struct Attention {
|
||||||
q_proj: Linear,
|
qkv_proj: Linear,
|
||||||
k_proj: Linear,
|
|
||||||
v_proj: Linear,
|
|
||||||
o_proj: Linear,
|
o_proj: Linear,
|
||||||
num_heads: usize,
|
num_heads: usize,
|
||||||
num_kv_heads: usize,
|
num_kv_heads: usize,
|
||||||
@ -189,6 +290,7 @@ struct Attention {
|
|||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
rotary_emb: Arc<RotaryEmbedding>,
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
variant: ModelVariant,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
@ -196,16 +298,47 @@ impl Attention {
|
|||||||
let hidden_sz = cfg.hidden_size;
|
let hidden_sz = cfg.hidden_size;
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
let num_kv_groups = num_heads / num_kv_heads;
|
let num_kv_groups = if num_kv_heads > 0 {
|
||||||
|
num_heads / num_kv_heads
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
let head_dim = hidden_sz / num_heads;
|
let head_dim = hidden_sz / num_heads;
|
||||||
let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
|
||||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
let (qkv_proj, o_proj) = match cfg.variant {
|
||||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
ModelVariant::Large => {
|
||||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
// The 1.5B variant comes with separate `q, k, v` layers, let's merge it and standardize
|
||||||
|
// Weights
|
||||||
|
let q_w = vb
|
||||||
|
.pp("q_proj")
|
||||||
|
.get((num_heads * head_dim, hidden_sz), "weight")?;
|
||||||
|
let k_w = vb
|
||||||
|
.pp("k_proj")
|
||||||
|
.get((num_kv_heads * head_dim, hidden_sz), "weight")?;
|
||||||
|
let v_w = vb
|
||||||
|
.pp("v_proj")
|
||||||
|
.get((num_kv_heads * head_dim, hidden_sz), "weight")?;
|
||||||
|
// Biases
|
||||||
|
let q_b = vb.pp("q_proj").get(num_heads * head_dim, "bias")?;
|
||||||
|
let k_b = vb.pp("k_proj").get(num_kv_heads * head_dim, "bias")?;
|
||||||
|
let v_b = vb.pp("v_proj").get(num_kv_heads * head_dim, "bias")?;
|
||||||
|
|
||||||
|
let qkv_w = Tensor::cat(&[&q_w, &k_w, &v_w], 0)?;
|
||||||
|
let qkv_b = Tensor::cat(&[&q_b, &k_b, &v_b], 0)?;
|
||||||
|
|
||||||
|
(
|
||||||
|
Linear::from_weights(qkv_w, Some(qkv_b)),
|
||||||
|
linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
ModelVariant::Small => (
|
||||||
|
linear(hidden_sz, 3 * num_heads * head_dim, vb.pp("qkv_proj"))?,
|
||||||
|
linear(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
q_proj,
|
qkv_proj,
|
||||||
k_proj,
|
|
||||||
v_proj,
|
|
||||||
o_proj,
|
o_proj,
|
||||||
num_heads,
|
num_heads,
|
||||||
num_kv_heads,
|
num_kv_heads,
|
||||||
@ -213,45 +346,90 @@ impl Attention {
|
|||||||
head_dim,
|
head_dim,
|
||||||
hidden_size: hidden_sz,
|
hidden_size: hidden_sz,
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
|
variant: cfg.variant,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
let (b_sz, q_len, _) = xs.dims3()?;
|
let (b_sz, q_len, _) = xs.dims3()?;
|
||||||
|
|
||||||
let query_states = self.q_proj.forward(xs)?;
|
let qkv = self.qkv_proj.forward(xs)?;
|
||||||
let key_states = self.k_proj.forward(xs)?;
|
|
||||||
let value_states = self.v_proj.forward(xs)?;
|
|
||||||
|
|
||||||
let query_states = query_states
|
let n_kv_heads = match self.variant {
|
||||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
ModelVariant::Large => self.num_kv_heads,
|
||||||
.transpose(1, 2)?;
|
ModelVariant::Small => self.num_heads,
|
||||||
let key_states = key_states
|
};
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?;
|
let (query_states, key_states, value_states) = match self.variant {
|
||||||
let value_states = value_states
|
ModelVariant::Large => {
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
let q_sz = self.num_heads * self.head_dim;
|
||||||
.transpose(1, 2)?;
|
let kv_sz = n_kv_heads * self.head_dim;
|
||||||
|
|
||||||
|
let q = qkv.narrow(D::Minus1, 0, q_sz)?.reshape((
|
||||||
|
b_sz,
|
||||||
|
q_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
))?;
|
||||||
|
let k = qkv.narrow(D::Minus1, q_sz, kv_sz)?.reshape((
|
||||||
|
b_sz,
|
||||||
|
q_len,
|
||||||
|
n_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
))?;
|
||||||
|
let v = qkv.narrow(D::Minus1, q_sz + kv_sz, kv_sz)?.reshape((
|
||||||
|
b_sz,
|
||||||
|
q_len,
|
||||||
|
n_kv_heads,
|
||||||
|
self.head_dim,
|
||||||
|
))?;
|
||||||
|
|
||||||
|
(q, k, v)
|
||||||
|
}
|
||||||
|
ModelVariant::Small => {
|
||||||
|
// Split into Q, K, V and reshape to match PyTorch shapes
|
||||||
|
let qkv = qkv.reshape((b_sz, q_len, 3, self.num_heads, self.head_dim))?;
|
||||||
|
|
||||||
|
(
|
||||||
|
qkv.i((.., .., 0, .., ..))?,
|
||||||
|
qkv.i((.., .., 1, .., ..))?,
|
||||||
|
qkv.i((.., .., 2, .., ..))?,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let query_states = query_states.transpose(1, 2)?.contiguous()?;
|
||||||
|
let key_states = key_states.transpose(1, 2)?.contiguous()?;
|
||||||
|
let value_states = value_states.transpose(1, 2)?.contiguous()?;
|
||||||
|
|
||||||
let (query_states, key_states) = self
|
let (query_states, key_states) = self
|
||||||
.rotary_emb
|
.rotary_emb
|
||||||
.apply_rotary_emb_qkv(&query_states, &key_states)?;
|
.apply_rotary_emb_qkv(&query_states, &key_states)?;
|
||||||
|
|
||||||
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
// The 1.5B is expected to have grouped query attention
|
||||||
let value_states =
|
let (key_states, value_states) = if self.variant == ModelVariant::Large {
|
||||||
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
(
|
||||||
|
crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?,
|
||||||
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(key_states, value_states)
|
||||||
|
};
|
||||||
|
|
||||||
let attn_output = {
|
let attn_output = {
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
let attn_weights = query_states.matmul(&key_states.transpose(2, 3)?)?;
|
||||||
|
let attn_weights = (attn_weights * scale)?;
|
||||||
|
|
||||||
let attn_weights = match attention_mask {
|
let attn_weights = match attention_mask {
|
||||||
None => attn_weights,
|
None => attn_weights,
|
||||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
};
|
};
|
||||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
|
||||||
attn_weights.matmul(&value_states)?
|
attn_weights.matmul(&value_states)?
|
||||||
};
|
};
|
||||||
|
|
||||||
attn_output
|
attn_output
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.reshape((b_sz, q_len, self.hidden_size))?
|
.reshape((b_sz, q_len, self.hidden_size))?
|
||||||
@ -260,70 +438,282 @@ impl Attention {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecoderLayer {
|
enum NormType {
|
||||||
self_attn: Attention,
|
Layer(LayerNorm),
|
||||||
mlp: MLP,
|
Rms(RmsNorm),
|
||||||
input_layernorm: RmsNorm,
|
|
||||||
post_attention_layernorm: RmsNorm,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DecoderLayer {
|
#[derive(Debug, Clone)]
|
||||||
|
struct Layer {
|
||||||
|
variant: ModelVariant,
|
||||||
|
attention: Attention,
|
||||||
|
mlp: MLP,
|
||||||
|
// For 1.5B: this is `input_layernorm`
|
||||||
|
// For 400M: this is `output_layernorm`
|
||||||
|
layernorm: NormType,
|
||||||
|
post_attention_layernorm: NormType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Layer {
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
let attention = Attention::new(
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
rotary_emb,
|
||||||
let input_layernorm =
|
cfg,
|
||||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
vb.pp(if cfg.variant == ModelVariant::Large {
|
||||||
let post_attention_layernorm = RmsNorm::new(
|
"self_attn"
|
||||||
cfg.hidden_size,
|
} else {
|
||||||
cfg.rms_norm_eps,
|
"attention"
|
||||||
vb.pp("post_attention_layernorm"),
|
}),
|
||||||
)?;
|
)?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let (layernorm, post_attention_layernorm) = match cfg.variant {
|
||||||
|
ModelVariant::Large => (
|
||||||
|
NormType::Rms(RmsNorm::new(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.norm_eps,
|
||||||
|
vb.pp("input_layernorm"),
|
||||||
|
)?),
|
||||||
|
NormType::Rms(RmsNorm::new(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.norm_eps,
|
||||||
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?),
|
||||||
|
),
|
||||||
|
ModelVariant::Small => (
|
||||||
|
NormType::Layer(layer_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
candle_nn::LayerNormConfig {
|
||||||
|
eps: cfg.norm_eps,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
vb.pp("mlp_ln"),
|
||||||
|
)?),
|
||||||
|
NormType::Layer(layer_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
candle_nn::LayerNormConfig {
|
||||||
|
eps: cfg.norm_eps,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
vb.pp("attn_ln"),
|
||||||
|
)?),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
variant: cfg.variant,
|
||||||
|
attention,
|
||||||
mlp,
|
mlp,
|
||||||
input_layernorm,
|
layernorm,
|
||||||
post_attention_layernorm,
|
post_attention_layernorm,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
// Here, the application of normalizations and activation calculations differ
|
||||||
|
// For Large [1.5B]:
|
||||||
|
// residual = x
|
||||||
|
// state = other_layernorm(xs)
|
||||||
|
// state = attention(state)
|
||||||
|
// state += residual
|
||||||
|
// residual = state
|
||||||
|
// state = mlp(attention_layernorm(state))
|
||||||
|
// -> residual + state
|
||||||
|
// For Small [400M]:
|
||||||
|
// residual = x;
|
||||||
|
// state = attention(x)
|
||||||
|
// state += residual
|
||||||
|
// state = attention_layernorm(state)
|
||||||
|
// residual = state
|
||||||
|
// state = mlp(state)
|
||||||
|
// state += residual
|
||||||
|
// -> other_layernorm(state)
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs = self.input_layernorm.forward(xs)?;
|
|
||||||
let xs = self.self_attn.forward(&xs, attention_mask)?;
|
match self.variant {
|
||||||
let xs = (xs + residual)?;
|
ModelVariant::Large => {
|
||||||
let residual = &xs;
|
let (attn_ln, input_ln) = if let (NormType::Rms(attn_ln), NormType::Rms(input_ln)) =
|
||||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
(&self.post_attention_layernorm, &self.layernorm)
|
||||||
residual + xs
|
{
|
||||||
|
(attn_ln, input_ln)
|
||||||
|
} else {
|
||||||
|
return Err(candle::error::Error::Msg(
|
||||||
|
"Stella 1.5B expects RMSNorm".to_string(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let xs = input_ln.forward(xs)?;
|
||||||
|
let xs = (self.attention.forward(&xs, attention_mask)? + residual)?;
|
||||||
|
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = xs.apply(attn_ln)?.apply(&self.mlp)?;
|
||||||
|
|
||||||
|
residual + xs
|
||||||
|
}
|
||||||
|
ModelVariant::Small => {
|
||||||
|
let (attn_ln, output_ln) =
|
||||||
|
if let (NormType::Layer(attn_ln), NormType::Layer(input_ln)) =
|
||||||
|
(&self.post_attention_layernorm, &self.layernorm)
|
||||||
|
{
|
||||||
|
(attn_ln, input_ln)
|
||||||
|
} else {
|
||||||
|
return Err(candle::error::Error::Msg(
|
||||||
|
"Stella 400M expects RMSNorm".to_string(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let xs = (self.attention.forward(xs, attention_mask)? + residual)?;
|
||||||
|
let xs = attn_ln.forward(&xs)?;
|
||||||
|
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = (self.mlp.forward(&xs)? + residual)?;
|
||||||
|
|
||||||
|
output_ln.forward(&xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Embeddings {
|
||||||
|
variant: ModelVariant,
|
||||||
|
// For 1.5B: this is the `embed_tokens`
|
||||||
|
// For 400M: this is the `word_embeddings`
|
||||||
|
embeddings: candle_nn::Embedding,
|
||||||
|
// folloing are specifically for 400M
|
||||||
|
token_type_embeddings: Option<candle_nn::Embedding>,
|
||||||
|
layer_norm: Option<LayerNorm>,
|
||||||
|
position_ids: Option<Tensor>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embeddings {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let (embeddings, token_type_embeddings, layer_norm, position_ids) = match cfg.variant {
|
||||||
|
ModelVariant::Large => (
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("embed_tokens"))?,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
),
|
||||||
|
ModelVariant::Small => {
|
||||||
|
let vb = vb.pp("embeddings");
|
||||||
|
let weight = vb.pp("LayerNorm").get_with_hints(
|
||||||
|
cfg.hidden_size,
|
||||||
|
"weight",
|
||||||
|
candle_nn::Init::Const(1.0),
|
||||||
|
)?;
|
||||||
|
let bias = vb.pp("LayerNorm").get_with_hints(
|
||||||
|
cfg.hidden_size,
|
||||||
|
"bias",
|
||||||
|
candle_nn::Init::Const(0.0),
|
||||||
|
)?;
|
||||||
|
let dev = bias.device().clone();
|
||||||
|
|
||||||
|
let layer_norm = candle_nn::LayerNorm::new(weight, bias, cfg.norm_eps);
|
||||||
|
|
||||||
|
(
|
||||||
|
candle_nn::embedding(
|
||||||
|
cfg.vocab_size,
|
||||||
|
cfg.hidden_size,
|
||||||
|
vb.pp("word_embeddings"),
|
||||||
|
)?,
|
||||||
|
Some(candle_nn::embedding(
|
||||||
|
cfg.type_vocab_size,
|
||||||
|
cfg.hidden_size,
|
||||||
|
vb.pp("token_type_embeddings"),
|
||||||
|
)?),
|
||||||
|
Some(layer_norm),
|
||||||
|
Some(Tensor::arange(
|
||||||
|
0u32,
|
||||||
|
cfg.max_position_embeddings as u32,
|
||||||
|
&dev,
|
||||||
|
)?),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
variant: cfg.variant,
|
||||||
|
embeddings,
|
||||||
|
token_type_embeddings,
|
||||||
|
layer_norm,
|
||||||
|
position_ids,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Embeddings {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let embd = self.embeddings.forward(xs)?;
|
||||||
|
// For 1.5B just forward the embeddings
|
||||||
|
if self.variant == ModelVariant::Large {
|
||||||
|
return Ok(embd);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (token_type_embed, layer_norm, pos_ids) =
|
||||||
|
if let (Some(token_type_embd), Some(layer_norm), Some(position_ids)) = (
|
||||||
|
&self.token_type_embeddings,
|
||||||
|
&self.layer_norm,
|
||||||
|
&self.position_ids,
|
||||||
|
) {
|
||||||
|
(token_type_embd, layer_norm, position_ids)
|
||||||
|
} else {
|
||||||
|
return Err(Error::Msg(
|
||||||
|
"Stella 400M requires `token_type_embeddings`, `layer_norm` and `position_ids`"
|
||||||
|
.to_string(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
let (batch_size, seq_length) = xs.dims2()?;
|
||||||
|
|
||||||
|
let pos_ids = pos_ids
|
||||||
|
.as_ref()
|
||||||
|
.narrow(0, 0, seq_length)?
|
||||||
|
.expand((batch_size, seq_length))?;
|
||||||
|
|
||||||
|
layer_norm.forward(&embd.add(&token_type_embed.forward(&pos_ids.zeros_like()?)?)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
embed_tokens: candle_nn::Embedding,
|
embeddings: Embeddings,
|
||||||
layers: Vec<DecoderLayer>,
|
layers: Vec<Layer>,
|
||||||
norm: RmsNorm,
|
norm: Option<RmsNorm>,
|
||||||
device: Device,
|
device: Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let vb_m = vb.pp("model");
|
let vb_m = match cfg.variant {
|
||||||
let embed_tokens =
|
ModelVariant::Large => vb.pp("model"),
|
||||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
ModelVariant::Small => vb.pp("new"),
|
||||||
|
};
|
||||||
|
// let embed_tokens =
|
||||||
|
// candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
|
let embeddings = Embeddings::new(cfg, vb_m.clone())?;
|
||||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
let vb_l = vb_m.pp("layers");
|
let vb_l = match cfg.variant {
|
||||||
|
ModelVariant::Large => vb_m.pp("layers"),
|
||||||
|
ModelVariant::Small => vb_m.pp("encoder").pp("layer"),
|
||||||
|
};
|
||||||
for layer_idx in 0..cfg.num_hidden_layers {
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
let layer = Layer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||||
layers.push(layer)
|
layers.push(layer)
|
||||||
}
|
}
|
||||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
let norm = match cfg.variant {
|
||||||
|
ModelVariant::Large => Some(RmsNorm::new(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.norm_eps,
|
||||||
|
vb_m.pp("norm"),
|
||||||
|
)?),
|
||||||
|
ModelVariant::Small => None,
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
embed_tokens,
|
embeddings,
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
// sliding_window: 0,
|
|
||||||
device: vb.device().clone(),
|
device: vb.device().clone(),
|
||||||
dtype: vb.dtype(),
|
dtype: vb.dtype(),
|
||||||
})
|
})
|
||||||
@ -352,15 +742,20 @@ impl Model {
|
|||||||
Some(self.prepare_attention_mask(mask)?)
|
Some(self.prepare_attention_mask(mask)?)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
let mut xs = self.embeddings.forward(input_ids)?;
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
xs = layer.forward(&xs, attention_mask.as_ref())?
|
xs = layer.forward(&xs, attention_mask.as_ref())?
|
||||||
}
|
}
|
||||||
xs.apply(&self.norm)
|
|
||||||
|
if let Some(n) = &self.norm {
|
||||||
|
xs.apply(n)
|
||||||
|
} else {
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
pub struct EmbeddingModel {
|
pub struct EmbeddingModel {
|
||||||
base_model: Model,
|
base_model: Model,
|
||||||
lm_head: Linear,
|
lm_head: Linear,
|
||||||
|
Reference in New Issue
Block a user