mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Quantized support for f16 and f32 (#457)
* Add f32 as a quantized type. * Add f16 as a quantized type too.
This commit is contained in:
@ -69,9 +69,6 @@ impl Cpu<ARR> for CurrentCpu {
|
||||
for i in 0..ARR / 4 {
|
||||
x[4 * i] = vaddq_f32(x[4 * i], x[4 * i + 2]);
|
||||
}
|
||||
for i in 0..ARR / 8 {
|
||||
x[8 * i] = vaddq_f32(x[8 * i], x[8 * i + 4]);
|
||||
}
|
||||
*y = Self::reduce_one(x[0]);
|
||||
}
|
||||
}
|
||||
|
@ -726,3 +726,77 @@ pub fn matmul<T: GgmlType>(
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl GgmlType for f32 {
|
||||
const DTYPE: GgmlDType = GgmlDType::F32;
|
||||
const BLCK_SIZE: usize = 1;
|
||||
type VecDotType = f32;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
if ys.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", ys.len())
|
||||
}
|
||||
let mut res = 0f32;
|
||||
unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
ys.copy_from_slice(xs);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
ys.copy_from_slice(xs);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl GgmlType for f16 {
|
||||
const DTYPE: GgmlDType = GgmlDType::F16;
|
||||
const BLCK_SIZE: usize = 1;
|
||||
type VecDotType = f16;
|
||||
|
||||
fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
|
||||
if xs.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", xs.len())
|
||||
}
|
||||
if ys.len() < n {
|
||||
crate::bail!("size mismatch {} < {n}", ys.len())
|
||||
}
|
||||
let mut res = 0f32;
|
||||
unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
// TODO: vectorize
|
||||
for (x, y) in xs.iter().zip(ys.iter_mut()) {
|
||||
*y = f16::from_f32(*x)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
|
||||
if xs.len() != ys.len() {
|
||||
crate::bail!("size mismatch {} {}", xs.len(), ys.len());
|
||||
}
|
||||
// TODO: vectorize
|
||||
for (x, y) in xs.iter().zip(ys.iter_mut()) {
|
||||
*y = x.to_f32()
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ use candle::{DType, Device};
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGML file to load.
|
||||
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
|
||||
#[arg(long)]
|
||||
model: String,
|
||||
}
|
||||
|
Reference in New Issue
Block a user