diff --git a/candle-core/examples/llama/weights.rs b/candle-core/examples/llama/weights.rs index cc3fccd4..c3364cef 100644 --- a/candle-core/examples/llama/weights.rs +++ b/candle-core/examples/llama/weights.rs @@ -105,7 +105,7 @@ impl Llama { ) -> Result { let handles: Vec<_> = filenames .iter() - .map(candle::safetensors::MmapedFile::new) + .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) }) .collect::>>()?; let tensors: Vec<_> = handles .iter() diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index b80a756a..99e11c60 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -12,6 +12,10 @@ fn convert_(view: st::TensorView<'_>, device: &Device) -> Result(st::SafeTensors<'a>); pub struct MmapedFile(memmap2::Mmap); impl MmapedFile { - pub fn new>(p: P) -> Result { + /// Creates a wrapper around a memory mapped file from which you can retrieve + /// tensors using [`MmapedFile::deserialize`] + /// + /// # Safety + /// + /// The unsafe is inherited from [`memmap2::MmapOptions`]. + pub unsafe fn new>(p: P) -> Result { let file = std::fs::File::open(p)?; - let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; + let mmap = memmap2::MmapOptions::new().map(&file)?; Ok(Self(mmap)) }