From 81cec86e758390e5f025a6e93888673b003fb4c8 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 3 Jul 2023 11:55:54 +0200 Subject: [PATCH] Adding a bit more docs around safety. --- candle-core/examples/llama/weights.rs | 2 +- candle-core/src/safetensors.rs | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index cc3fccd4..c3364cef 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -105,7 +105,7 @@ impl Llama { ) -> Result { let handles: Vec<_> = filenames .iter() - .map(candle::safetensors::MmapedFile::new) + .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) .collect::>>()?; let tensors: Vec<_> = handles .iter() diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index b80a756a..99e11c60 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -12,6 +12,10 @@ fn convert_(view: st::TensorView<'_>, device: &Device) -> Result(st::SafeTensors<'a>); pub struct MmapedFile(memmap2::Mmap); impl MmapedFile { - pub fn new>(p: P) -> Result { + /// Creates a wrapper around a memory mapped file from which you can retrieve + /// tensors using [`MmapedFile::deserialize`] + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn new>(p: P) -> Result { let file = std::fs::File::open(p)?; - let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; + let mmap = memmap2::MmapOptions::new().map(&file)?; Ok(Self(mmap)) }