mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Merge pull request #88 from LaurentMazare/fix_unsafe_loads
Fixing unsafe slow load (memcpy).
This commit is contained in:
@ -71,7 +71,9 @@ fn convert_<T: WithDType>(view: st::TensorView<'_>, device: &Device) -> Result<T
|
|||||||
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
|
let data: &[T] = unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, elem_count) };
|
||||||
Tensor::from_slice(data, view.shape(), device)
|
Tensor::from_slice(data, view.shape(), device)
|
||||||
} else {
|
} else {
|
||||||
let mut c = Vec::with_capacity(elem_count);
|
// XXX: We need to specify `T` here, otherwise the compiler will infer u8 because of the following cast
|
||||||
|
// Making this vector too small to fit a full f16/f32/f64 weights, resulting in out-of-bounds access
|
||||||
|
let mut c: Vec<T> = Vec::with_capacity(elem_count);
|
||||||
// SAFETY: We just created c, so the allocated memory is necessarily
|
// SAFETY: We just created c, so the allocated memory is necessarily
|
||||||
// contiguous and non overlapping with the view's data.
|
// contiguous and non overlapping with the view's data.
|
||||||
// We're downgrading the `c` pointer from T to u8, which removes alignment
|
// We're downgrading the `c` pointer from T to u8, which removes alignment
|
||||||
|
Reference in New Issue
Block a user