Early conversion for the llama weights.

This commit is contained in:
laurent
2023-06-30 16:42:53 +01:00
parent dbd7d5b3fd
commit 679b6987b6
2 changed files with 19 additions and 45 deletions

View File

@ -18,7 +18,7 @@ fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
// was correctly aligned.
let data: &[f16] =
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const f16, v.len() / 2) };
Tensor::from_slice(data, view.shape(), device)
Tensor::from_slice(data, view.shape(), device)?.to_dtype(DTYPE)
} else {
let mut c = Vec::with_capacity(v.len() / 2);
let mut i = 0;
@ -26,7 +26,7 @@ fn convert(view: TensorView<'_>, device: &Device) -> Result<Tensor> {
c.push(f16::from_le_bytes([v[i], v[i + 1]]));
i += 2;
}
Tensor::from_slice(&c, view.shape(), device)
Tensor::from_slice(&c, view.shape(), device)?.to_dtype(DTYPE)
}
}
dt => todo!("Unhandled dtype {dt:?}"),