mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Cuda acceleration for quantized model. (#1754)
* Boilerplate for the quantized cuda support. * More basic cuda support. * More cuda quantization (quantize on cpu for now). * Add the dequantization bit. * Start adding some dedicated cuda kernels from llama.cpp. * Move the kernel code. * Start interfacing with the kernel. * Tweak the kernel launch params. * Bugfix for quantized metal. * Fix some clippy lints. * Tweak the launch parameters. * Tweak cuda basics to perform a quantized matmul. * Perform the dequantization on the cpu + use cublas for matmul. * Add the dequantization kernel. * Test the qmatmul. * More kernels. * Matmul-vec kernel. * Add a couple kernels. * More dequantization kernels.
This commit is contained in:
@ -827,9 +827,9 @@ impl BackendStorage for MetalStorage {
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1264,7 +1264,7 @@ impl BackendStorage for MetalStorage {
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let src_shape = src_l.shape();
|
||||
@ -1636,7 +1636,7 @@ impl BackendDevice for MetalDevice {
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -1667,7 +1667,7 @@ impl BackendDevice for MetalDevice {
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
Reference in New Issue
Block a user