mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Quantized support for f16 and f32 (#457)
* Add f32 as a quantized type. * Add f16 as a quantized type too.
This commit is contained in:
@ -726,3 +726,77 @@ pub fn matmul<T: GgmlType>(
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl GgmlType for f32 {
|
||||
const DTYPE: GgmlDType = GgmlDType::F32;
|
||||
const BLCK_SIZE: usize = 1;
|
||||
type VecDotType = f32;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
if ys.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", ys.len())
|
||||
}
|
||||
let mut res = 0f32;
|
||||
unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
ys.copy_from_slice(xs);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
ys.copy_from_slice(xs);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for f16 {
|
||||
const DTYPE: GgmlDType = GgmlDType::F16;
|
||||
const BLCK_SIZE: usize = 1;
|
||||
type VecDotType = f16;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
if ys.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", ys.len())
|
||||
}
|
||||
let mut res = 0f32;
|
||||
unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
// TODO: vectorize
|
||||
for (x, y) in xs.iter().zip(ys.iter_mut()) {
|
||||
*y = f16::from_f32(*x)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
// TODO: vectorize
|
||||
for (x, y) in xs.iter().zip(ys.iter_mut()) {
|
||||
*y = x.to_f32()
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user