mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add support for Llama 3.1 (#2359)
* Add Llama 3.1 rope * Clippy * Format * Clippy * Add support for multiple eos tokens: * Untagged either * Remove either dep and fix settings.json * Make the max positional embeddings configurable
This commit is contained in:
@ -12,7 +12,7 @@ fn run_affine_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name:
|
|||||||
let m = 1024;
|
let m = 1024;
|
||||||
let k = 1024;
|
let k = 1024;
|
||||||
|
|
||||||
let tensor = Tensor::zeros((b, m, k), dtype, &device).unwrap();
|
let tensor = Tensor::zeros((b, m, k), dtype, device).unwrap();
|
||||||
|
|
||||||
let flops = b * m * k * dtype.size_in_bytes();
|
let flops = b * m * k * dtype.size_in_bytes();
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
fn run(matmul: &QMatMul, x: &Tensor) {
|
fn run(matmul: &QMatMul, x: &Tensor) {
|
||||||
matmul.forward(&x).unwrap();
|
matmul.forward(x).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
||||||
@ -50,7 +50,7 @@ fn run_bench(c: &mut Criterion, device: &Device, dtype: GgmlDType) {
|
|||||||
fn criterion_benchmark(c: &mut Criterion) {
|
fn criterion_benchmark(c: &mut Criterion) {
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
let handler = BenchDeviceHandler::new().unwrap();
|
||||||
for device in handler.devices {
|
for device in handler.devices {
|
||||||
for dtype in vec![
|
for dtype in [
|
||||||
GgmlDType::F32,
|
GgmlDType::F32,
|
||||||
GgmlDType::F16,
|
GgmlDType::F16,
|
||||||
GgmlDType::Q4_0,
|
GgmlDType::Q4_0,
|
||||||
|
@ -12,7 +12,7 @@ fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &
|
|||||||
let m = 1024;
|
let m = 1024;
|
||||||
let k = 1024;
|
let k = 1024;
|
||||||
|
|
||||||
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, &device)
|
let tensor = Tensor::arange(0.0f32, (b * m * k) as f32, device)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.to_dtype(dtype)
|
.to_dtype(dtype)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -25,9 +25,9 @@ const SIZE: usize = B * M * K;
|
|||||||
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
|
||||||
|
|
||||||
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
|
||||||
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), &device).unwrap();
|
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
|
||||||
let on_true = Tensor::ones((B, M, K), dtype, &device).unwrap();
|
let on_true = Tensor::ones((B, M, K), dtype, device).unwrap();
|
||||||
let on_false = Tensor::zeros((B, M, K), dtype, &device).unwrap();
|
let on_false = Tensor::zeros((B, M, K), dtype, device).unwrap();
|
||||||
|
|
||||||
let elements = B * M * K;
|
let elements = B * M * K;
|
||||||
// E.g. 2 f32 tensors + 1 u8 tensor
|
// E.g. 2 f32 tensors + 1 u8 tensor
|
||||||
|
@ -35,7 +35,7 @@ serde = { workspace = true }
|
|||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
cpal= { version = "0.15.2", optional = true }
|
cpal = { version = "0.15.2", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
|
@ -32,7 +32,9 @@ enum Which {
|
|||||||
V1,
|
V1,
|
||||||
V2,
|
V2,
|
||||||
V3,
|
V3,
|
||||||
|
V31,
|
||||||
V3Instruct,
|
V3Instruct,
|
||||||
|
V31Instruct,
|
||||||
#[value(name = "solar-10.7b")]
|
#[value(name = "solar-10.7b")]
|
||||||
Solar10_7B,
|
Solar10_7B,
|
||||||
#[value(name = "tiny-llama-1.1b-chat")]
|
#[value(name = "tiny-llama-1.1b-chat")]
|
||||||
@ -133,6 +135,8 @@ fn main() -> Result<()> {
|
|||||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||||
|
Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(),
|
||||||
|
Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(),
|
||||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||||
});
|
});
|
||||||
@ -146,7 +150,13 @@ fn main() -> Result<()> {
|
|||||||
let config = config.into_config(args.use_flash_attn);
|
let config = config.into_config(args.use_flash_attn);
|
||||||
|
|
||||||
let filenames = match args.which {
|
let filenames = match args.which {
|
||||||
Which::V1 | Which::V2 | Which::V3 | Which::V3Instruct | Which::Solar10_7B => {
|
Which::V1
|
||||||
|
| Which::V2
|
||||||
|
| Which::V3
|
||||||
|
| Which::V3Instruct
|
||||||
|
| Which::V31
|
||||||
|
| Which::V31Instruct
|
||||||
|
| Which::Solar10_7B => {
|
||||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
Which::TinyLlama1_1BChat => vec![api.get("model.safetensors")?],
|
||||||
@ -157,9 +167,11 @@ fn main() -> Result<()> {
|
|||||||
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
(Llama::load(vb, &config)?, tokenizer_filename, cache, config)
|
||||||
};
|
};
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let eos_token_id = config
|
let eos_token_id = config.eos_token_id.or_else(|| {
|
||||||
.eos_token_id
|
tokenizer
|
||||||
.or_else(|| tokenizer.token_to_id(EOS_TOKEN));
|
.token_to_id(EOS_TOKEN)
|
||||||
|
.map(model::LlamaEosToks::Single)
|
||||||
|
});
|
||||||
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
let prompt = args.prompt.as_ref().map_or(DEFAULT_PROMPT, |p| p.as_str());
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(prompt, true)
|
.encode(prompt, true)
|
||||||
@ -217,9 +229,15 @@ fn main() -> Result<()> {
|
|||||||
token_generated += 1;
|
token_generated += 1;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
|
|
||||||
if Some(next_token) == eos_token_id {
|
match eos_token_id {
|
||||||
|
Some(model::LlamaEosToks::Single(eos_tok_id)) if next_token == eos_tok_id => {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
Some(model::LlamaEosToks::Multiple(ref eos_ids)) if eos_ids.contains(&next_token) => {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
if let Some(t) = tokenizer.next_token(next_token)? {
|
if let Some(t) = tokenizer.next_token(next_token)? {
|
||||||
print!("{t}");
|
print!("{t}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
|
@ -272,7 +272,7 @@ impl Darknet {
|
|||||||
let mut prev_channels: usize = 3;
|
let mut prev_channels: usize = 3;
|
||||||
for (index, block) in self.blocks.iter().enumerate() {
|
for (index, block) in self.blocks.iter().enumerate() {
|
||||||
let channels_and_bl = match block.block_type.as_str() {
|
let channels_and_bl = match block.block_type.as_str() {
|
||||||
"convolutional" => conv(vb.pp(&index.to_string()), index, prev_channels, block)?,
|
"convolutional" => conv(vb.pp(index.to_string()), index, prev_channels, block)?,
|
||||||
"upsample" => upsample(prev_channels)?,
|
"upsample" => upsample(prev_channels)?,
|
||||||
"shortcut" => shortcut(index, prev_channels, block)?,
|
"shortcut" => shortcut(index, prev_channels, block)?,
|
||||||
"route" => route(index, &blocks, block)?,
|
"route" => route(index, &blocks, block)?,
|
||||||
|
@ -264,6 +264,7 @@ impl SimpleBackend for VarMap {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
pub struct SafeTensorWithRouting<'a> {
|
pub struct SafeTensorWithRouting<'a> {
|
||||||
routing: HashMap<String, usize>,
|
routing: HashMap<String, usize>,
|
||||||
safetensors: Vec<SafeTensors<'a>>,
|
safetensors: Vec<SafeTensors<'a>>,
|
||||||
|
@ -288,7 +288,7 @@ impl BeitVisionTransformer {
|
|||||||
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
||||||
let vb_b = vb.pp("blocks");
|
let vb_b = vb.pp("blocks");
|
||||||
let blocks = (0..depth)
|
let blocks = (0..depth)
|
||||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
.map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_embed,
|
patch_embed,
|
||||||
|
@ -249,7 +249,7 @@ impl ClipEncoder {
|
|||||||
let vs = vs.pp("layers");
|
let vs = vs.pp("layers");
|
||||||
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
|
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
|
||||||
for index in 0..c.num_hidden_layers() {
|
for index in 0..c.num_hidden_layers() {
|
||||||
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
|
let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
|
||||||
layers.push(layer)
|
layers.push(layer)
|
||||||
}
|
}
|
||||||
Ok(ClipEncoder { layers })
|
Ok(ClipEncoder { layers })
|
||||||
|
@ -214,7 +214,7 @@ impl DinoVisionTransformer {
|
|||||||
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
|
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
|
||||||
let vb_b = vb.pp("blocks");
|
let vb_b = vb.pp("blocks");
|
||||||
let blocks = (0..depth)
|
let blocks = (0..depth)
|
||||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
.map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_embed,
|
patch_embed,
|
||||||
|
@ -212,7 +212,7 @@ impl DinoVisionTransformer {
|
|||||||
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
||||||
let vb_b = vb.pp("blocks");
|
let vb_b = vb.pp("blocks");
|
||||||
let blocks = (0..depth)
|
let blocks = (0..depth)
|
||||||
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
|
.map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_embed,
|
patch_embed,
|
||||||
|
@ -571,7 +571,7 @@ impl<'a> Layer<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn next(&mut self) -> VarBuilder {
|
fn next(&mut self) -> VarBuilder {
|
||||||
let vb = self.vb.pp(&self.cnt.to_string());
|
let vb = self.vb.pp(self.cnt.to_string());
|
||||||
self.cnt += 1;
|
self.cnt += 1;
|
||||||
vb
|
vb
|
||||||
}
|
}
|
||||||
|
@ -255,14 +255,7 @@ impl EVA2VisionTransformer {
|
|||||||
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?;
|
||||||
let vb_b = vb.pp("blocks");
|
let vb_b = vb.pp("blocks");
|
||||||
let blocks = (0..depth)
|
let blocks = (0..depth)
|
||||||
.map(|i| {
|
.map(|i| Block::new(vb_b.pp(i.to_string()), embed_dim, num_heads, &rot_pos_embed))
|
||||||
Block::new(
|
|
||||||
vb_b.pp(&i.to_string()),
|
|
||||||
embed_dim,
|
|
||||||
num_heads,
|
|
||||||
&rot_pos_embed,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
patch_embed,
|
patch_embed,
|
||||||
|
@ -1,9 +1,33 @@
|
|||||||
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
use super::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, f32::consts::PI};
|
||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
pub const DEFAULT_MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize, Default)]
|
||||||
|
pub enum Llama3RopeType {
|
||||||
|
#[serde(rename = "llama3")]
|
||||||
|
Llama3,
|
||||||
|
#[default]
|
||||||
|
#[serde(rename = "default")]
|
||||||
|
Default,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize, Default)]
|
||||||
|
pub struct Llama3RopeConfig {
|
||||||
|
pub factor: f32,
|
||||||
|
pub low_freq_factor: f32,
|
||||||
|
pub high_freq_factor: f32,
|
||||||
|
pub original_max_position_embeddings: usize,
|
||||||
|
pub rope_type: Llama3RopeType,
|
||||||
|
}
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum LlamaEosToks {
|
||||||
|
Single(u32),
|
||||||
|
Multiple(Vec<u32>),
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
pub struct LlamaConfig {
|
pub struct LlamaConfig {
|
||||||
@ -17,7 +41,9 @@ pub struct LlamaConfig {
|
|||||||
#[serde(default = "default_rope")]
|
#[serde(default = "default_rope")]
|
||||||
pub rope_theta: f32,
|
pub rope_theta: f32,
|
||||||
pub bos_token_id: Option<u32>,
|
pub bos_token_id: Option<u32>,
|
||||||
pub eos_token_id: Option<u32>,
|
pub eos_token_id: Option<LlamaEosToks>,
|
||||||
|
pub rope_scaling: Option<Llama3RopeConfig>,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaConfig {
|
impl LlamaConfig {
|
||||||
@ -44,6 +70,8 @@ impl LlamaConfig {
|
|||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
bos_token_id: self.bos_token_id,
|
bos_token_id: self.bos_token_id,
|
||||||
eos_token_id: self.eos_token_id,
|
eos_token_id: self.eos_token_id,
|
||||||
|
rope_scaling: self.rope_scaling,
|
||||||
|
max_position_embeddings: self.max_position_embeddings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -60,7 +88,9 @@ pub struct Config {
|
|||||||
pub rms_norm_eps: f64,
|
pub rms_norm_eps: f64,
|
||||||
pub rope_theta: f32,
|
pub rope_theta: f32,
|
||||||
pub bos_token_id: Option<u32>,
|
pub bos_token_id: Option<u32>,
|
||||||
pub eos_token_id: Option<u32>,
|
pub eos_token_id: Option<LlamaEosToks>,
|
||||||
|
pub rope_scaling: Option<Llama3RopeConfig>,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -77,6 +107,8 @@ impl Config {
|
|||||||
rope_theta: 10_000.0,
|
rope_theta: 10_000.0,
|
||||||
bos_token_id: None,
|
bos_token_id: None,
|
||||||
eos_token_id: None,
|
eos_token_id: None,
|
||||||
|
rope_scaling: None,
|
||||||
|
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,6 +125,8 @@ impl Config {
|
|||||||
rope_theta: 10_000.0,
|
rope_theta: 10_000.0,
|
||||||
bos_token_id: None,
|
bos_token_id: None,
|
||||||
eos_token_id: None,
|
eos_token_id: None,
|
||||||
|
rope_scaling: None,
|
||||||
|
max_position_embeddings: DEFAULT_MAX_SEQ_LEN,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -107,18 +141,54 @@ pub struct Cache {
|
|||||||
device: Device,
|
device: Device,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn calculate_default_inv_freq(cfg: &Config) -> Vec<f32> {
|
||||||
|
let head_dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
|
(0..head_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f32 / head_dim as f32))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
impl Cache {
|
impl Cache {
|
||||||
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||||
// precompute freqs_cis
|
// precompute freqs_cis
|
||||||
let n_elem = config.hidden_size / config.num_attention_heads;
|
let theta = match &config.rope_scaling {
|
||||||
let theta: Vec<_> = (0..n_elem)
|
None
|
||||||
.step_by(2)
|
| Some(Llama3RopeConfig {
|
||||||
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
rope_type: Llama3RopeType::Default,
|
||||||
.collect();
|
..
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
}) => calculate_default_inv_freq(config),
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
Some(rope_scaling) => {
|
||||||
|
let low_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
|
||||||
|
/ rope_scaling.low_freq_factor;
|
||||||
|
let high_freq_wavelen = rope_scaling.original_max_position_embeddings as f32
|
||||||
|
/ rope_scaling.high_freq_factor;
|
||||||
|
|
||||||
|
calculate_default_inv_freq(config)
|
||||||
|
.into_iter()
|
||||||
|
.map(|freq| {
|
||||||
|
let wavelen = 2. * PI / freq;
|
||||||
|
if wavelen < high_freq_wavelen {
|
||||||
|
freq
|
||||||
|
} else if wavelen > low_freq_wavelen {
|
||||||
|
freq / rope_scaling.factor
|
||||||
|
} else {
|
||||||
|
let smooth = (rope_scaling.original_max_position_embeddings as f32
|
||||||
|
/ wavelen
|
||||||
|
- rope_scaling.low_freq_factor)
|
||||||
|
/ (rope_scaling.high_freq_factor - rope_scaling.low_freq_factor);
|
||||||
|
(1. - smooth) * freq / rope_scaling.factor + smooth * freq
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let theta = Tensor::new(theta, device)?;
|
||||||
|
|
||||||
|
let idx_theta = Tensor::arange(0, config.max_position_embeddings as u32, device)?
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
.reshape((config.max_position_embeddings, 1))?
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
// This is different from the paper, see:
|
// This is different from the paper, see:
|
||||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||||
@ -160,6 +230,7 @@ struct CausalSelfAttention {
|
|||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
span_rot: tracing::Span,
|
span_rot: tracing::Span,
|
||||||
|
max_position_embeddings: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "flash-attn")]
|
#[cfg(feature = "flash-attn")]
|
||||||
@ -220,15 +291,23 @@ impl CausalSelfAttention {
|
|||||||
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
|
||||||
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
|
||||||
let k_seq_len = k.dims()[1];
|
let k_seq_len = k.dims()[1];
|
||||||
if k_seq_len > MAX_SEQ_LEN {
|
if k_seq_len > self.max_position_embeddings {
|
||||||
k = k
|
k = k
|
||||||
.narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
.narrow(
|
||||||
|
D::Minus1,
|
||||||
|
k_seq_len - self.max_position_embeddings,
|
||||||
|
self.max_position_embeddings,
|
||||||
|
)?
|
||||||
.contiguous()?
|
.contiguous()?
|
||||||
}
|
}
|
||||||
let v_seq_len = v.dims()[1];
|
let v_seq_len = v.dims()[1];
|
||||||
if v_seq_len > 2 * MAX_SEQ_LEN {
|
if v_seq_len > 2 * self.max_position_embeddings {
|
||||||
v = v
|
v = v
|
||||||
.narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)?
|
.narrow(
|
||||||
|
D::Minus1,
|
||||||
|
v_seq_len - self.max_position_embeddings,
|
||||||
|
self.max_position_embeddings,
|
||||||
|
)?
|
||||||
.contiguous()?
|
.contiguous()?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -291,6 +370,7 @@ impl CausalSelfAttention {
|
|||||||
use_flash_attn: cfg.use_flash_attn,
|
use_flash_attn: cfg.use_flash_attn,
|
||||||
span,
|
span,
|
||||||
span_rot,
|
span_rot,
|
||||||
|
max_position_embeddings: cfg.max_position_embeddings,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2,7 +2,7 @@ use std::collections::HashMap;
|
|||||||
|
|
||||||
use crate::models::{
|
use crate::models::{
|
||||||
clip::{text_model::Activation, vision_model::ClipVisionConfig},
|
clip::{text_model::Activation, vision_model::ClipVisionConfig},
|
||||||
llama::Config,
|
llama::{Config, LlamaEosToks},
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@ -73,8 +73,10 @@ impl LLaVAConfig {
|
|||||||
rms_norm_eps: self.rms_norm_eps as f64,
|
rms_norm_eps: self.rms_norm_eps as f64,
|
||||||
rope_theta: self.rope_theta,
|
rope_theta: self.rope_theta,
|
||||||
bos_token_id: Some(self.bos_token_id as u32),
|
bos_token_id: Some(self.bos_token_id as u32),
|
||||||
eos_token_id: Some(self.eos_token_id as u32),
|
eos_token_id: Some(LlamaEosToks::Single(self.eos_token_id as u32)),
|
||||||
use_flash_attn: false,
|
use_flash_attn: false,
|
||||||
|
rope_scaling: None, // Assume we don't have LLaVA for Llama 3.1
|
||||||
|
max_position_embeddings: self.max_position_embeddings,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -358,7 +358,7 @@ impl SpatialTransformer {
|
|||||||
let vs_tb = vs.pp("transformer_blocks");
|
let vs_tb = vs.pp("transformer_blocks");
|
||||||
for index in 0..config.depth {
|
for index in 0..config.depth {
|
||||||
let tb = BasicTransformerBlock::new(
|
let tb = BasicTransformerBlock::new(
|
||||||
vs_tb.pp(&index.to_string()),
|
vs_tb.pp(index.to_string()),
|
||||||
inner_dim,
|
inner_dim,
|
||||||
n_heads,
|
n_heads,
|
||||||
d_head,
|
d_head,
|
||||||
|
@ -322,7 +322,7 @@ impl ClipEncoder {
|
|||||||
let vs = vs.pp("layers");
|
let vs = vs.pp("layers");
|
||||||
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
|
let mut layers: Vec<ClipEncoderLayer> = Vec::new();
|
||||||
for index in 0..c.num_hidden_layers {
|
for index in 0..c.num_hidden_layers {
|
||||||
let layer = ClipEncoderLayer::new(vs.pp(&index.to_string()), c)?;
|
let layer = ClipEncoderLayer::new(vs.pp(index.to_string()), c)?;
|
||||||
layers.push(layer)
|
layers.push(layer)
|
||||||
}
|
}
|
||||||
Ok(ClipEncoder { layers })
|
Ok(ClipEncoder { layers })
|
||||||
|
@ -161,7 +161,7 @@ impl UNet2DConditionModel {
|
|||||||
transformer_layers_per_block,
|
transformer_layers_per_block,
|
||||||
};
|
};
|
||||||
let block = CrossAttnDownBlock2D::new(
|
let block = CrossAttnDownBlock2D::new(
|
||||||
vs_db.pp(&i.to_string()),
|
vs_db.pp(i.to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
Some(time_embed_dim),
|
Some(time_embed_dim),
|
||||||
@ -171,7 +171,7 @@ impl UNet2DConditionModel {
|
|||||||
Ok(UNetDownBlock::CrossAttn(block))
|
Ok(UNetDownBlock::CrossAttn(block))
|
||||||
} else {
|
} else {
|
||||||
let block = DownBlock2D::new(
|
let block = DownBlock2D::new(
|
||||||
vs_db.pp(&i.to_string()),
|
vs_db.pp(i.to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
Some(time_embed_dim),
|
Some(time_embed_dim),
|
||||||
@ -251,7 +251,7 @@ impl UNet2DConditionModel {
|
|||||||
transformer_layers_per_block,
|
transformer_layers_per_block,
|
||||||
};
|
};
|
||||||
let block = CrossAttnUpBlock2D::new(
|
let block = CrossAttnUpBlock2D::new(
|
||||||
vs_ub.pp(&i.to_string()),
|
vs_ub.pp(i.to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
prev_out_channels,
|
prev_out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -262,7 +262,7 @@ impl UNet2DConditionModel {
|
|||||||
Ok(UNetUpBlock::CrossAttn(block))
|
Ok(UNetUpBlock::CrossAttn(block))
|
||||||
} else {
|
} else {
|
||||||
let block = UpBlock2D::new(
|
let block = UpBlock2D::new(
|
||||||
vs_ub.pp(&i.to_string()),
|
vs_ub.pp(i.to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
prev_out_channels,
|
prev_out_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
|
@ -146,7 +146,7 @@ impl DownEncoderBlock2D {
|
|||||||
(0..(config.num_layers))
|
(0..(config.num_layers))
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||||
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
|
ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?
|
.collect::<Result<Vec<_>>>()?
|
||||||
};
|
};
|
||||||
@ -235,7 +235,7 @@ impl UpDecoderBlock2D {
|
|||||||
(0..(config.num_layers))
|
(0..(config.num_layers))
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||||
ResnetBlock2D::new(vs.pp(&i.to_string()), in_channels, conv_cfg)
|
ResnetBlock2D::new(vs.pp(i.to_string()), in_channels, conv_cfg)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?
|
.collect::<Result<Vec<_>>>()?
|
||||||
};
|
};
|
||||||
@ -328,9 +328,9 @@ impl UNetMidBlock2D {
|
|||||||
};
|
};
|
||||||
let mut attn_resnets = vec![];
|
let mut attn_resnets = vec![];
|
||||||
for index in 0..config.num_layers {
|
for index in 0..config.num_layers {
|
||||||
let attn = AttentionBlock::new(vs_attns.pp(&index.to_string()), in_channels, attn_cfg)?;
|
let attn = AttentionBlock::new(vs_attns.pp(index.to_string()), in_channels, attn_cfg)?;
|
||||||
let resnet = ResnetBlock2D::new(
|
let resnet = ResnetBlock2D::new(
|
||||||
vs_resnets.pp(&(index + 1).to_string()),
|
vs_resnets.pp((index + 1).to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
resnet_cfg,
|
resnet_cfg,
|
||||||
)?;
|
)?;
|
||||||
@ -425,7 +425,7 @@ impl UNetMidBlock2DCrossAttn {
|
|||||||
let mut attn_resnets = vec![];
|
let mut attn_resnets = vec![];
|
||||||
for index in 0..config.num_layers {
|
for index in 0..config.num_layers {
|
||||||
let attn = SpatialTransformer::new(
|
let attn = SpatialTransformer::new(
|
||||||
vs_attns.pp(&index.to_string()),
|
vs_attns.pp(index.to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
n_heads,
|
n_heads,
|
||||||
in_channels / n_heads,
|
in_channels / n_heads,
|
||||||
@ -433,7 +433,7 @@ impl UNetMidBlock2DCrossAttn {
|
|||||||
attn_cfg,
|
attn_cfg,
|
||||||
)?;
|
)?;
|
||||||
let resnet = ResnetBlock2D::new(
|
let resnet = ResnetBlock2D::new(
|
||||||
vs_resnets.pp(&(index + 1).to_string()),
|
vs_resnets.pp((index + 1).to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
resnet_cfg,
|
resnet_cfg,
|
||||||
)?;
|
)?;
|
||||||
@ -515,7 +515,7 @@ impl DownBlock2D {
|
|||||||
let resnets = (0..config.num_layers)
|
let resnets = (0..config.num_layers)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let in_channels = if i == 0 { in_channels } else { out_channels };
|
let in_channels = if i == 0 { in_channels } else { out_channels };
|
||||||
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
|
ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let downsampler = if config.add_downsample {
|
let downsampler = if config.add_downsample {
|
||||||
@ -619,7 +619,7 @@ impl CrossAttnDownBlock2D {
|
|||||||
let attentions = (0..config.downblock.num_layers)
|
let attentions = (0..config.downblock.num_layers)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
SpatialTransformer::new(
|
SpatialTransformer::new(
|
||||||
vs_attn.pp(&i.to_string()),
|
vs_attn.pp(i.to_string()),
|
||||||
out_channels,
|
out_channels,
|
||||||
n_heads,
|
n_heads,
|
||||||
out_channels / n_heads,
|
out_channels / n_heads,
|
||||||
@ -724,7 +724,7 @@ impl UpBlock2D {
|
|||||||
out_channels
|
out_channels
|
||||||
};
|
};
|
||||||
let in_channels = resnet_in_channels + res_skip_channels;
|
let in_channels = resnet_in_channels + res_skip_channels;
|
||||||
ResnetBlock2D::new(vs_resnets.pp(&i.to_string()), in_channels, resnet_cfg)
|
ResnetBlock2D::new(vs_resnets.pp(i.to_string()), in_channels, resnet_cfg)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let upsampler = if config.add_upsample {
|
let upsampler = if config.add_upsample {
|
||||||
@ -826,7 +826,7 @@ impl CrossAttnUpBlock2D {
|
|||||||
let attentions = (0..config.upblock.num_layers)
|
let attentions = (0..config.upblock.num_layers)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
SpatialTransformer::new(
|
SpatialTransformer::new(
|
||||||
vs_attn.pp(&i.to_string()),
|
vs_attn.pp(i.to_string()),
|
||||||
out_channels,
|
out_channels,
|
||||||
n_heads,
|
n_heads,
|
||||||
out_channels / n_heads,
|
out_channels / n_heads,
|
||||||
|
@ -80,7 +80,7 @@ impl Encoder {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let down_block = DownEncoderBlock2D::new(
|
let down_block = DownEncoderBlock2D::new(
|
||||||
vs_down_blocks.pp(&index.to_string()),
|
vs_down_blocks.pp(index.to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
cfg,
|
cfg,
|
||||||
@ -222,7 +222,7 @@ impl Decoder {
|
|||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let up_block = UpDecoderBlock2D::new(
|
let up_block = UpDecoderBlock2D::new(
|
||||||
vs_up_blocks.pp(&index.to_string()),
|
vs_up_blocks.pp(index.to_string()),
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
cfg,
|
cfg,
|
||||||
|
@ -601,7 +601,7 @@ impl T5Block {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
|
let ff_i = if cross_attn.is_some() { 2 } else { 1 };
|
||||||
let ff = T5LayerFF::load(vb.pp(&ff_i.to_string()), cfg)?;
|
let ff = T5LayerFF::load(vb.pp(ff_i.to_string()), cfg)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
cross_attn,
|
cross_attn,
|
||||||
|
Reference in New Issue
Block a user