diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index f1b2c3d4..066025b1 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -15,7 +15,7 @@ use anyhow::{Error as E, Result}; use clap::Parser; use candle::{DType, Device, Tensor}; -use candle_hub::{Repo, api::Api, RepoType}; +use candle_hub::{api::Api, Repo, RepoType}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -439,7 +439,10 @@ async fn main() -> Result<()> { .to_vec(); let mut filenames = vec![]; - for rfilename in ["model-00001-of-00002.safetensors", "model-00002-of-00002.safetensors"]{ + for rfilename in [ + "model-00001-of-00002.safetensors", + "model-00002-of-00002.safetensors", + ] { let filename = api.get(&repo, rfilename).await?; filenames.push(filename); } diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index 2a350f65..73609d51 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -1,13 +1,16 @@ +use super::*; +use candle::{Device, Result, Tensor}; +use half::f16; use memmap2::MmapOptions; -use candle::{Device, Result, Shape, Tensor, WithDType}; +use safetensors::{ + tensor::{Dtype, TensorView}, + SafeTensors, +}; use std::fs::File; use std::path::PathBuf; -use super::*; -use safetensors::{SafeTensors, tensor::{Dtype, TensorView}}; -use half::f16; -fn convert<'a>(view: TensorView<'a>, device: &Device) -> Result{ - match view.dtype(){ +fn convert(view: TensorView<'_>, device: &Device) -> Result { + match view.dtype() { Dtype::F16 => { let v = view.data(); if (v.as_ptr() as usize) % 2 == 0 { @@ -25,76 +28,82 @@ fn convert<'a>(view: TensorView<'a>, device: &Device) -> Result{ } Tensor::from_slice(&c, view.shape(), device) } - } - dt => todo!("Unhandled dtype {dt:?}") + dt => todo!("Unhandled dtype {dt:?}"), } } -pub struct VarBuilder<'a>{ +pub struct VarBuilder<'a> { routing: HashMap, safetensors: Vec>, device: Device, } - -impl<'a> VarBuilder<'a>{ - pub fn new(safetensors: Vec>, device: Device) -> Self{ +impl<'a> VarBuilder<'a> { + pub fn new(safetensors: Vec>, device: Device) -> Self { let mut routing = HashMap::new(); - for (index, sf) in safetensors.iter().enumerate(){ - for k in sf.names(){ + for (index, sf) in safetensors.iter().enumerate() { + for k in sf.names() { routing.insert(k.to_string(), index); } } - Self{ + Self { safetensors, device, - routing + routing, } } - pub fn get(&self, tensor_name: &str) -> Result{ + pub fn get(&self, tensor_name: &str) -> Result { // Unwrap or 0 just to let the proper error flow. let index = self.routing.get(tensor_name).unwrap_or(&0); let view = self.safetensors[*index].tensor(tensor_name).unwrap(); let tensor = convert(view, &self.device)?; Ok(tensor) - } } -impl Linear{ - fn load(prefix: &str, vb: &VarBuilder) -> Result{ +impl Linear { + fn load(prefix: &str, vb: &VarBuilder) -> Result { let weight = vb.get(&format!("{prefix}.weight"))?; Ok(Self::new(weight)) } - fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result{ - let weights: Vec<_> = prefixes.iter().map(|p| vb.get(&format!("{p}.weight")).unwrap()).collect(); - println!("shapes {:?}", weights.iter().map(|w| w.shape()).collect::>()); + fn load_multi(prefixes: &[&str], vb: &VarBuilder) -> Result { + let weights: Vec<_> = prefixes + .iter() + .map(|p| vb.get(&format!("{p}.weight")).unwrap()) + .collect(); let weight = Tensor::cat(&weights, 0)?; Ok(Self::new(weight)) } } -impl RmsNorm{ - fn load(prefix: &str, vb: &VarBuilder) -> Result{ +impl RmsNorm { + fn load(prefix: &str, vb: &VarBuilder) -> Result { let scale = vb.get(&format!("{prefix}.weight"))?; Ok(Self::new(scale)) } } -impl CausalSelfAttention{ - fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result{ - let c_attn = Linear::load_multi(&[&format!("{prefix}.q_proj"), &format!("{prefix}.k_proj"), &format!("{prefix}.v_proj")], vb)?; +impl CausalSelfAttention { + fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result { + let c_attn = Linear::load_multi( + &[ + &format!("{prefix}.q_proj"), + &format!("{prefix}.k_proj"), + &format!("{prefix}.v_proj"), + ], + vb, + )?; let o_proj = Linear::load(&format!("{prefix}.o_proj"), vb)?; - Ok(Self::new(c_attn,o_proj, config.n_head, cache)) + Ok(Self::new(c_attn, o_proj, config.n_head, cache)) } } -impl Mlp{ - fn load(prefix: &str, vb: &VarBuilder, config: &Config) -> Result{ +impl Mlp { + fn load(prefix: &str, vb: &VarBuilder) -> Result { let c_fc1 = Linear::load(&format!("{prefix}.gate_proj"), vb)?; let c_fc2 = Linear::load(&format!("{prefix}.up_proj"), vb)?; let c_proj = Linear::load(&format!("{prefix}.down_proj"), vb)?; @@ -102,27 +111,43 @@ impl Mlp{ } } -impl Block{ - fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result{ +impl Block { + fn load(prefix: &str, vb: &VarBuilder, cache: &Cache, config: &Config) -> Result { let attn = CausalSelfAttention::load(&format!("{prefix}.self_attn"), vb, cache, config)?; - let mlp = Mlp::load(&format!("{prefix}.mlp"), vb, config)?; + let mlp = Mlp::load(&format!("{prefix}.mlp"), vb)?; let input_layernorm = RmsNorm::load(&format!("{prefix}.input_layernorm"), vb)?; - let post_attention_layernorm = RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?; - Ok(Self::new(input_layernorm, attn, post_attention_layernorm, mlp)) + let post_attention_layernorm = + RmsNorm::load(&format!("{prefix}.post_attention_layernorm"), vb)?; + Ok(Self::new( + input_layernorm, + attn, + post_attention_layernorm, + mlp, + )) } } -impl Llama{ - pub fn load(device: &Device, filenames: &[PathBuf], cache: &Cache, config: &Config) -> Result{ - let handles: Vec<_> = filenames.iter().map(|f| { - let file = File::open(f).unwrap(); - let buffer = unsafe { MmapOptions::new().map(&file).unwrap() }; - buffer - }).collect(); - let tensors: Vec<_> = handles.iter().map(|h| { - let tensors = SafeTensors::deserialize(h).unwrap(); - tensors - }).collect(); +impl Llama { + pub fn load( + device: &Device, + filenames: &[PathBuf], + cache: &Cache, + config: &Config, + ) -> Result { + let handles: Vec<_> = filenames + .iter() + .map(|f| { + let file = File::open(f).unwrap(); + unsafe { MmapOptions::new().map(&file).unwrap() } + }) + .collect(); + let tensors: Vec<_> = handles + .iter() + .map(|h| { + let tensors = SafeTensors::deserialize(h).unwrap(); + tensors + }) + .collect(); let vb = VarBuilder::new(tensors, device.clone()); @@ -130,15 +155,10 @@ impl Llama{ let wte = Embedding::new(embedding); let lm_head = Linear::load("lm_head", &vb)?; let norm = RmsNorm::load("model.norm", &vb)?; - let blocks: Vec<_> = (0..config.n_layer).map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap()).collect(); + let blocks: Vec<_> = (0..config.n_layer) + .map(|i| Block::load(&format!("model.layers.{i}"), &vb, cache, config).unwrap()) + .collect(); - Ok(Self::new( - wte, - blocks, - norm, - lm_head - )) + Ok(Self::new(wte, blocks, norm, lm_head)) } } - - diff --git a/candle-hub/src/api.rs b/candle-hub/src/api.rs index bdf39cde..fa677f24 100644 --- a/candle-hub/src/api.rs +++ b/candle-hub/src/api.rs @@ -397,7 +397,6 @@ impl Api { let parallel_failures_semaphore = parallel_failures_semaphore.clone(); let progress = progressbar.clone(); handles.push(tokio::spawn(async move { - println!("Start {start:?} - {stop:?}"); let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; let mut i = 0; if parallel_failures > 0 { diff --git a/candle-hub/src/lib.rs b/candle-hub/src/lib.rs index c5965d1f..5d186039 100644 --- a/candle-hub/src/lib.rs +++ b/candle-hub/src/lib.rs @@ -53,9 +53,9 @@ impl Cache { let commit_hash = std::fs::read_to_string(commit_path).ok()?; let mut pointer_path = self.pointer_path(repo, &commit_hash); pointer_path.push(filename); - if pointer_path.exists(){ + if pointer_path.exists() { Some(pointer_path) - }else{ + } else { None } } @@ -150,7 +150,7 @@ impl Repo { /// The normalized folder nameof the repo within the cache directory pub fn folder_name(&self) -> String { - let prefix = match self.repo_type{ + let prefix = match self.repo_type { RepoType::Model => "models", RepoType::Dataset => "datasets", RepoType::Space => "spaces",