Speedup ShardedSafeTensors to load Tensors with default hints (#1384)

* Speedup ShardedSafeTensors to load Tensors with default hints

* Tweaks.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
YiiSh
2023-12-14 22:08:56 +08:00
committed by GitHub
parent 7be982f6f7
commit e60f9b5dfc

View File

@ -535,12 +535,18 @@ impl Backend for ShardedSafeTensors {
fn get(
&self,
_target_shape: Shape, // The size is not checked for ShardedTensors
target_shape: Shape, // The size is only checked when the world size is 1.
path: &str,
h: Self::Hints,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
if h.world_size == 1 {
// There is no sharding to be applied here so we use the default backend to speed
// things up.
return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
}
let Shard {
dim,
rank,