mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Adding support for codellama in examples.
Codellama requires bf16 for now (error to convert from bf16 to f16). Multiprocess demo not functional for it because flash-attn only supports f16 for now.
This commit is contained in:
@ -18,7 +18,7 @@ use clap::Parser;
|
|||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
mod model;
|
mod model;
|
||||||
@ -59,9 +59,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// Use f32 computations rather than f16.
|
/// Use different dtype than f16
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_f32: bool,
|
dtype: Option<String>,
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
@ -70,6 +70,9 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
v1: bool,
|
v1: bool,
|
||||||
|
|
||||||
@ -97,7 +100,13 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let dtype = if args.use_f32 { DType::F32 } else { DType::F16 };
|
let dtype = match args.dtype.as_deref() {
|
||||||
|
Some("f16") => DType::F16,
|
||||||
|
Some("bf16") => DType::BF16,
|
||||||
|
Some("f32") => DType::F32,
|
||||||
|
Some(dtype) => panic!("Unsupported dtype {dtype}"),
|
||||||
|
None => DType::F16,
|
||||||
|
};
|
||||||
let (llama, tokenizer_filename, cache) = match args.npy {
|
let (llama, tokenizer_filename, cache) = match args.npy {
|
||||||
Some(filename) => {
|
Some(filename) => {
|
||||||
let config = if args.v1 {
|
let config = if args.v1 {
|
||||||
@ -120,7 +129,8 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let api = api.model(model_id);
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
|
||||||
let tokenizer_filename = match &args.local_weights {
|
let tokenizer_filename = match &args.local_weights {
|
||||||
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
Some(path) => (path.to_owned() + "tokenizer.json").into(),
|
||||||
|
@ -15,6 +15,12 @@ pub struct LlamaConfig {
|
|||||||
pub num_attention_heads: usize,
|
pub num_attention_heads: usize,
|
||||||
pub num_key_value_heads: Option<usize>,
|
pub num_key_value_heads: Option<usize>,
|
||||||
pub rms_norm_eps: f64,
|
pub rms_norm_eps: f64,
|
||||||
|
#[serde(default = "default_rope")]
|
||||||
|
pub rope_theta: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rope() -> f32 {
|
||||||
|
10_000.0
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlamaConfig {
|
impl LlamaConfig {
|
||||||
@ -27,6 +33,7 @@ impl LlamaConfig {
|
|||||||
num_attention_heads: self.num_attention_heads,
|
num_attention_heads: self.num_attention_heads,
|
||||||
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
num_key_value_heads: self.num_key_value_heads.unwrap_or(self.num_attention_heads),
|
||||||
rms_norm_eps: self.rms_norm_eps,
|
rms_norm_eps: self.rms_norm_eps,
|
||||||
|
rope_theta: self.rope_theta,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -41,6 +48,7 @@ pub struct Config {
|
|||||||
pub num_key_value_heads: usize,
|
pub num_key_value_heads: usize,
|
||||||
pub use_flash_attn: bool,
|
pub use_flash_attn: bool,
|
||||||
pub rms_norm_eps: f64,
|
pub rms_norm_eps: f64,
|
||||||
|
pub rope_theta: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -54,6 +62,7 @@ impl Config {
|
|||||||
num_key_value_heads: 32,
|
num_key_value_heads: 32,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
rms_norm_eps: 1e-6,
|
rms_norm_eps: 1e-6,
|
||||||
|
rope_theta: 10_000.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -67,6 +76,7 @@ impl Config {
|
|||||||
num_key_value_heads: 32,
|
num_key_value_heads: 32,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
rms_norm_eps: 1e-5,
|
rms_norm_eps: 1e-5,
|
||||||
|
rope_theta: 10_000.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -103,7 +113,7 @@ impl Cache {
|
|||||||
let n_elem = config.hidden_size / config.num_attention_heads;
|
let n_elem = config.hidden_size / config.num_attention_heads;
|
||||||
let theta: Vec<_> = (0..n_elem)
|
let theta: Vec<_> = (0..n_elem)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||||
|
@ -17,7 +17,7 @@ use candle_nn::VarBuilder;
|
|||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use cudarc::driver::safe::CudaDevice;
|
use cudarc::driver::safe::CudaDevice;
|
||||||
use cudarc::nccl::safe::{Comm, Id};
|
use cudarc::nccl::safe::{Comm, Id};
|
||||||
use hf_hub::api::sync::Api;
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
@ -108,6 +108,12 @@ struct Args {
|
|||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
dtype: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -115,8 +121,13 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
let config = Config::config_7b();
|
let dtype = match args.dtype.as_deref() {
|
||||||
let dtype = DType::F16;
|
Some("f16") => DType::F16,
|
||||||
|
Some("bf16") => DType::BF16,
|
||||||
|
Some("f32") => DType::F32,
|
||||||
|
Some(dtype) => panic!("Unsupported dtype {dtype}"),
|
||||||
|
None => DType::F16,
|
||||||
|
};
|
||||||
|
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
|
|
||||||
@ -124,7 +135,10 @@ fn main() -> Result<()> {
|
|||||||
.model_id
|
.model_id
|
||||||
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
.unwrap_or_else(|| "meta-llama/Llama-2-7b-hf".to_string());
|
||||||
println!("loading the model weights from {model_id}");
|
println!("loading the model weights from {model_id}");
|
||||||
let api = api.model(model_id);
|
let revision = args.revision.unwrap_or("main".to_string());
|
||||||
|
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
|
let config_filename = api.get("config.json")?;
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let mut filenames = vec![];
|
let mut filenames = vec![];
|
||||||
for rfilename in [
|
for rfilename in [
|
||||||
@ -185,7 +199,7 @@ fn main() -> Result<()> {
|
|||||||
println!("Rank {rank:?} spawned");
|
println!("Rank {rank:?} spawned");
|
||||||
|
|
||||||
let device = Device::new_cuda(i)?;
|
let device = Device::new_cuda(i)?;
|
||||||
let cache = model::Cache::new(&config, &device)?;
|
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let handles = filenames
|
let handles = filenames
|
||||||
|
@ -3,6 +3,7 @@ use candle::{CpuStorage, CustomOp1, DType, Device, IndexOp, Layout, Result, Shap
|
|||||||
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
|
||||||
use cudarc::nccl::safe::{Comm, ReduceOp};
|
use cudarc::nccl::safe::{Comm, ReduceOp};
|
||||||
use half::f16;
|
use half::f16;
|
||||||
|
use serde::Deserialize;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
@ -110,26 +111,34 @@ impl TensorParallelRowLinear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
pub n_layer: usize,
|
pub num_hidden_layers: usize,
|
||||||
pub n_head: usize,
|
pub num_attention_heads: usize,
|
||||||
pub n_embd: usize,
|
pub num_key_value_heads: usize,
|
||||||
pub n_key_value_head: usize,
|
pub rms_norm_eps: f64,
|
||||||
|
#[serde(default = "default_rope")]
|
||||||
|
pub rope_theta: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_rope() -> f32 {
|
||||||
|
10_000.0
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn config_7b() -> Self {
|
pub fn config_7b() -> Self {
|
||||||
Self {
|
Self {
|
||||||
hidden_size: 4096,
|
|
||||||
intermediate_size: 11008,
|
intermediate_size: 11008,
|
||||||
vocab_size: 32000,
|
vocab_size: 32000,
|
||||||
n_layer: 32,
|
num_hidden_layers: 32,
|
||||||
n_head: 32,
|
num_attention_heads: 32,
|
||||||
n_embd: 4096,
|
hidden_size: 4096,
|
||||||
n_key_value_head: 32,
|
num_key_value_heads: 32,
|
||||||
|
rms_norm_eps: 1e-5,
|
||||||
|
rope_theta: 10_000.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -143,12 +152,12 @@ pub struct Cache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Cache {
|
impl Cache {
|
||||||
pub fn new(config: &Config, device: &Device) -> Result<Self> {
|
pub fn new(dtype: DType, config: &Config, device: &Device) -> Result<Self> {
|
||||||
// precompute freqs_cis
|
// precompute freqs_cis
|
||||||
let n_elem = config.n_embd / config.n_head;
|
let n_elem = config.hidden_size / config.num_attention_heads;
|
||||||
let theta: Vec<_> = (0..n_elem)
|
let theta: Vec<_> = (0..n_elem)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / n_elem as f32))
|
.map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||||
@ -158,10 +167,10 @@ impl Cache {
|
|||||||
// This is different from the paper, see:
|
// This is different from the paper, see:
|
||||||
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
// https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
|
||||||
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
let idx_theta = Tensor::cat(&[&idx_theta, &idx_theta], D::Minus1)?;
|
||||||
let cos = idx_theta.cos()?.to_dtype(DType::F16)?;
|
let cos = idx_theta.cos()?.to_dtype(dtype)?;
|
||||||
let sin = idx_theta.sin()?.to_dtype(DType::F16)?;
|
let sin = idx_theta.sin()?.to_dtype(dtype)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
kvs: Arc::new(Mutex::new(vec![None; config.n_layer])),
|
kvs: Arc::new(Mutex::new(vec![None; config.num_hidden_layers])),
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
})
|
})
|
||||||
@ -185,21 +194,21 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
|||||||
struct CausalSelfAttention {
|
struct CausalSelfAttention {
|
||||||
qkv_proj: TensorParallelColumnLinear,
|
qkv_proj: TensorParallelColumnLinear,
|
||||||
o_proj: TensorParallelRowLinear,
|
o_proj: TensorParallelRowLinear,
|
||||||
n_head: usize,
|
num_attention_heads: usize,
|
||||||
n_key_value_head: usize,
|
num_key_value_heads: usize,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (b_sz, _, seq_len, n_embd) = x.shape().dims4()?;
|
let (b_sz, _, seq_len, hidden_size) = x.shape().dims4()?;
|
||||||
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
|
||||||
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
|
||||||
let cos = cos.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
let cos = cos.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||||
let sin = sin.broadcast_as((b_sz, 1, seq_len, n_embd))?;
|
let sin = sin.broadcast_as((b_sz, 1, seq_len, hidden_size))?;
|
||||||
let x1 = x.narrow(D::Minus1, 0, n_embd / 2)?;
|
let x1 = x.narrow(D::Minus1, 0, hidden_size / 2)?;
|
||||||
let x2 = x.narrow(D::Minus1, n_embd / 2, n_embd / 2)?;
|
let x2 = x.narrow(D::Minus1, hidden_size / 2, hidden_size / 2)?;
|
||||||
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
let rotate_x = Tensor::cat(&[&x2.neg()?, &x1], D::Minus1)?;
|
||||||
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
let rope = (x.broadcast_mul(&cos)? + rotate_x.broadcast_mul(&sin)?)?;
|
||||||
Ok(rope)
|
Ok(rope)
|
||||||
@ -209,30 +218,31 @@ impl CausalSelfAttention {
|
|||||||
let (b_sz, seq_len, _) = x.shape().dims3()?;
|
let (b_sz, seq_len, _) = x.shape().dims3()?;
|
||||||
|
|
||||||
let qkv = self.qkv_proj.forward(x)?;
|
let qkv = self.qkv_proj.forward(x)?;
|
||||||
let n_embd = self.n_head * self.head_dim;
|
let hidden_size = self.num_attention_heads * self.head_dim;
|
||||||
|
|
||||||
let q = qkv.i((.., .., ..self.n_head * self.head_dim))?;
|
let q = qkv.i((.., .., ..self.num_attention_heads * self.head_dim))?;
|
||||||
let k = qkv.i((
|
let k = qkv.i((
|
||||||
..,
|
..,
|
||||||
..,
|
..,
|
||||||
self.n_head * self.head_dim
|
self.num_attention_heads * self.head_dim
|
||||||
..self.n_head * self.head_dim + self.n_key_value_head * self.head_dim,
|
..self.num_attention_heads * self.head_dim
|
||||||
|
+ self.num_key_value_heads * self.head_dim,
|
||||||
))?;
|
))?;
|
||||||
let v = qkv.i((
|
let v = qkv.i((
|
||||||
..,
|
..,
|
||||||
..,
|
..,
|
||||||
self.n_head * self.head_dim + self.n_key_value_head * self.head_dim..,
|
self.num_attention_heads * self.head_dim + self.num_key_value_heads * self.head_dim..,
|
||||||
))?;
|
))?;
|
||||||
// todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape());
|
// todo!("Q {:?} K {:?} V {:?} - x {:?}", q.shape(), k.shape(), v.shape(), x.shape());
|
||||||
|
|
||||||
let q = q
|
let q = q
|
||||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
let k = k
|
let k = k
|
||||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
let mut v = v
|
let mut v = v
|
||||||
.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?
|
.reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
@ -266,13 +276,13 @@ impl CausalSelfAttention {
|
|||||||
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
let y = candle_flash_attn::flash_attn(&q, &k, &v, softmax_scale, seq_len > 1)?
|
||||||
.transpose(1, 2)?;
|
.transpose(1, 2)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?;
|
||||||
let y = self.o_proj.forward(&y)?;
|
let y = self.o_proj.forward(&y)?;
|
||||||
Ok(y)
|
Ok(y)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
|
||||||
let n_rep = self.n_head / self.n_key_value_head;
|
let n_rep = self.num_attention_heads / self.num_key_value_heads;
|
||||||
if n_rep == 1 {
|
if n_rep == 1 {
|
||||||
Ok(x)
|
Ok(x)
|
||||||
} else {
|
} else {
|
||||||
@ -295,9 +305,9 @@ impl CausalSelfAttention {
|
|||||||
Ok(Self {
|
Ok(Self {
|
||||||
qkv_proj,
|
qkv_proj,
|
||||||
o_proj,
|
o_proj,
|
||||||
n_head: cfg.n_head / comm.world_size(),
|
num_attention_heads: cfg.num_attention_heads / comm.world_size(),
|
||||||
n_key_value_head: cfg.n_key_value_head / comm.world_size(),
|
num_key_value_heads: cfg.num_key_value_heads / comm.world_size(),
|
||||||
head_dim: cfg.hidden_size / cfg.n_head,
|
head_dim: cfg.hidden_size / cfg.num_attention_heads,
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -409,7 +419,7 @@ impl Llama {
|
|||||||
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
||||||
let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?;
|
let norm = rms_norm(cfg.hidden_size, 1e-5, vb.pp("model.norm"))?;
|
||||||
let blocks: Vec<_> = (0..cfg.n_layer)
|
let blocks: Vec<_> = (0..cfg.num_hidden_layers)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
Block::load(
|
Block::load(
|
||||||
vb.pp(&format!("model.layers.{i}")),
|
vb.pp(&format!("model.layers.{i}")),
|
||||||
|
Reference in New Issue
Block a user