Self-contained safetensors for the multiprocess llama example. (#950)

This commit is contained in:
Laurent Mazare
2023-09-24 06:54:49 +01:00
committed by GitHub
parent 7edd755756
commit bcb0ed8f1c
3 changed files with 22 additions and 42 deletions

View File

@ -456,27 +456,24 @@ impl<'a> VarBuilder<'a> {
}
}
pub struct ShardedSafeTensors<'a>(SafeTensorWithRouting<'a>);
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors<'a>>;
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
impl<'a> ShardedSafeTensors<'a> {
pub fn var_builder(
safetensors: Vec<SafeTensors<'a>>,
impl ShardedSafeTensors {
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
/// files and make them usable in a sharded way.
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn var_builder<P: AsRef<std::path::Path>>(
paths: &[P],
dtype: DType,
dev: &Device,
) -> ShardedVarBuilder<'a> {
let mut routing = HashMap::new();
for (index, sf) in safetensors.iter().enumerate() {
for k in sf.names() {
routing.insert(k.to_string(), index);
}
}
let tensors = SafeTensorWithRouting {
routing,
safetensors,
};
) -> Result<ShardedVarBuilder<'static>> {
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
let backend = ShardedSafeTensors(tensors);
VarBuilderArgs::new_with_args(backend, dtype, dev)
Ok(VarBuilderArgs::new_with_args(backend, dtype, dev))
}
}
@ -508,7 +505,7 @@ impl Default for Shard {
/// `get_sharded("tensor", 0, 0, 2)` means `tensor.i((..512))`
/// `get_sharded("tensor", 0, 1, 2)` means `tensor.i((512..))`
/// `get_sharded("tensor", 1, 0, 2)` means `tensor.i((.., ..512))`
impl<'a> Backend for ShardedSafeTensors<'a> {
impl Backend for ShardedSafeTensors {
type Hints = Shard;
fn get(
@ -524,18 +521,7 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
rank,
world_size,
} = h;
let SafeTensorWithRouting {
routing,
safetensors,
} = &self.0;
let index = routing.get(path).ok_or_else(|| {
Error::CannotFindTensor {
path: path.to_string(),
}
.bt()
})?;
let view = safetensors[*index].tensor(path)?;
let view = self.0.get(path)?;
let view_dtype = view.dtype();
let mut shape = view.shape().to_vec();
let size = shape[dim];
@ -578,6 +564,6 @@ impl<'a> Backend for ShardedSafeTensors<'a> {
}
fn contains_tensor(&self, name: &str) -> bool {
self.0.routing.contains_key(name)
self.0.get(name).is_ok()
}
}