diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 0e2e8093..171415cf 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -15,6 +15,7 @@ candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" } candle-datasets = { path = "../candle-datasets", version = "0.2.3" } candle-nn = { path = "../candle-nn", version = "0.2.3" } candle-transformers = { path = "../candle-transformers", version = "0.2.3" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true } cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } image = { workspace = true } @@ -50,7 +51,7 @@ default = [] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cudnn = ["candle/cudnn"] -flash-attn = ["cuda", "candle-transformers/flash-attn"] +flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] nccl = ["cuda", "cudarc/nccl", "dep:half"] diff --git a/candle-examples/examples/llama_multiprocess/main.rs b/candle-examples/examples/llama_multiprocess/main.rs index 8a13ce6c..87f91e2c 100644 --- a/candle-examples/examples/llama_multiprocess/main.rs +++ b/candle-examples/examples/llama_multiprocess/main.rs @@ -205,16 +205,9 @@ fn main() -> Result<()> { let cache = model::Cache::new(dtype, &config, &device)?; println!("building the model"); - let handles = filenames - .iter() - .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? })) - .collect::>>()?; - let tensors: Vec<_> = handles - .iter() - .map(|h| Ok(h.deserialize()?)) - .collect::>>()?; - - let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device); + let vb = unsafe { + candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)? + }; let llama = Llama::load(vb, &cache, &config, comm)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 220bae1b..27cbb636 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -456,27 +456,24 @@ impl<'a> VarBuilder<'a> { } } -pub struct ShardedSafeTensors<'a>(SafeTensorWithRouting<'a>); -pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors<'a>>; +pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors); +pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>; -impl<'a> ShardedSafeTensors<'a> { - pub fn var_builder( - safetensors: Vec>, +impl ShardedSafeTensors { + /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors + /// files and make them usable in a sharded way. + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn var_builder>( + paths: &[P], dtype: DType, dev: &Device, - ) -> ShardedVarBuilder<'a> { - let mut routing = HashMap::new(); - for (index, sf) in safetensors.iter().enumerate() { - for k in sf.names() { - routing.insert(k.to_string(), index); - } - } - let tensors = SafeTensorWithRouting { - routing, - safetensors, - }; + ) -> Result> { + let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?; let backend = ShardedSafeTensors(tensors); - VarBuilderArgs::new_with_args(backend, dtype, dev) + Ok(VarBuilderArgs::new_with_args(backend, dtype, dev)) } } @@ -508,7 +505,7 @@ impl Default for Shard { /// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))` /// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))` /// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))` -impl<'a> Backend for ShardedSafeTensors<'a> { +impl Backend for ShardedSafeTensors { type Hints = Shard; fn get( @@ -524,18 +521,7 @@ impl<'a> Backend for ShardedSafeTensors<'a> { rank, world_size, } = h; - let SafeTensorWithRouting { - routing, - safetensors, - } = &self.0; - let index = routing.get(path).ok_or_else(|| { - Error::CannotFindTensor { - path: path.to_string(), - } - .bt() - })?; - - let view = safetensors[*index].tensor(path)?; + let view = self.0.get(path)?; let view_dtype = view.dtype(); let mut shape = view.shape().to_vec(); let size = shape[dim]; @@ -578,6 +564,6 @@ impl<'a> Backend for ShardedSafeTensors<'a> { } fn contains_tensor(&self, name: &str) -> bool { - self.0.routing.contains_key(name) + self.0.get(name).is_ok() } }