mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the helium model. (#2715)
This commit is contained in:
11
candle-examples/examples/helium/README.md
Normal file
11
candle-examples/examples/helium/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# candle-helium: 2b LLM with CC-BY licensed weights
|
||||||
|
|
||||||
|
- [Model card](https://huggingface.co/kyutai/helium-1-preview) on the HuggingFace Hub.
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example helium --release --features cuda -- --prompt 'Write helloworld code in Rust' --sample-len 150
|
||||||
|
```
|
||||||
|
|
||||||
|
|
292
candle-examples/examples/helium/main.rs
Normal file
292
candle-examples/examples/helium/main.rs
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::{Error as E, Result};
|
||||||
|
use clap::Parser;
|
||||||
|
|
||||||
|
use candle_transformers::models::helium::{Config, Model};
|
||||||
|
|
||||||
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
struct TextGeneration {
|
||||||
|
model: Model,
|
||||||
|
device: Device,
|
||||||
|
tokenizer: TokenOutputStream,
|
||||||
|
logits_processor: LogitsProcessor,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TextGeneration {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
model: Model,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temp: Option<f64>,
|
||||||
|
top_p: Option<f64>,
|
||||||
|
top_k: Option<usize>,
|
||||||
|
repeat_penalty: f32,
|
||||||
|
repeat_last_n: usize,
|
||||||
|
config: Config,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
|
let logits_processor = {
|
||||||
|
let temperature = temp.unwrap_or(0.);
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (top_k, top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
|
logits_processor,
|
||||||
|
repeat_penalty,
|
||||||
|
repeat_last_n,
|
||||||
|
device: device.clone(),
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||||
|
use std::io::Write;
|
||||||
|
self.tokenizer.clear();
|
||||||
|
let mut tokens = self
|
||||||
|
.tokenizer
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt, true)
|
||||||
|
.map_err(E::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
for &t in tokens.iter() {
|
||||||
|
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||||
|
print!("{t}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut generated_tokens = 0usize;
|
||||||
|
let start_gen = std::time::Instant::now();
|
||||||
|
for index in 0..sample_len {
|
||||||
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
|
let start_pos = tokens.len().saturating_sub(context_size);
|
||||||
|
let ctxt = &tokens[start_pos..];
|
||||||
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
|
let logits = self.model.forward(&input, start_pos)?;
|
||||||
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
let logits = if self.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
self.repeat_penalty,
|
||||||
|
&tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
|
||||||
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
|
tokens.push(next_token);
|
||||||
|
generated_tokens += 1;
|
||||||
|
if next_token == self.config.bos_token_id || next_token == self.config.eos_token_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let dt = start_gen.elapsed();
|
||||||
|
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
println!(
|
||||||
|
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||||
|
generated_tokens as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "v1-preview")]
|
||||||
|
V1Preview,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
use_flash_attn: bool,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long, default_value_t = 0.7)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "v1-preview")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "main")]
|
||||||
|
revision: String,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
config: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
weights: Option<String>,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_id = match args.model_id {
|
||||||
|
Some(model_id) => model_id,
|
||||||
|
None => {
|
||||||
|
let name = match args.which {
|
||||||
|
Which::V1Preview => "kyutai/helium-1-preview",
|
||||||
|
};
|
||||||
|
name.to_string()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
model_id,
|
||||||
|
RepoType::Model,
|
||||||
|
args.revision,
|
||||||
|
));
|
||||||
|
let tokenizer_filename = match args.tokenizer {
|
||||||
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
|
None => repo.get("tokenizer.json")?,
|
||||||
|
};
|
||||||
|
let filenames = match args.weights {
|
||||||
|
Some(files) => files
|
||||||
|
.split(',')
|
||||||
|
.map(std::path::PathBuf::from)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors")?,
|
||||||
|
};
|
||||||
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let config: Config = match args.config {
|
||||||
|
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
|
||||||
|
None => {
|
||||||
|
let config_file = repo.get("config.json")?;
|
||||||
|
serde_json::from_slice(&std::fs::read(config_file)?)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let (model, device) = {
|
||||||
|
let dtype = if device.is_cuda() {
|
||||||
|
DType::BF16
|
||||||
|
} else {
|
||||||
|
DType::F32
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
(model, device)
|
||||||
|
};
|
||||||
|
|
||||||
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
|
let mut pipeline = TextGeneration::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
Some(args.temperature),
|
||||||
|
args.top_p,
|
||||||
|
args.top_k,
|
||||||
|
args.repeat_penalty,
|
||||||
|
args.repeat_last_n,
|
||||||
|
config,
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
395
candle-transformers/src/models/helium.rs
Normal file
395
candle-transformers/src/models/helium.rs
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
//! Helium inference implementation.
|
||||||
|
//!
|
||||||
|
//! See the model card on Hugging Face's [hub](https://huggingface.co/kmhf/helium-2b).
|
||||||
|
|
||||||
|
use super::with_tracing::{linear_b as linear, Linear, RmsNorm};
|
||||||
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
fn default_use_flash_attn() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub attention_bias: bool,
|
||||||
|
pub bos_token_id: u32,
|
||||||
|
pub eos_token_id: u32,
|
||||||
|
pub head_dim: usize,
|
||||||
|
pub hidden_act: candle_nn::Activation,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub mlp_bias: bool,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
pub rope_theta: f64,
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
pub vocab_size: usize,
|
||||||
|
#[serde(default = "default_use_flash_attn")]
|
||||||
|
pub use_flash_attn: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn config_2b(use_flash_attn: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
attention_bias: false,
|
||||||
|
bos_token_id: 1,
|
||||||
|
eos_token_id: 2,
|
||||||
|
head_dim: 128,
|
||||||
|
hidden_act: candle_nn::Activation::Silu,
|
||||||
|
hidden_size: 2560,
|
||||||
|
intermediate_size: 7040,
|
||||||
|
max_position_embeddings: 4096,
|
||||||
|
mlp_bias: false,
|
||||||
|
num_attention_heads: 20,
|
||||||
|
num_hidden_layers: 24,
|
||||||
|
num_key_value_heads: 20,
|
||||||
|
rms_norm_eps: 1e-08,
|
||||||
|
rope_theta: 100000.0,
|
||||||
|
tie_word_embeddings: false,
|
||||||
|
vocab_size: 48000,
|
||||||
|
use_flash_attn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let rope_theta = cfg.rope_theta as f32;
|
||||||
|
let dim = cfg.head_dim;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / rope_theta.powf(i as f32 / dim as f32))
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(DType::F32)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?.to_dtype(dtype)?,
|
||||||
|
cos: freqs.cos()?.to_dtype(dtype)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope_i(q, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope_i(k, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MLP {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
act_fn: candle_nn::Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MLP {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_sz = cfg.hidden_size;
|
||||||
|
let intermediate_sz = cfg.intermediate_size;
|
||||||
|
let bias = cfg.mlp_bias;
|
||||||
|
let gate_proj = linear(hidden_sz, intermediate_sz, bias, vb.pp("gate_proj"))?;
|
||||||
|
let up_proj = linear(hidden_sz, intermediate_sz, bias, vb.pp("up_proj"))?;
|
||||||
|
let down_proj = linear(intermediate_sz, hidden_sz, bias, vb.pp("down_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MLP {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = xs.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "flash-attn")]
|
||||||
|
fn flash_attn(
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
v: &Tensor,
|
||||||
|
softmax_scale: f32,
|
||||||
|
causal: bool,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "flash-attn"))]
|
||||||
|
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
|
||||||
|
unimplemented!("compile with '--features flash-attn'")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
use_flash_attn: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_sz = cfg.hidden_size;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let bias = cfg.attention_bias;
|
||||||
|
let q_proj = linear(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?;
|
||||||
|
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache: None,
|
||||||
|
use_flash_attn: cfg.use_flash_attn,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_sz, q_len, _) = xs.dims3()?;
|
||||||
|
|
||||||
|
let query_states = self.q_proj.forward(xs)?;
|
||||||
|
let key_states = self.k_proj.forward(xs)?;
|
||||||
|
let value_states = self.v_proj.forward(xs)?;
|
||||||
|
|
||||||
|
let query_states = query_states
|
||||||
|
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let key_states = key_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let value_states = value_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
|
||||||
|
let (query_states, key_states) =
|
||||||
|
self.rotary_emb
|
||||||
|
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||||
|
|
||||||
|
let (key_states, value_states) = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||||
|
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
|
||||||
|
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;
|
||||||
|
|
||||||
|
let attn_output = if self.use_flash_attn {
|
||||||
|
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
|
||||||
|
let q = query_states.transpose(1, 2)?;
|
||||||
|
let k = key_states.transpose(1, 2)?;
|
||||||
|
let v = value_states.transpose(1, 2)?;
|
||||||
|
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
|
||||||
|
flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
|
||||||
|
} else {
|
||||||
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
let attn_weights = match attention_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
|
};
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
attn_weights.matmul(&value_states)?
|
||||||
|
};
|
||||||
|
attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, q_len, self.num_heads * self.head_dim))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DecoderLayer {
|
||||||
|
self_attn: Attention,
|
||||||
|
mlp: MLP,
|
||||||
|
input_layernorm: RmsNorm,
|
||||||
|
post_attention_layernorm: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderLayer {
|
||||||
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let input_layernorm =
|
||||||
|
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
|
let post_attention_layernorm = RmsNorm::new(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
input_layernorm,
|
||||||
|
post_attention_layernorm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self.input_layernorm.forward(xs)?;
|
||||||
|
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
||||||
|
residual + xs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embed_tokens: candle_nn::Embedding,
|
||||||
|
layers: Vec<DecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
lm_head: Linear,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_m = vb.pp("model");
|
||||||
|
let embed_tokens =
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
|
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_l = vb_m.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::from_weights(embed_tokens.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
linear(cfg.hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
lm_head,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_decoder_attention_mask(
|
||||||
|
&self,
|
||||||
|
tgt_len: usize,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
|
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||||
|
let mask = if seqlen_offset > 0 {
|
||||||
|
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((1, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
|
.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_tokens(&self) -> &candle_nn::Embedding {
|
||||||
|
&self.embed_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
let (_b_size, seq_len) = input_ids.dims2()?;
|
||||||
|
let attention_mask = if seq_len <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mask = self.prepare_decoder_attention_mask(seq_len, seqlen_offset)?;
|
||||||
|
Some(mask)
|
||||||
|
};
|
||||||
|
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||||
|
}
|
||||||
|
xs.narrow(1, seq_len - 1, 1)?
|
||||||
|
.apply(&self.norm)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
layer.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -43,6 +43,7 @@ pub mod gemma;
|
|||||||
pub mod gemma2;
|
pub mod gemma2;
|
||||||
pub mod glm4;
|
pub mod glm4;
|
||||||
pub mod granite;
|
pub mod granite;
|
||||||
|
pub mod helium;
|
||||||
pub mod hiera;
|
pub mod hiera;
|
||||||
pub mod jina_bert;
|
pub mod jina_bert;
|
||||||
pub mod llama;
|
pub mod llama;
|
||||||
|
Reference in New Issue
Block a user