mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Early conversion for the llama weights.
This commit is contained in:
@ -18,7 +18,7 @@ fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
// was correctly aligned.
|
||||
let data: &[f16] =
|
||||
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
|
||||
Tensor::from_slice(data, view.shape(), device)
|
||||
Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE)
|
||||
} else {
|
||||
let mut c = Vec::with_capacity(v.len() / 2);
|
||||
let mut i = 0;
|
||||
@ -26,7 +26,7 @@ fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
|
||||
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
|
||||
i += 2;
|
||||
}
|
||||
Tensor::from_slice(&c, view.shape(), device)
|
||||
Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE)
|
||||
}
|
||||
}
|
||||
dt => todo!("Unhandled dtype {dt:?}"),
|
||||
|
Reference in New Issue
Block a user