Adding a bit more docs around safety.

This commit is contained in:
Nicolas Patry
2023-07-03 11:55:54 +02:00
parent 48089005f6
commit 81cec86e75
2 changed files with 13 additions and 3 deletions

View File

@ -105,7 +105,7 @@ impl Llama {
) -> Result<Self> { ) -> Result<Self> {
let handles: Vec<_> = filenames let handles: Vec<_> = filenames
.iter() .iter()
.map(candle::safetensors::MmapedFile::new) .map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles let tensors: Vec<_> = handles
.iter() .iter()

View File

@ -12,6 +12,10 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T
Tensor::from_slice(data, view.shape(), device) Tensor::from_slice(data, view.shape(), device)
} else { } else {
let mut c = Vec::with_capacity(elem_count); let mut c = Vec::with_capacity(elem_count);
// SAFETY: We just created c, so the allocated memory is necessarily
// contiguous and non overlapping with the view's data.
// We're downgrading the `c` pointer from T to u8, which removes alignment
// constraints.
unsafe { unsafe {
std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len()); std::ptr::copy_nonoverlapping(v.as_ptr(), c.as_mut_ptr() as *mut u8, v.len());
c.set_len(elem_count) c.set_len(elem_count)
@ -42,9 +46,15 @@ pub struct SafeTensors<'a>(st::SafeTensors<'a>);
pub struct MmapedFile(memmap2::Mmap); pub struct MmapedFile(memmap2::Mmap);
impl MmapedFile { impl MmapedFile {
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> { /// Creates a wrapper around a memory mapped file from which you can retrieve
/// tensors using [`MmapedFile::deserialize`]
///
/// # Safety
///
/// The unsafe is inherited from [`memmap2::MmapOptions`].
pub unsafe fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?; let file = std::fs::File::open(p)?;
let mmap = unsafe { memmap2::MmapOptions::new().map(&file)? }; let mmap = memmap2::MmapOptions::new().map(&file)?;
Ok(Self(mmap)) Ok(Self(mmap))
} }