mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Tmp.
This commit is contained in:
@ -27,6 +27,9 @@ anyhow = "1"
|
|||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
tokenizers = { version = "0.13.3", default-features=false, features=["onig"] }
|
||||||
|
tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] }
|
||||||
|
candle-hub = { path = "../candle-hub" }
|
||||||
|
memmap2 = "0.7.1"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["cuda"]
|
default = ["cuda"]
|
||||||
|
@ -15,11 +15,14 @@ use anyhow::{Error as E, Result};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_hub::{Repo, api::Api, RepoType};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
mod var_store;
|
// mod var_store;
|
||||||
use var_store::VarBuilder;
|
// use var_store::VarBuilder;
|
||||||
|
|
||||||
|
mod weights;
|
||||||
|
|
||||||
const CONTEXT_SIZE: usize = 512;
|
const CONTEXT_SIZE: usize = 512;
|
||||||
const START_PROMPT: &str = r"
|
const START_PROMPT: &str = r"
|
||||||
@ -131,9 +134,8 @@ struct Embedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
fn new(mut vb: VarBuilder, vocab_size: usize, n_embd: usize) -> Result<Self> {
|
fn new(embeddings: Tensor) -> Self {
|
||||||
let embeddings = vb.var("weight", (vocab_size, n_embd))?;
|
Self { embeddings }
|
||||||
Ok(Self { embeddings })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||||
@ -145,42 +147,27 @@ impl Embedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Linear {
|
struct Linear {
|
||||||
ws: Tensor,
|
weight: Tensor,
|
||||||
bs: Option<Tensor>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Linear {
|
impl Linear {
|
||||||
#[allow(dead_code)]
|
fn new(weight: Tensor) -> Self {
|
||||||
fn new(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
|
Self { weight }
|
||||||
let ws = vb.var("weight", (in_size, out_size))?;
|
|
||||||
let bs = vb.var("bias", out_size)?;
|
|
||||||
Ok(Self { ws, bs: Some(bs) })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn new_no_bias(mut vb: VarBuilder, in_size: usize, out_size: usize) -> Result<Self> {
|
|
||||||
let ws = vb.var("weight", (in_size, out_size))?;
|
|
||||||
Ok(Self { ws, bs: None })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = x.matmul(&self.ws.to_dtype(DType::F32)?)?;
|
let x = x.matmul(&self.weight.to_dtype(DType::F32)?.t()?)?;
|
||||||
let y = match &self.bs {
|
Ok(x)
|
||||||
None => x,
|
|
||||||
Some(bs) => x.broadcast_add(&bs.to_dtype(DType::F32)?)?,
|
|
||||||
};
|
|
||||||
Ok(y)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
size: usize,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn new(mut vb: VarBuilder, size: usize) -> Result<Self> {
|
fn new(scale: Tensor) -> Self {
|
||||||
let scale = vb.var("scale", &[size])?;
|
Self { scale }
|
||||||
Ok(Self { scale, size })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
@ -188,10 +175,11 @@ impl RmsNorm {
|
|||||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||||
|
let size = self.scale.shape().r1()?;
|
||||||
let scale = self
|
let scale = self
|
||||||
.scale
|
.scale
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.broadcast_as((seq_len, self.size))?;
|
.broadcast_as((seq_len, size))?;
|
||||||
Ok((scale * x_normed)?)
|
Ok((scale * x_normed)?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -207,17 +195,17 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
fn new(vb: VarBuilder, n_embd: usize) -> Result<Self> {
|
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
|
||||||
let n_hidden = 8 * n_embd / 3;
|
// let n_hidden = 8 * n_embd / 3;
|
||||||
let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
|
// let n_hidden = (n_hidden - 1) / 256 * 256 + 256;
|
||||||
let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
|
// let c_fc1 = Linear::new_no_bias(&vb / "c_fc1", n_embd, n_hidden)?;
|
||||||
let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
|
// let c_fc2 = Linear::new_no_bias(&vb / "c_fc2", n_embd, n_hidden)?;
|
||||||
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
|
// let c_proj = Linear::new_no_bias(&vb / "c_proj", n_hidden, n_embd)?;
|
||||||
Ok(Self {
|
Self {
|
||||||
c_fc1,
|
c_fc1,
|
||||||
c_fc2,
|
c_fc2,
|
||||||
c_proj,
|
c_proj,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
@ -256,7 +244,7 @@ impl Cache {
|
|||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u32::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
// Once lower_triangle is available, use the following:
|
// Once lower_triangle is available, use the followig:
|
||||||
//let mask = Tensor::new(1u32, &device)?
|
//let mask = Tensor::new(1u32, &device)?
|
||||||
// .broadcast_as(&[t, t])?
|
// .broadcast_as(&[t, t])?
|
||||||
// .lower_triangle()?
|
// .lower_triangle()?
|
||||||
@ -271,21 +259,21 @@ struct CausalSelfAttention {
|
|||||||
c_attn: Linear,
|
c_attn: Linear,
|
||||||
c_proj: Linear,
|
c_proj: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
n_embd: usize,
|
// n_embd: usize,
|
||||||
cache: Cache,
|
cache: Cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalSelfAttention {
|
impl CausalSelfAttention {
|
||||||
fn new(vb: VarBuilder, n_head: usize, n_embd: usize, cache: &Cache) -> Result<Self> {
|
fn new(c_attn: Linear, c_proj: Linear, n_head: usize, cache: &Cache) -> Self {
|
||||||
let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
// let c_attn = Linear::new_no_bias(&vb / "c_attn", n_embd, 3 * n_embd)?;
|
||||||
let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
// let c_proj = Linear::new_no_bias(&vb / "c_proj", n_embd, n_embd)?;
|
||||||
Ok(Self {
|
Self {
|
||||||
c_attn,
|
c_attn,
|
||||||
c_proj,
|
c_proj,
|
||||||
n_head,
|
n_head,
|
||||||
n_embd,
|
// n_embd,
|
||||||
cache: cache.clone(),
|
cache: cache.clone(),
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
@ -313,7 +301,7 @@ impl CausalSelfAttention {
|
|||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
let (t, c) = x.shape().r2()?;
|
let (t, c) = x.shape().r2()?;
|
||||||
let qkv = self.c_attn.forward(x)?;
|
let qkv = self.c_attn.forward(x)?;
|
||||||
let n_embd = self.n_embd;
|
let n_embd = c;
|
||||||
let q = qkv.narrow(1, 0, n_embd)?;
|
let q = qkv.narrow(1, 0, n_embd)?;
|
||||||
let k = qkv.narrow(1, n_embd, n_embd)?;
|
let k = qkv.narrow(1, n_embd, n_embd)?;
|
||||||
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
let v = qkv.narrow(1, 2 * n_embd, n_embd)?;
|
||||||
@ -344,17 +332,13 @@ struct Block {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Block {
|
impl Block {
|
||||||
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
|
||||||
let rms_1 = RmsNorm::new(&vb / "rms_1", config.n_embd)?;
|
Self {
|
||||||
let attn = CausalSelfAttention::new(&vb / "attn", config.n_head, config.n_embd, cache)?;
|
|
||||||
let rms_2 = RmsNorm::new(&vb / "rms_2", config.n_embd)?;
|
|
||||||
let mlp = Mlp::new(&vb / "mlp", config.n_embd)?;
|
|
||||||
Ok(Self {
|
|
||||||
rms_1,
|
rms_1,
|
||||||
attn,
|
attn,
|
||||||
rms_2,
|
rms_2,
|
||||||
mlp,
|
mlp,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
@ -372,23 +356,13 @@ struct Llama {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Llama {
|
impl Llama {
|
||||||
fn new(vb: VarBuilder, cache: &Cache, config: &Config) -> Result<Self> {
|
fn new(wte: Embedding, blocks: Vec<Block>, ln_f: RmsNorm, lm_head: Linear) -> Self {
|
||||||
let lm_head = Linear::new_no_bias(&vb / "lm_head", config.n_embd, config.vocab_size)?;
|
Self {
|
||||||
let wte = Embedding::new(
|
|
||||||
&vb / "transformer" / "wte",
|
|
||||||
config.vocab_size,
|
|
||||||
config.n_embd,
|
|
||||||
)?;
|
|
||||||
let blocks = (0..config.n_layer)
|
|
||||||
.map(|i| Block::new(&vb / "transformer" / "h" / i, cache, config))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let ln_f = RmsNorm::new(&vb / "transformer" / "ln_f", config.n_embd)?;
|
|
||||||
Ok(Self {
|
|
||||||
wte,
|
wte,
|
||||||
blocks,
|
blocks,
|
||||||
ln_f,
|
ln_f,
|
||||||
lm_head,
|
lm_head,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, freqs_cis: &Tensor) -> Result<Tensor> {
|
||||||
@ -443,7 +417,8 @@ struct Args {
|
|||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
use rand::prelude::*;
|
use rand::prelude::*;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
@ -453,32 +428,39 @@ fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
println!("loading tokenizer config");
|
let api = Api::new()?;
|
||||||
let tokenizer = Tokenizer::from_file("llama-tokenizer.json").map_err(E::msg)?;
|
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||||
|
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
let mut tokens = tokenizer
|
let mut tokens = tokenizer
|
||||||
.encode(START_PROMPT, true)
|
.encode(START_PROMPT, true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let weight_path = std::path::Path::new("llama.npz");
|
let mut filenames = vec![];
|
||||||
let weights = if weight_path.exists() {
|
for rfilename in ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"]{
|
||||||
println!("loading weights from {weight_path:?}");
|
let filename = api.get(&repo, rfilename).await?;
|
||||||
let start_load = std::time::Instant::now();
|
filenames.push(filename);
|
||||||
let tensors = Tensor::read_npz(weight_path)?;
|
}
|
||||||
println!("loaded weights in {:?}", start_load.elapsed());
|
// let weight_path = std::path::Path::new("llama.npz");
|
||||||
let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
// let weights = if weight_path.exists() {
|
||||||
Some(tensors)
|
// println!("loading weights from {weight_path:?}");
|
||||||
} else {
|
// let start_load = std::time::Instant::now();
|
||||||
println!("cannot find {weight_path:?}, using zero weights");
|
// let tensors = Tensor::read_npz(weight_path)?;
|
||||||
None
|
// println!("loaded weights in {:?}", start_load.elapsed());
|
||||||
};
|
// let tensors: std::collections::HashMap<String, Tensor> = tensors.into_iter().collect();
|
||||||
let vb = VarBuilder::new::<f32>(&device, weights);
|
// Some(tensors)
|
||||||
|
// } else {
|
||||||
|
// println!("cannot find {weight_path:?}, using zero weights");
|
||||||
|
// None
|
||||||
|
// };
|
||||||
|
// let vb = VarBuilder::new::<f32>(&device, weights);
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = Cache::new(&device);
|
let cache = Cache::new(&device);
|
||||||
let llama = Llama::new(vb, &cache, &config)?;
|
let llama = Llama::load(&device, &filenames, &cache, &config)?;
|
||||||
|
|
||||||
println!("pre-computing the positional embeddings");
|
println!("pre-computing the positional embeddings");
|
||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||||
|
144
candle-core/examples/llama/weights.rs
Normal file
144
candle-core/examples/llama/weights.rs
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
use memmap2::MmapOptions;
|
||||||
|
use candle::{Device, Result, Shape, Tensor, WithDType};
|
||||||
|
use std::fs::File;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use super::*;
|
||||||
|
use safetensors::{SafeTensors, tensor::{Dtype, TensorView}};
|
||||||
|
use half::f16;
|
||||||
|
|
||||||
|
fn convert<'a>(view: TensorView<'a>, device: &Device) -> Result<Tensor>{
|
||||||
|
match view.dtype(){
|
||||||
|
Dtype::F16 => {
|
||||||
|
let v = view.data();
|
||||||
|
if (v.as_ptr() as usize) % 2 == 0 {
|
||||||
|
// SAFETY This is safe because we just checked that this
|
||||||
|
// was correctly aligned.
|
||||||
|
let data: &[f16] =
|
||||||
|
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
||||||
|
Tensor::from_slice(data, view.shape(), device)
|
||||||
|
} else {
|
||||||
|
let mut c = Vec::with_capacity(v.len() / 2);
|
||||||
|
let mut i = 0;
|
||||||
|
while i < v.len() {
|
||||||
|
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
||||||
|
i += 2;
|
||||||
|
}
|
||||||
|
Tensor::from_slice(&c, view.shape(), device)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
dt => todo!("Unhandled dtype {dt:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct VarBuilder<'a>{
|
||||||
|
routing: HashMap<String, usize>,
|
||||||
|
safetensors: Vec<SafeTensors<'a>>,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl<'a> VarBuilder<'a>{
|
||||||
|
pub fn new(safetensors: Vec<SafeTensors<'a>>, device: Device) -> Self{
|
||||||
|
let mut routing = HashMap::new();
|
||||||
|
for (index, sf) in safetensors.iter().enumerate(){
|
||||||
|
for k in sf.names(){
|
||||||
|
routing.insert(k.to_string(), index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Self{
|
||||||
|
safetensors,
|
||||||
|
device,
|
||||||
|
routing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, tensor_name: &str) -> Result<Tensor>{
|
||||||
|
// Unwrap or 0 just to let the proper error flow.
|
||||||
|
let index = self.routing.get(tensor_name).unwrap_or(&0);
|
||||||
|
let view = self.safetensors[*index].tensor(tensor_name).unwrap();
|
||||||
|
let tensor = convert(view, &self.device)?;
|
||||||
|
Ok(tensor)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear{
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self>{
|
||||||
|
let weight = vb.get(&format!("{prefix}.weight"))?;
|
||||||
|
Ok(Self::new(weight))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result<Self>{
|
||||||
|
let weights: Vec<_> = prefixes.iter().map(|p| vb.get(&format!("{p}.weight")).unwrap()).collect();
|
||||||
|
println!("shapes {:?}", weights.iter().map(|w| w.shape()).collect::<Vec<_>>());
|
||||||
|
let weight = Tensor::cat(&weights, 0)?;
|
||||||
|
Ok(Self::new(weight))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RmsNorm{
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder) -> Result<Self>{
|
||||||
|
let scale = vb.get(&format!("{prefix}.weight"))?;
|
||||||
|
Ok(Self::new(scale))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CausalSelfAttention{
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self>{
|
||||||
|
let c_attn = Linear::load_multi(&[&format!("{prefix}.q_proj"), &format!("{prefix}.k_proj"), &format!("{prefix}.v_proj")], vb)?;
|
||||||
|
let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?;
|
||||||
|
Ok(Self::new(c_attn,o_proj, config.n_head, cache))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Mlp{
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder, config: &Config) -> Result<Self>{
|
||||||
|
let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?;
|
||||||
|
let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?;
|
||||||
|
let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?;
|
||||||
|
Ok(Self::new(c_fc1, c_fc2, c_proj))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Block{
|
||||||
|
fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result<Self>{
|
||||||
|
let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?;
|
||||||
|
let mlp = Mlp::load(&format!("{prefix}.mlp"), vb, config)?;
|
||||||
|
let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?;
|
||||||
|
let post_attention_layernorm = RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?;
|
||||||
|
Ok(Self::new(input_layernorm, attn, post_attention_layernorm, mlp))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Llama{
|
||||||
|
pub fn load(device: &Device, filenames: &[PathBuf], cache: &Cache, config: &Config) -> Result<Self>{
|
||||||
|
let handles: Vec<_> = filenames.iter().map(|f| {
|
||||||
|
let file = File::open(f).unwrap();
|
||||||
|
let buffer = unsafe { MmapOptions::new().map(&file).unwrap() };
|
||||||
|
buffer
|
||||||
|
}).collect();
|
||||||
|
let tensors: Vec<_> = handles.iter().map(|h| {
|
||||||
|
let tensors = SafeTensors::deserialize(h).unwrap();
|
||||||
|
tensors
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
let vb = VarBuilder::new(tensors, device.clone());
|
||||||
|
|
||||||
|
let embedding = vb.get("model.embed_tokens.weight")?;
|
||||||
|
let wte = Embedding::new(embedding);
|
||||||
|
let lm_head = Linear::load("lm_head", &vb)?;
|
||||||
|
let norm = RmsNorm::load("model.norm", &vb)?;
|
||||||
|
let blocks: Vec<_> = (0..config.n_layer).map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap()).collect();
|
||||||
|
|
||||||
|
Ok(Self::new(
|
||||||
|
wte,
|
||||||
|
blocks,
|
||||||
|
norm,
|
||||||
|
lm_head
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -378,6 +378,12 @@ impl Api {
|
|||||||
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures));
|
let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures));
|
||||||
let filename = temp_filename();
|
let filename = temp_filename();
|
||||||
|
|
||||||
|
// Create the file and set everything properly
|
||||||
|
tokio::fs::File::create(&filename)
|
||||||
|
.await?
|
||||||
|
.set_len(length as u64)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let chunk_size = self.chunk_size;
|
let chunk_size = self.chunk_size;
|
||||||
for start in (0..length).step_by(chunk_size) {
|
for start in (0..length).step_by(chunk_size) {
|
||||||
let url = url.to_string();
|
let url = url.to_string();
|
||||||
@ -391,6 +397,7 @@ impl Api {
|
|||||||
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
|
let parallel_failures_semaphore = parallel_failures_semaphore.clone();
|
||||||
let progress = progressbar.clone();
|
let progress = progressbar.clone();
|
||||||
handles.push(tokio::spawn(async move {
|
handles.push(tokio::spawn(async move {
|
||||||
|
println!("Start {start:?} - {stop:?}");
|
||||||
let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await;
|
let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await;
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
if parallel_failures > 0 {
|
if parallel_failures > 0 {
|
||||||
@ -440,7 +447,6 @@ impl Api {
|
|||||||
let range = format!("bytes={start}-{stop}");
|
let range = format!("bytes={start}-{stop}");
|
||||||
let mut file = tokio::fs::OpenOptions::new()
|
let mut file = tokio::fs::OpenOptions::new()
|
||||||
.write(true)
|
.write(true)
|
||||||
.create(true)
|
|
||||||
.open(filename)
|
.open(filename)
|
||||||
.await?;
|
.await?;
|
||||||
file.seek(SeekFrom::Start(start as u64)).await?;
|
file.seek(SeekFrom::Start(start as u64)).await?;
|
||||||
|
@ -53,7 +53,11 @@ impl Cache {
|
|||||||
let commit_hash = std::fs::read_to_string(commit_path).ok()?;
|
let commit_hash = std::fs::read_to_string(commit_path).ok()?;
|
||||||
let mut pointer_path = self.pointer_path(repo, &commit_hash);
|
let mut pointer_path = self.pointer_path(repo, &commit_hash);
|
||||||
pointer_path.push(filename);
|
pointer_path.push(filename);
|
||||||
|
if pointer_path.exists(){
|
||||||
Some(pointer_path)
|
Some(pointer_path)
|
||||||
|
}else{
|
||||||
|
None
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a reference in the cache directory that points branches to the correct
|
/// Creates a reference in the cache directory that points branches to the correct
|
||||||
@ -146,7 +150,12 @@ impl Repo {
|
|||||||
|
|
||||||
/// The normalized folder nameof the repo within the cache directory
|
/// The normalized folder nameof the repo within the cache directory
|
||||||
pub fn folder_name(&self) -> String {
|
pub fn folder_name(&self) -> String {
|
||||||
self.repo_id.replace('/', "--")
|
let prefix = match self.repo_type{
|
||||||
|
RepoType::Model => "models",
|
||||||
|
RepoType::Dataset => "datasets",
|
||||||
|
RepoType::Space => "spaces",
|
||||||
|
};
|
||||||
|
format!("{prefix}--{}", self.repo_id).replace('/', "--")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The revision
|
/// The revision
|
||||||
|
Reference in New Issue
Block a user