mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Tensor -> QTensor conversion (#496)
* Sketch some qmatmul test. * Add the quantization function. * More testing. * Make the test smaller and faster. * Add some shape checking.
This commit is contained in:
@ -125,7 +125,7 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
let raw_data_ptr = raw_data.as_ptr();
|
||||
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
Ok(super::QTensor::new(data.to_vec(), dims))
|
||||
super::QTensor::new(data.to_vec(), dims)
|
||||
}
|
||||
|
||||
/// Creates a [Tensor] from a raw GGML tensor.
|
||||
|
@ -117,15 +117,52 @@ impl std::fmt::Debug for QTensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
|
||||
let dims = shape.dims();
|
||||
if dims.is_empty() {
|
||||
crate::bail!("scalar tensor cannot be quantized {shape:?}")
|
||||
}
|
||||
if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"quantized tensor must have their last dim divisible by block size {shape:?} {}",
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl QTensor {
|
||||
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
|
||||
data: Vec<T>,
|
||||
shape: S,
|
||||
) -> Self {
|
||||
Self {
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
check_shape::<T>(&shape)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.into(),
|
||||
shape,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
|
||||
let shape = src.shape();
|
||||
check_shape::<T>(shape)?;
|
||||
let src = src
|
||||
.to_dtype(crate::DType::F32)?
|
||||
.flatten_all()?
|
||||
.to_vec1::<f32>()?;
|
||||
if src.len() % T::BLCK_SIZE != 0 {
|
||||
crate::bail!(
|
||||
"tensor size ({shape:?}) is not divisible by block size {}",
|
||||
T::BLCK_SIZE
|
||||
)
|
||||
}
|
||||
let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
|
||||
T::from_float(&src, &mut data)?;
|
||||
Ok(Self {
|
||||
data: Box::new(data),
|
||||
shape: shape.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
|
Reference in New Issue
Block a user