Cuda acceleration for quantized model. (#1754)

* Boilerplate for the quantized cuda support.

* More basic cuda support.

* More cuda quantization (quantize on cpu for now).

* Add the dequantization bit.

* Start adding some dedicated cuda kernels from llama.cpp.

* Move the kernel code.

* Start interfacing with the kernel.

* Tweak the kernel launch params.

* Bugfix for quantized metal.

* Fix some clippy lints.

* Tweak the launch parameters.

* Tweak cuda basics to perform a quantized matmul.

* Perform the dequantization on the cpu + use cublas for matmul.

* Add the dequantization kernel.

* Test the qmatmul.

* More kernels.

* Matmul-vec kernel.

* Add a couple kernels.

* More dequantization kernels.
This commit is contained in:
Laurent Mazare
2024-02-25 18:11:47 +01:00
committed by GitHub
parent 8d04f70f4d
commit 2f22afd80e
11 changed files with 1996 additions and 69 deletions

View File

@ -827,9 +827,9 @@ impl BackendStorage for MetalStorage {
layout.start_offset() * self.dtype.size_in_bytes(),
),
&t.buffer,
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
&f.buffer,
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
&buffer,
)
.map_err(MetalError::from)?;
@ -1264,7 +1264,7 @@ impl BackendStorage for MetalStorage {
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.end_encoding();
} else {
let src_shape = src_l.shape();
@ -1636,7 +1636,7 @@ impl BackendDevice for MetalDevice {
min as f32,
max as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
@ -1667,7 +1667,7 @@ impl BackendDevice for MetalDevice {
mean as f32,
stddev as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;