mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Self-contained safetensors for the multiprocess llama example. (#950)
This commit is contained in:
@ -15,6 +15,7 @@ candle = { path = "../candle-core", version = "0.2.3", package = "candle-core" }
|
|||||||
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
candle-datasets = { path = "../candle-datasets", version = "0.2.3" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
candle-nn = { path = "../candle-nn", version = "0.2.3" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
candle-transformers = { path = "../candle-transformers", version = "0.2.3" }
|
||||||
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.3", optional = true }
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
@ -50,7 +51,7 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||||
cudnn = ["candle/cudnn"]
|
cudnn = ["candle/cudnn"]
|
||||||
flash-attn = ["cuda", "candle-transformers/flash-attn"]
|
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
|
|
||||||
|
@ -205,16 +205,9 @@ fn main() -> Result<()> {
|
|||||||
let cache = model::Cache::new(dtype, &config, &device)?;
|
let cache = model::Cache::new(dtype, &config, &device)?;
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let handles = filenames
|
let vb = unsafe {
|
||||||
.iter()
|
candle_nn::var_builder::ShardedSafeTensors::var_builder(&filenames, dtype, &device)?
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
};
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let tensors: Vec<_> = handles
|
|
||||||
.iter()
|
|
||||||
.map(|h| Ok(h.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let vb = candle_nn::var_builder::ShardedSafeTensors::var_builder(tensors, dtype, &device);
|
|
||||||
let llama = Llama::load(vb, &cache, &config, comm)?;
|
let llama = Llama::load(vb, &cache, &config, comm)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
|
@ -456,27 +456,24 @@ impl<'a> VarBuilder<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ShardedSafeTensors<'a>(SafeTensorWithRouting<'a>);
|
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
|
||||||
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors<'a>>;
|
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
|
||||||
|
|
||||||
impl<'a> ShardedSafeTensors<'a> {
|
impl ShardedSafeTensors {
|
||||||
pub fn var_builder(
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||||
safetensors: Vec<SafeTensors<'a>>,
|
/// files and make them usable in a sharded way.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
///
|
||||||
|
/// The unsafe is inherited from [`memmap2::MmapOptions`].
|
||||||
|
pub unsafe fn var_builder<P: AsRef<std::path::Path>>(
|
||||||
|
paths: &[P],
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
dev: &Device,
|
dev: &Device,
|
||||||
) -> ShardedVarBuilder<'a> {
|
) -> Result<ShardedVarBuilder<'static>> {
|
||||||
let mut routing = HashMap::new();
|
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
|
||||||
for (index, sf) in safetensors.iter().enumerate() {
|
|
||||||
for k in sf.names() {
|
|
||||||
routing.insert(k.to_string(), index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let tensors = SafeTensorWithRouting {
|
|
||||||
routing,
|
|
||||||
safetensors,
|
|
||||||
};
|
|
||||||
let backend = ShardedSafeTensors(tensors);
|
let backend = ShardedSafeTensors(tensors);
|
||||||
VarBuilderArgs::new_with_args(backend, dtype, dev)
|
Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -508,7 +505,7 @@ impl Default for Shard {
|
|||||||
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
|
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
|
||||||
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
|
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
|
||||||
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
|
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
|
||||||
impl<'a> Backend for ShardedSafeTensors<'a> {
|
impl Backend for ShardedSafeTensors {
|
||||||
type Hints = Shard;
|
type Hints = Shard;
|
||||||
|
|
||||||
fn get(
|
fn get(
|
||||||
@ -524,18 +521,7 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
|
|||||||
rank,
|
rank,
|
||||||
world_size,
|
world_size,
|
||||||
} = h;
|
} = h;
|
||||||
let SafeTensorWithRouting {
|
let view = self.0.get(path)?;
|
||||||
routing,
|
|
||||||
safetensors,
|
|
||||||
} = &self.0;
|
|
||||||
let index = routing.get(path).ok_or_else(|| {
|
|
||||||
Error::CannotFindTensor {
|
|
||||||
path: path.to_string(),
|
|
||||||
}
|
|
||||||
.bt()
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let view = safetensors[*index].tensor(path)?;
|
|
||||||
let view_dtype = view.dtype();
|
let view_dtype = view.dtype();
|
||||||
let mut shape = view.shape().to_vec();
|
let mut shape = view.shape().to_vec();
|
||||||
let size = shape[dim];
|
let size = shape[dim];
|
||||||
@ -578,6 +564,6 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn contains_tensor(&self, name: &str) -> bool {
|
fn contains_tensor(&self, name: &str) -> bool {
|
||||||
self.0.routing.contains_key(name)
|
self.0.get(name).is_ok()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user