Compare commits

..

11 Commits

Author SHA1 Message Date
cdbdb4af9c Update yew-agent requirement from 0.2.0 to 0.3.0
Updates the requirements on [yew-agent](https://github.com/yewstack/yew) to permit the latest version.
- [Release notes](https://github.com/yewstack/yew/releases)
- [Changelog](https://github.com/yewstack/yew/blob/master/CHANGELOG.md)
- [Commits](https://github.com/yewstack/yew/compare/yew-agent-v0.2.0...yew-agent-v0.3.0)

---
updated-dependencies:
- dependency-name: yew-agent
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-01-10 14:14:03 +00:00
edf3fcd1c4 fix: deprecated option field (open-pull-requests-limit-per-dependency) (#1554) 2024-01-10 15:12:46 +01:00
53e4755015 feat: add dependabot to the project (#1553)
* feat: add dependabot to the project

* feat: add let's accept patches/fix from other libs

* Revert "feat: add let's accept patches/fix from other libs"

This reverts commit d31a956f81.
2024-01-10 14:57:20 +01:00
12b2a337f3 Handle start-offset when loading a tensor from a pickle file. (#1546) 2024-01-08 09:20:48 +01:00
0eb90ed783 Simpler repro for the neon optimization issue + bugfix (#1544)
* Simpler repro for the neon optimization issue.

* Bugfix for q4k.

* Improve the fix, share the dot-prod bit.

* Clippy fixes.

* Fix for q6k.

* Also fix for q2k.

* Use the new shared dotprod.

* Add more testing.
2024-01-07 20:21:49 +01:00
89b5a06858 Use bindgen-cuda for the custom-kernel example. (#1536)
* Use bindgen-cuda for the custom-kernel example.

* Only depend on the kernels when cuda is enabled.

* Skip rustfmt.
2024-01-07 17:18:46 +01:00
30313c3081 Moving to a proper build crate bindgen_cuda. (#1531)
* Moving to a proper build crate `bindgen_cuda`.

* Fmt.
2024-01-07 12:29:24 +01:00
e72d52b1a2 Unpin more of the workplace relative dependencies. (#1535) 2024-01-07 12:26:20 +01:00
b4cb982e49 Simplifying our internal cargo dependencies. (#1529) 2024-01-07 12:04:14 +01:00
84250bf52f fix index_pos bug when kv cache is disabled. (#1517)
* fix index_pos bug when kv cache is disabled

* Tweak the fix.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
2024-01-06 11:43:01 +01:00
8d1a57c9a0 chore: update flash attention kernels (#1518)
* chore: update flash attention kernels

* fmt

* remove unused kernels

* force f32

* correct stride
2024-01-05 18:28:55 +01:00
90 changed files with 1818 additions and 7868 deletions

7
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,7 @@
version: 2
updates:
- package-ecosystem: "cargo"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 5

View File

@ -31,6 +31,14 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] } cudarc = { version = "0.9.14", features = ["f16"] }

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" } candle-datasets = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } candle-flash-attn = { workspace = true, optional = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }

View File

@ -12,8 +12,8 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true } byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true } candle-kernels = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true } candle-metal-kernels = { workspace = true, optional = true }
metal = { workspace = true, optional = true} metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
gemm = { workspace = true } gemm = { workspace = true }

View File

@ -1,5 +1,5 @@
use candle_core::quantized::{gguf_file, GgmlDType, QTensor}; use candle_core::quantized::{gguf_file, k_quants, QTensor};
use candle_core::{Device, Result}; use candle_core::{Device, Result, Tensor};
use clap::{Parser, Subcommand, ValueEnum}; use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*; use rayon::prelude::*;
@ -11,7 +11,12 @@ enum QuantizationMode {
} }
impl QuantizationMode { impl QuantizationMode {
fn quantize(&self, name: &str, tensor: QTensor, dtype: GgmlDType) -> Result<QTensor> { fn quantize(
&self,
name: &str,
tensor: QTensor,
default: fn(&Tensor) -> Result<QTensor>,
) -> Result<QTensor> {
match self { match self {
Self::Llama => { Self::Llama => {
// Same behavior as the llama.cpp quantization. // Same behavior as the llama.cpp quantization.
@ -19,9 +24,9 @@ impl QuantizationMode {
if should_quantize { if should_quantize {
let tensor = tensor.dequantize(&Device::Cpu)?; let tensor = tensor.dequantize(&Device::Cpu)?;
if name == "output.weight" { if name == "output.weight" {
QTensor::quantize(&tensor, GgmlDType::Q6K) QTensor::quantize::<k_quants::BlockQ6K>(&tensor)
} else { } else {
QTensor::quantize(&tensor, dtype) default(&tensor)
} }
} else { } else {
Ok(tensor) Ok(tensor)
@ -55,27 +60,6 @@ enum Quantization {
F32, F32,
} }
impl Quantization {
fn dtype(&self) -> GgmlDType {
match self {
Quantization::Q4_0 => GgmlDType::Q4_0,
Quantization::Q4_1 => GgmlDType::Q4_1,
Quantization::Q5_0 => GgmlDType::Q5_0,
Quantization::Q5_1 => GgmlDType::Q5_1,
Quantization::Q8_0 => GgmlDType::Q8_0,
Quantization::Q8_1 => GgmlDType::Q8_1,
Quantization::Q2k => GgmlDType::Q2K,
Quantization::Q3k => GgmlDType::Q3K,
Quantization::Q4k => GgmlDType::Q4K,
Quantization::Q5k => GgmlDType::Q5K,
Quantization::Q6k => GgmlDType::Q6K,
Quantization::Q8k => GgmlDType::Q8K,
Quantization::F16 => GgmlDType::F16,
Quantization::F32 => GgmlDType::F32,
}
}
}
#[derive(ValueEnum, Debug, Clone)] #[derive(ValueEnum, Debug, Clone)]
enum Format { enum Format {
Safetensors, Safetensors,
@ -141,12 +125,7 @@ struct Args {
command: Command, command: Command,
} }
fn run_ls( fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
file: &std::path::PathBuf,
format: Option<Format>,
verbose: bool,
device: &Device,
) -> Result<()> {
let format = match format { let format = match format {
Some(format) => format, Some(format) => format,
None => match Format::infer(file) { None => match Format::infer(file) {
@ -212,7 +191,7 @@ fn run_ls(
} }
Format::Ggml => { Format::Ggml => {
let mut file = std::fs::File::open(file)?; let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>(); let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() { for (name, qtensor) in tensors.iter() {
@ -253,8 +232,37 @@ fn run_quantize_safetensors(
} }
println!("tensors: {}", tensors.len()); println!("tensors: {}", tensors.len());
let dtype = q.dtype(); let quantize_fn = match q {
let block_size = dtype.block_size(); Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let block_size = match q {
Quantization::Q4_0 => k_quants::QK4_0,
Quantization::Q4_1 => k_quants::QK4_1,
Quantization::Q5_0 => k_quants::QK5_0,
Quantization::Q5_1 => k_quants::QK5_1,
Quantization::Q8_0 => k_quants::QK8_0,
Quantization::Q8_1 => k_quants::QK8_1,
Quantization::Q2k
| Quantization::Q3k
| Quantization::Q4k
| Quantization::Q5k
| Quantization::Q6k
| Quantization::Q8k => k_quants::QK_K,
Quantization::F16 | Quantization::F32 => 1,
};
let qtensors = tensors let qtensors = tensors
.into_par_iter() .into_par_iter()
@ -262,9 +270,9 @@ fn run_quantize_safetensors(
let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0; let should_quantize = tensor.rank() == 2 && tensor.dim(1)? % block_size == 0;
println!(" quantizing {name} {tensor:?} {should_quantize}"); println!(" quantizing {name} {tensor:?} {should_quantize}");
let tensor = if should_quantize { let tensor = if should_quantize {
QTensor::quantize(&tensor, dtype)? quantize_fn(&tensor)?
} else { } else {
QTensor::quantize(&tensor, GgmlDType::F32)? QTensor::quantize::<f32>(&tensor)?
}; };
Ok((name, tensor)) Ok((name, tensor))
}) })
@ -282,7 +290,6 @@ fn run_quantize(
out_file: std::path::PathBuf, out_file: std::path::PathBuf,
q: Quantization, q: Quantization,
qmode: QuantizationMode, qmode: QuantizationMode,
device: &Device,
) -> Result<()> { ) -> Result<()> {
if in_files.is_empty() { if in_files.is_empty() {
candle_core::bail!("no specified input files") candle_core::bail!("no specified input files")
@ -308,15 +315,31 @@ fn run_quantize(
let content = gguf_file::Content::read(&mut in_)?; let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len()); println!("tensors: {}", content.tensor_infos.len());
let dtype = q.dtype(); let quantize_fn = match q {
Quantization::Q4_0 => QTensor::quantize::<k_quants::BlockQ4_0>,
Quantization::Q4_1 => QTensor::quantize::<k_quants::BlockQ4_1>,
Quantization::Q5_0 => QTensor::quantize::<k_quants::BlockQ5_0>,
Quantization::Q5_1 => QTensor::quantize::<k_quants::BlockQ5_1>,
Quantization::Q8_0 => QTensor::quantize::<k_quants::BlockQ8_0>,
Quantization::Q8_1 => QTensor::quantize::<k_quants::BlockQ8_1>,
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>,
Quantization::F16 => QTensor::quantize::<half::f16>,
Quantization::F32 => QTensor::quantize::<f32>,
};
let qtensors = content let qtensors = content
.tensor_infos .tensor_infos
.par_iter() .par_iter()
.map(|(name, _)| { .map(|(name, _)| {
println!(" quantizing {name}"); println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_files[0])?; let mut in_file = std::fs::File::open(&in_files[0])?;
let tensor = content.tensor(&mut in_file, name, device)?; let tensor = content.tensor(&mut in_file, name)?;
let tensor = qmode.quantize(name, tensor, dtype)?; let tensor = qmode.quantize(name, tensor, quantize_fn)?;
Ok((name, tensor)) Ok((name, tensor))
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
@ -336,7 +359,6 @@ fn run_quantize(
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
let device = Device::Cpu;
match args.command { match args.command {
Command::Ls { Command::Ls {
files, files,
@ -348,7 +370,7 @@ fn main() -> anyhow::Result<()> {
if multiple_files { if multiple_files {
println!("--- {file:?} ---"); println!("--- {file:?} ---");
} }
run_ls(file, format.clone(), verbose, &device)? run_ls(file, format.clone(), verbose)?
} }
} }
Command::Quantize { Command::Quantize {
@ -356,7 +378,7 @@ fn main() -> anyhow::Result<()> {
out_file, out_file,
quantization, quantization,
mode, mode,
} => run_quantize(&in_file, out_file, quantization, mode, &device)?, } => run_quantize(&in_file, out_file, quantization, mode)?,
} }
Ok(()) Ok(())
} }

View File

@ -88,7 +88,7 @@ pub struct MetalDevice {
/// execution order to be linear. /// execution order to be linear.
/// It could be relaxed in some circumstances, by managing ourselves the dependencies in the /// It could be relaxed in some circumstances, by managing ourselves the dependencies in the
/// compute graph. /// compute graph.
// fence: metal::Fence, fence: metal::Fence,
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them. /// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
/// Heavily used by [`candle_metal_kernels`], both fences need to match /// Heavily used by [`candle_metal_kernels`], both fences need to match
kernels: Arc<candle_metal_kernels::Kernels>, kernels: Arc<candle_metal_kernels::Kernels>,
@ -131,10 +131,6 @@ impl MetalDevice {
&self.device &self.device
} }
// pub(crate) fn fence(&self) -> &metal::Fence {
// &self.fence
// }
pub fn command_queue(&self) -> &CommandQueue { pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue &self.command_queue
} }
@ -225,10 +221,10 @@ impl MetalDevice {
let command_buffer = self.command_buffer()?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("with_data"); command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
// blit.wait_for_fence(&self.fence); blit.wait_for_fence(&self.fence);
blit.set_label("with_data_blit"); blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
// blit.update_fence(&self.fence); blit.update_fence(&self.fence);
blit.end_encoding(); blit.end_encoding();
// This is necessary, for mmaped safetensors // This is necessary, for mmaped safetensors
@ -242,29 +238,6 @@ impl MetalDevice {
Ok(real) Ok(real)
} }
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
let buffer = self.allocate_buffer(
size_in_bytes as NSUInteger,
MTLResourceOptions::StorageModePrivate,
"allocate_zeros",
)?;
let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
// blit.wait_for_fence(&self.fence);
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
// blit.update_fence(&self.fence);
blit.end_encoding();
Ok(buffer)
}
/// The critical allocator algorithm /// The critical allocator algorithm
fn allocate_buffer( fn allocate_buffer(
&self, &self,
@ -335,14 +308,35 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
let length = self.buffer.length() as usize;
let size = self.dtype.size_in_bytes();
if length % size != 0 {
crate::bail!(
"The Metal buffer length is not aligned with dtype {:?}",
self.dtype
);
}
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.update_fence(&self.device.fence);
blit.end_encoding();
}
self.device.wait_until_completed()?;
match self.dtype { match self.dtype {
DType::U8 => Ok(CpuStorage::U8(self.to_cpu()?)), DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
DType::U32 => Ok(CpuStorage::U32(self.to_cpu()?)), DType::U32 => Ok(CpuStorage::U32(read_to_vec(&buffer, length / size))),
DType::I64 => Ok(CpuStorage::I64(self.to_cpu()?)), DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
DType::F16 => Ok(CpuStorage::F16(self.to_cpu()?)), DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
DType::BF16 => Ok(CpuStorage::BF16(self.to_cpu()?)), DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
DType::F32 => Ok(CpuStorage::F32(self.to_cpu()?)), DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
DType::F64 => Ok(CpuStorage::F64(self.to_cpu()?)), DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
} }
} }
@ -1247,7 +1241,7 @@ impl BackendStorage for MetalStorage {
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger; let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger; let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger; let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length); blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
blit.end_encoding(); blit.end_encoding();
} else { } else {
let src_shape = src_l.shape(); let src_shape = src_l.shape();
@ -1470,30 +1464,6 @@ impl MetalStorage {
command_buffer.set_label("binary"); command_buffer.set_label("binary");
Ok(Self::new(buffer, device.clone(), dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
let length = self.buffer.length() as usize;
let size = self.dtype.size_in_bytes();
if length % size != 0 {
crate::bail!(
"The Metal buffer length is not aligned with dtype {:?}",
self.dtype
);
}
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
{
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
// blit.wait_for_fence(&self.device.fence);
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
// blit.update_fence(&self.device.fence);
blit.end_encoding();
}
self.device.wait_until_completed()?;
Ok(read_to_vec(&buffer, length / size))
}
} }
impl BackendDevice for MetalDevice { impl BackendDevice for MetalDevice {
@ -1506,16 +1476,16 @@ impl BackendDevice for MetalDevice {
command_buffer.enqueue(); command_buffer.enqueue();
let command_buffer = Arc::new(RwLock::new(command_buffer)); let command_buffer = Arc::new(RwLock::new(command_buffer));
let command_buffer_index = Arc::new(RwLock::new(0)); let command_buffer_index = Arc::new(RwLock::new(0));
// let fence = device.new_fence(); let fence = device.new_fence();
let kernels = Arc::new(Kernels::new()); let kernels = Arc::new(Kernels::new(fence.clone()));
let buffers = Arc::new(RwLock::new(HashMap::new())); let buffers = Arc::new(RwLock::new(HashMap::new()));
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") { let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
Ok(val) => val.parse()?, Ok(val) => val.parse()?,
_ => 10, _ => 20,
}; };
Ok(Self { Ok(Self {
device, device,
// fence, fence,
command_queue, command_queue,
command_buffer, command_buffer,
command_buffer_index, command_buffer_index,
@ -1540,8 +1510,21 @@ impl BackendDevice for MetalDevice {
} }
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let size = shape.elem_count() * dtype.size_in_bytes(); let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
let buffer = self.allocate_zeros(size)?; let command_buffer = self.command_buffer()?;
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.update_fence(&self.fence);
blit.end_encoding();
Ok(MetalStorage::new(buffer, self.clone(), dtype)) Ok(MetalStorage::new(buffer, self.clone(), dtype))
} }

View File

@ -703,6 +703,7 @@ impl PthTensors {
} }
pub fn get(&self, name: &str) -> Result<Option<Tensor>> { pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
use std::io::Read;
let tensor_info = match self.tensor_infos.get(name) { let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None), None => return Ok(None),
Some(tensor_info) => tensor_info, Some(tensor_info) => tensor_info,
@ -712,14 +713,21 @@ impl PthTensors {
let mut zip = zip::ZipArchive::new(zip_reader)?; let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?; let mut reader = zip.by_name(&tensor_info.path)?;
// Reading the data is a bit tricky as it can be strided, use an offset, etc. // Reading the data is a bit tricky as it can be strided, for now only support the basic
// For now only support the basic case. // case.
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() { if !tensor_info.layout.is_contiguous() {
crate::bail!( crate::bail!(
"cannot retrieve non-contiguous tensors {:?}", "cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout tensor_info.layout
) )
} }
let start_offset = tensor_info.layout.start_offset();
if start_offset > 0 {
std::io::copy(
&mut reader.by_ref().take(start_offset as u64),
&mut std::io::sink(),
)?;
}
let tensor = Tensor::from_reader( let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(), tensor_info.layout.shape().clone(),
tensor_info.dtype, tensor_info.dtype,

View File

@ -1,9 +1,7 @@
//! Support for the GGML file format. //! Support for the GGML file format.
#[cfg(feature = "metal")] use super::{k_quants, GgmlDType};
use super::metal::load_quantized_metal; use crate::Result;
use super::{k_quants, GgmlDType, QStorage};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -123,22 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8], raw_data: &[u8],
size_in_bytes: usize, size_in_bytes: usize,
dims: Vec<usize>, dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr(); let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>(); let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) }; let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
let data: QStorage = match device { super::QTensor::new(data.to_vec(), dims)
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
#[cfg(feature = "metal")]
Device::Metal(metal) => load_quantized_metal(metal, data)?,
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal backend requires `metal` feature")
}
device => unimplemented!("Implement quantized tensor for device {device:?}"),
};
super::QTensor::new(data, dims)
} }
/// Creates a [Tensor] from a raw GGML tensor. /// Creates a [Tensor] from a raw GGML tensor.
@ -146,50 +133,29 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType, ggml_dtype: GgmlDType,
raw_data: &[u8], raw_data: &[u8],
dims: Vec<usize>, dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> { ) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>(); let tensor_elems = dims.iter().product::<usize>();
let block_size = ggml_dtype.block_size(); let blck_size = ggml_dtype.blck_size();
if tensor_elems % block_size != 0 { if tensor_elems % blck_size != 0 {
crate::bail!( crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}" "the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
) )
} }
let size_in_bytes = tensor_elems / block_size * ggml_dtype.type_size(); let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype { match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device), GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device), GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => { GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device) GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
} GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => { GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device) GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
} GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => { GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device) GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
} GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => { GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"), _ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
} }
} }
@ -197,7 +163,6 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>( fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R, reader: &mut R,
magic: VersionedMagic, magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> { ) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?; let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?; let name_len = reader.read_u32::<LittleEndian>()?;
@ -218,11 +183,11 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
} }
let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>(); let dims = dims.iter().map(|&u| u as usize).collect::<Vec<_>>();
let tensor_elems = dims.iter().product::<usize>(); let tensor_elems = dims.iter().product::<usize>();
let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.block_size(); let size_in_bytes = tensor_elems * ggml_dtype.type_size() / ggml_dtype.blck_size();
// TODO: Mmap version to avoid copying the data around? // TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) { match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)), Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"), Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
} }
@ -236,10 +201,7 @@ pub struct Content {
} }
impl Content { impl Content {
pub fn read<R: std::io::Seek + std::io::Read>( pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505 // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?; let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?; reader.seek(std::io::SeekFrom::Start(0))?;
@ -249,7 +211,7 @@ impl Content {
let mut tensors = HashMap::new(); let mut tensors = HashMap::new();
while reader.stream_position()? != last_position { while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?; let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.insert(name, tensor); tensors.insert(name, tensor);
} }
Ok(Self { Ok(Self {

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md //! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor}; use super::{GgmlDType, QTensor};
use crate::{Device, Result}; use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
@ -59,25 +59,19 @@ impl TensorInfo {
&self, &self,
reader: &mut R, reader: &mut R,
tensor_data_offset: u64, tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count(); let tensor_elems = self.shape.elem_count();
let block_size = self.ggml_dtype.block_size(); let blck_size = self.ggml_dtype.blck_size();
if tensor_elems % block_size != 0 { if tensor_elems % blck_size != 0 {
crate::bail!( crate::bail!(
"the number of elements {tensor_elems} is not divisible by the block size {block_size}" "the number of elements {tensor_elems} is not divisible by the block size {blck_size}"
) )
} }
let size_in_bytes = tensor_elems / block_size * self.ggml_dtype.type_size(); let size_in_bytes = tensor_elems / blck_size * self.ggml_dtype.type_size();
let mut raw_data = vec![0u8; size_in_bytes]; let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?; reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?; reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml( super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
} }
} }
@ -466,13 +460,12 @@ impl Content {
&self, &self,
reader: &mut R, reader: &mut R,
name: &str, name: &str,
device: &Device,
) -> Result<QTensor> { ) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) { let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info, Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor info for {name}"), None => crate::bail!("cannot find tensor info for {name}"),
}; };
tensor_info.read(reader, self.tensor_data_offset, device) tensor_info.read(reader, self.tensor_data_offset)
} }
} }
@ -524,9 +517,10 @@ pub fn write<W: std::io::Seek + std::io::Write>(
"internal error, unexpected current position {tensor_start_pos} {offset} {pos}" "internal error, unexpected current position {tensor_start_pos} {offset} {pos}"
) )
} }
let data = tensor.data()?; let data_ptr = tensor.as_ptr();
let size_in_bytes = data.len(); let size_in_bytes = tensor.storage_size_in_bytes();
w.write_all(&data)?; let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
w.write_all(data)?;
let padding = 31 - (31 + size_in_bytes) % 32; let padding = 31 - (31 + size_in_bytes) % 32;
w.write_all(&vec![0u8; padding])?; w.write_all(&vec![0u8; padding])?;
} }

View File

@ -1,155 +0,0 @@
use super::{GgmlDType, QStorage};
use crate::{DType, MetalDevice, MetalStorage, Result};
use metal::Buffer;
use std::sync::Arc;
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
buffer: Arc<Buffer>,
}
impl QMetalStorage {
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
}
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
Self {
device,
buffer,
dtype,
}
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?;
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
// blit.wait_for_fence(&self.device.fence());
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
// blit.update_fence(&self.device.fence());
blit.end_encoding();
self.device.wait_until_completed()?;
let mut out = vec![0.0; elem_count];
match self.dtype {
GgmlDType::F32 => {
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
f32::to_float(&vec, &mut out)?;
}
GgmlDType::F16 => {
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
half::f16::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_0 => {
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q4_1 => {
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_0 => {
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q5_1 => {
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_0 => {
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
}
GgmlDType::Q8_1 => {
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
}
GgmlDType::Q2K => {
let vec: Vec<crate::quantized::BlockQ2K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
}
GgmlDType::Q3K => {
let vec: Vec<crate::quantized::BlockQ3K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
}
GgmlDType::Q4K => {
let vec: Vec<crate::quantized::BlockQ4K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
}
GgmlDType::Q5K => {
let vec: Vec<crate::quantized::BlockQ5K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
}
GgmlDType::Q6K => {
let vec: Vec<crate::quantized::BlockQ6K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
}
GgmlDType::Q8K => {
let vec: Vec<crate::quantized::BlockQ8K> =
read_to_vec(&buffer, elem_count / self.dtype.block_size());
use crate::quantized::k_quants::GgmlType;
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
}
}
let buffer = self.device.new_buffer_with_data(&out)?;
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
}
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
// Quantization only happens on CPU for now.
let src = src.to_cpu::<f32>()?;
let elem_count = src.len();
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
let mut qcpu_storage = crate::Device::Cpu.qzeros(elem_count, self.dtype)?;
qcpu_storage.quantize(&src)?;
let buffer = self.device.new_buffer_with_data(&qcpu_storage.data()?)?;
self.buffer = buffer;
Ok(())
}
}
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
device: &MetalDevice,
data: &[T],
) -> Result<QStorage> {
let buffer = device.new_buffer_with_data(data)?;
let device = device.clone();
Ok(QStorage::Metal(QMetalStorage {
dtype: T::DTYPE,
device,
buffer,
}))
}
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}

View File

@ -1,125 +1,23 @@
#[cfg(feature = "metal")] use crate::{Device, Result, Shape, Tensor};
use crate::{backend::BackendStorage, DType};
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use k_quants::*;
use std::borrow::Cow;
#[cfg(target_feature = "avx")] #[cfg(target_feature = "avx")]
pub mod avx; pub mod avx;
pub mod ggml_file; pub mod ggml_file;
pub mod gguf_file; pub mod gguf_file;
pub mod k_quants; pub mod k_quants;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
pub mod neon; pub mod neon;
#[cfg(target_feature = "simd128")] #[cfg(target_feature = "simd128")]
pub mod simd128; pub mod simd128;
pub mod utils; pub mod utils;
use half::f16;
pub use k_quants::GgmlType; pub use k_quants::GgmlType;
pub struct QTensor { pub struct QTensor {
storage: QStorage, data: Box<dyn QuantizedType>,
shape: Shape, shape: Shape,
} }
impl Device {
fn qzeros(&self, elem_count: usize, dtype: GgmlDType) -> Result<QStorage> {
match self {
Device::Cpu => {
let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage))
}
#[cfg(feature = "metal")]
Device::Metal(metal) => {
let size = elem_count * dtype.type_size() / dtype.block_size();
let buffer = metal.allocate_zeros(size)?;
Ok(QStorage::Metal(metal::QMetalStorage::new(
buffer,
metal.clone(),
dtype,
)))
}
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal feature not activated");
}
Device::Cuda(_cuda) => {
crate::bail!("Cuda ggml quantization not supported");
}
}
}
}
pub enum QStorage {
Cpu(Box<dyn QuantizedType>),
#[cfg(feature = "metal")]
Metal(metal::QMetalStorage),
}
impl QStorage {
fn block_size(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.block_size(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.dtype().block_size(),
}
}
fn dtype(&self) -> GgmlDType {
match self {
QStorage::Cpu(storage) => storage.dtype(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.dtype(),
}
}
fn size_in_bytes(&self) -> usize {
match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.buffer().length() as usize,
}
}
fn quantize(&mut self, src: &Storage) -> Result<()> {
match (self, src) {
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?;
}
#[cfg(feature = "metal")]
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"),
}
Ok(())
}
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
}
}
fn data(&self) -> Result<Cow<[u8]>> {
match self {
QStorage::Cpu(storage) => {
let data_ptr = storage.as_ptr();
let size_in_bytes = storage.storage_size_in_bytes();
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data))
}
#[cfg(feature = "metal")]
QStorage::Metal(_storage) => {
crate::bail!("not implemented");
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum GgmlDType { pub enum GgmlDType {
F32, F32,
@ -179,25 +77,6 @@ impl GgmlDType {
} }
} }
/// The block dtype
pub fn cpu_zeros(&self, elem_count: usize) -> Box<dyn QuantizedType> {
match self {
Self::F32 => Box::new(vec![f32::zeros(); elem_count]),
Self::F16 => Box::new(vec![f16::zeros(); elem_count]),
Self::Q4_0 => Box::new(vec![BlockQ4_0::zeros(); elem_count / BlockQ4_0::BLCK_SIZE]),
Self::Q4_1 => Box::new(vec![BlockQ4_1::zeros(); elem_count / BlockQ4_1::BLCK_SIZE]),
Self::Q5_0 => Box::new(vec![BlockQ5_0::zeros(); elem_count / BlockQ5_0::BLCK_SIZE]),
Self::Q5_1 => Box::new(vec![BlockQ5_1::zeros(); elem_count / BlockQ5_1::BLCK_SIZE]),
Self::Q8_0 => Box::new(vec![BlockQ8_0::zeros(); elem_count / BlockQ8_0::BLCK_SIZE]),
Self::Q8_1 => Box::new(vec![BlockQ8_1::zeros(); elem_count / BlockQ8_1::BLCK_SIZE]),
Self::Q2K => Box::new(vec![BlockQ2K::zeros(); elem_count / BlockQ2K::BLCK_SIZE]),
Self::Q3K => Box::new(vec![BlockQ3K::zeros(); elem_count / BlockQ3K::BLCK_SIZE]),
Self::Q4K => Box::new(vec![BlockQ4K::zeros(); elem_count / BlockQ4K::BLCK_SIZE]),
Self::Q5K => Box::new(vec![BlockQ5K::zeros(); elem_count / BlockQ5K::BLCK_SIZE]),
Self::Q6K => Box::new(vec![BlockQ6K::zeros(); elem_count / BlockQ6K::BLCK_SIZE]),
Self::Q8K => Box::new(vec![BlockQ8K::zeros(); elem_count / BlockQ8K::BLCK_SIZE]),
}
}
/// The type size for blocks in bytes. /// The type size for blocks in bytes.
pub fn type_size(&self) -> usize { pub fn type_size(&self) -> usize {
use k_quants::*; use k_quants::*;
@ -221,7 +100,7 @@ impl GgmlDType {
} }
/// The block size, i.e. the number of elements stored in each block. /// The block size, i.e. the number of elements stored in each block.
pub fn block_size(&self) -> usize { pub fn blck_size(&self) -> usize {
match self { match self {
Self::F32 => 1, Self::F32 => 1,
Self::F16 => 1, Self::F16 => 1,
@ -240,13 +119,9 @@ impl GgmlDType {
pub trait QuantizedType: Send + Sync { pub trait QuantizedType: Send + Sync {
fn dtype(&self) -> GgmlDType; fn dtype(&self) -> GgmlDType;
fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>; fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()>;
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage>; fn to_float(&self, ys: &mut [f32]) -> Result<()>;
fn storage_size_in_bytes(&self) -> usize; fn storage_size_in_bytes(&self) -> usize;
fn as_ptr(&self) -> *const u8; fn as_ptr(&self) -> *const u8;
fn block_size(&self) -> usize;
#[allow(clippy::wrong_self_convention)]
fn from_float(&mut self, xs: &[f32]) -> Result<()>;
fn size(&self) -> usize;
} }
impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> { impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
@ -254,26 +129,12 @@ impl<T: k_quants::GgmlType + Send + Sync> QuantizedType for Vec<T> {
k_quants::matmul(mkn, lhs, self.as_slice(), dst) k_quants::matmul(mkn, lhs, self.as_slice(), dst)
} }
fn size(&self) -> usize {
self.len() * core::mem::size_of::<T>()
}
fn from_float(&mut self, xs: &[f32]) -> Result<()> {
T::from_float(xs, self)
}
fn dtype(&self) -> GgmlDType { fn dtype(&self) -> GgmlDType {
T::DTYPE T::DTYPE
} }
fn block_size(&self) -> usize { fn to_float(&self, ys: &mut [f32]) -> Result<()> {
T::BLCK_SIZE T::to_float(self.as_slice(), ys)
}
fn dequantize(&self, elem_count: usize) -> Result<CpuStorage> {
let mut ys = vec![0.0f32; elem_count];
T::to_float(self.as_slice(), &mut ys)?;
Ok(CpuStorage::F32(ys))
} }
fn storage_size_in_bytes(&self) -> usize { fn storage_size_in_bytes(&self) -> usize {
@ -291,49 +152,56 @@ impl std::fmt::Debug for QTensor {
} }
} }
fn check_shape(shape: &Shape, block_size: usize) -> Result<()> { fn check_shape<T: k_quants::GgmlType>(shape: &Shape) -> Result<()> {
let dims = shape.dims(); let dims = shape.dims();
if dims.is_empty() { if dims.is_empty() {
crate::bail!("scalar tensor cannot be quantized {shape:?}") crate::bail!("scalar tensor cannot be quantized {shape:?}")
} }
if dims[dims.len() - 1] % block_size != 0 { if dims[dims.len() - 1] % T::BLCK_SIZE != 0 {
crate::bail!( crate::bail!(
"quantized tensor must have their last dim divisible by block size {shape:?} {}", "quantized tensor must have their last dim divisible by block size {shape:?} {}",
block_size T::BLCK_SIZE
) )
} }
Ok(()) Ok(())
} }
impl QTensor { impl QTensor {
pub fn new<S: Into<Shape>>(storage: QStorage, shape: S) -> Result<Self> { pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
) -> Result<Self> {
let shape = shape.into(); let shape = shape.into();
check_shape(&shape, storage.block_size())?; check_shape::<T>(&shape)?;
Ok(Self { storage, shape }) Ok(Self {
data: Box::new(data),
shape,
})
} }
pub fn quantize(src: &Tensor, dtype: GgmlDType) -> Result<Self> { pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape(); let shape = src.shape();
let block_size = dtype.block_size(); check_shape::<T>(shape)?;
check_shape(shape, block_size)?; let src = src
let src = src.to_dtype(crate::DType::F32)?.flatten_all()?; .to_dtype(crate::DType::F32)?
let elem_count = shape.elem_count(); .flatten_all()?
if elem_count % block_size != 0 { .to_vec1::<f32>()?;
if src.len() % T::BLCK_SIZE != 0 {
crate::bail!( crate::bail!(
"tensor size ({shape:?}) is not divisible by block size {}", "tensor size ({shape:?}) is not divisible by block size {}",
block_size T::BLCK_SIZE
) )
} }
let mut storage = src.device().qzeros(elem_count, dtype)?; let mut data = vec![T::zeros(); src.len() / T::BLCK_SIZE];
storage.quantize(&src.storage())?; T::from_float(&src, &mut data)?;
Ok(Self { Ok(Self {
storage, data: Box::new(data),
shape: shape.clone(), shape: shape.clone(),
}) })
} }
pub fn dtype(&self) -> GgmlDType { pub fn dtype(&self) -> GgmlDType {
self.storage.dtype() self.data.dtype()
} }
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
@ -345,19 +213,21 @@ impl QTensor {
} }
pub fn dequantize(&self, device: &Device) -> Result<Tensor> { pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
let storage = self.storage.dequantize(self.shape.elem_count())?; let mut f32_data = vec![0f32; self.shape.elem_count()];
let none = crate::op::BackpropOp::none(); self.data.to_float(&mut f32_data)?;
let is_variable = false; Tensor::from_vec(f32_data, &self.shape, device)
crate::tensor::from_storage(storage, self.shape.clone(), none, is_variable) }
.to_device(device)
pub fn matmul_t(&self, mkn: (usize, usize, usize), lhs: &[f32], dst: &mut [f32]) -> Result<()> {
self.data.matmul_t(mkn, lhs, dst)
} }
pub fn storage_size_in_bytes(&self) -> usize { pub fn storage_size_in_bytes(&self) -> usize {
self.storage.size_in_bytes() self.data.storage_size_in_bytes()
} }
pub fn data(&self) -> Result<Cow<'_, [u8]>> { pub fn as_ptr(&self) -> *const u8 {
self.storage.data() self.data.as_ptr()
} }
} }
@ -424,93 +294,17 @@ impl crate::CustomOp1 for QTensor {
} }
dst_shape.push(n); dst_shape.push(n);
let dst_shape = Shape::from(dst_shape); let dst_shape = Shape::from(dst_shape);
#[allow(clippy::infallible_destructuring_match)] let storage = storage.as_slice::<f32>()?;
let self_storage = match &self.storage { let storage =
QStorage::Cpu(storage) => storage, &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
#[cfg(feature = "metal")]
_ => crate::bail!("Invalid storage"),
};
let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let mut dst_storage = vec![0f32; dst_shape.elem_count()]; let mut dst_storage = vec![0f32; dst_shape.elem_count()];
self_storage.matmul_t((dst_shape.elem_count() / n, k, n), slice, &mut dst_storage)?; self.matmul_t(
(dst_shape.elem_count() / n, k, n),
storage,
&mut dst_storage,
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
} }
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
use crate::MetalError;
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self.shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let (buffer, dtype) = match &self.storage {
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
};
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
Ok((dst_storage, dst_shape))
}
}
#[cfg(feature = "metal")]
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
}
} }
impl crate::Module for QMatMul { impl crate::Module for QMatMul {

View File

@ -12,6 +12,14 @@ use core::arch::arm::*;
#[cfg(target_arch = "aarch64")] #[cfg(target_arch = "aarch64")]
use core::arch::aarch64::*; use core::arch::aarch64::*;
#[inline(always)]
unsafe fn vdotq_s32(a: int8x16_t, b: int8x16_t) -> int32x4_t {
// TODO: dotprod
let p0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
let p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1))
}
#[inline(always)] #[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> { pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0; let qk = QK8_0;
@ -43,15 +51,8 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let v1_0l = vld1q_s8(y0.qs.as_ptr()); let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16)); let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly. let pl0 = vdotq_s32(v0_0ls, v1_0l);
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l)); let ph0 = vdotq_s32(v0_0hs, v1_0h);
let pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
sumv0 = vmlaq_n_f32( sumv0 = vmlaq_n_f32(
sumv0, sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)), vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
@ -82,14 +83,8 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let y0_0 = vld1q_s8(y0.qs.as_ptr()); let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16)); let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are. let p0 = vdotq_s32(x0_0, y0_0);
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0)); let p1 = vdotq_s32(x0_1, y0_1);
let p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0));
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
sumv0 = vmlaq_n_f32( sumv0 = vmlaq_n_f32(
sumv0, sumv0,
@ -118,10 +113,7 @@ pub(crate) fn vec_dot_q8k_q8k(n: usize, xs: &[BlockQ8K], ys: &[BlockQ8K]) -> Res
for i in (0..QK_K).step_by(16) { for i in (0..QK_K).step_by(16) {
let xs = vld1q_s8(xs.add(i)); let xs = vld1q_s8(xs.add(i));
let ys = vld1q_s8(ys.add(i)); let ys = vld1q_s8(ys.add(i));
let xy_lo = vmull_s8(vget_low_s8(xs), vget_low_s8(ys)); let xy = vdotq_s32(xs, ys);
let xy_up = vmull_s8(vget_high_s8(xs), vget_high_s8(ys));
let xy = vaddq_s32(vpaddlq_s16(xy_lo), vpaddlq_s16(xy_up));
sum_i = vaddq_s32(sum_i, xy) sum_i = vaddq_s32(sum_i, xy)
} }
sumf += vaddvq_s32(sum_i) as f32 * scale sumf += vaddvq_s32(sum_i) as f32 * scale
@ -191,30 +183,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2)); let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.2, m4b), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.3, m4b), q6h_3));
// TODO: dotprod let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
let p0 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2); scale = scale.add(2);
let p2 = vaddq_s16( let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2); scale = scale.add(2);
let q8bytes = vld1q_s8_x4(q8); let q8bytes = vld1q_s8_x4(q8);
@ -234,29 +212,16 @@ pub(crate) fn vec_dot_q6k_q8k(n: usize, xs: &[BlockQ6K], ys: &[BlockQ8K]) -> Res
let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2)); let q6bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.2, 4), q6h_2));
let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3)); let q6bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.3, 4), q6h_3));
// TODO: dotprod case. let p0 = vdotq_s32(q6bytes_0, q8bytes.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q6bytes_1, q8bytes.1);
vmull_s8(vget_low_s8(q6bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q6bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q6bytes_1), vget_high_s8(q8bytes.1)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p0) as i32 * scale0 + vaddvq_s16(p1) as i32 * scale1; isum += vaddvq_s32(p0) * scale0 + vaddvq_s32(p1) * scale1;
scale = scale.add(2); scale = scale.add(2);
let p2 = vaddq_s16( let p2 = vdotq_s32(q6bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q6bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q6bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q6bytes_2), vget_high_s8(q8bytes.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q6bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q6bytes_3), vget_high_s8(q8bytes.3)),
);
let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32); let (scale0, scale1) = (*scale as i32, *scale.add(1) as i32);
isum += vaddvq_s16(p2) as i32 * scale0 + vaddvq_s16(p3) as i32 * scale1; isum += vaddvq_s32(p2) * scale0 + vaddvq_s32(p3) * scale1;
scale = scale.add(2); scale = scale.add(2);
} }
sum += d_all * y.d * ((isum - 32 * isum_mins) as f32); sum += d_all * y.d * ((isum - 32 * isum_mins) as f32);
@ -333,28 +298,14 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2)); let q5bytes_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.0, 4), q5h_2));
let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3)); let q5bytes_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.1, 4), q5h_3));
// TODO: dotprod let p0 = vdotq_s32(q5bytes_0, q8bytes.0);
let p1 = vdotq_s32(q5bytes_1, q8bytes.1);
let p0 = vaddq_s16( sumi += vaddvq_s32(vaddq_s32(p0, p1)) * *scales as i32;
vmull_s8(vget_low_s8(q5bytes_0), vget_low_s8(q8bytes.0)),
vmull_s8(vget_high_s8(q5bytes_0), vget_high_s8(q8bytes.0)),
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q5bytes_1), vget_high_s8(q8bytes.1)),
);
sumi += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * *scales as i32;
scales = scales.add(1); scales = scales.add(1);
let p2 = vaddq_s16( let p2 = vdotq_s32(q5bytes_2, q8bytes.2);
vmull_s8(vget_low_s8(q5bytes_2), vget_low_s8(q8bytes.2)), let p3 = vdotq_s32(q5bytes_3, q8bytes.3);
vmull_s8(vget_high_s8(q5bytes_2), vget_high_s8(q8bytes.2)), sumi += vaddvq_s32(vaddq_s32(p2, p3)) * *scales as i32;
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q5bytes_3), vget_low_s8(q8bytes.3)),
vmull_s8(vget_high_s8(q5bytes_3), vget_high_s8(q8bytes.3)),
);
sumi += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * *scales as i32;
scales = scales.add(1); scales = scales.add(1);
} }
sumf += d * sumi as f32 - dmin * sumi_mins as f32; sumf += d * sumi as f32 - dmin * sumi_mins as f32;
@ -417,22 +368,15 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
for j in 0..QK_K / 64 { for j in 0..QK_K / 64 {
let q4bits = vld1q_u8_x2(q4); let q4bits = vld1q_u8_x2(q4);
q4 = q4.add(32); q4 = q4.add(32);
// TODO: dotprod
let q8bytes = vld1q_s8_x2(q8); let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32); q8 = q8.add(32);
let q4bytes = int8x16x2_t( let q4bytes = int8x16x2_t(
vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.0, m4b)),
vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)), vreinterpretq_s8_u8(vandq_u8(q4bits.1, m4b)),
); );
let p0 = vaddq_s16( let p0 = vdotq_s32(q4bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), let p1 = vdotq_s32(q4bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), sumi1 += vaddvq_s32(vaddq_s32(p0, p1)) * scales[2 * j] as i32;
);
let p1 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) as i32 * scales[2 * j] as i32;
let q8bytes = vld1q_s8_x2(q8); let q8bytes = vld1q_s8_x2(q8);
q8 = q8.add(32); q8 = q8.add(32);
@ -440,15 +384,9 @@ pub(crate) fn vec_dot_q4k_q8k(n: usize, xs: &[BlockQ4K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.0, 4)),
vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)), vreinterpretq_s8_u8(vshrq_n_u8(q4bits.1, 4)),
); );
let p2 = vaddq_s16( let p2 = vdotq_s32(q4bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q4bytes.0), vget_low_s8(q8bytes.0)), let p3 = vdotq_s32(q4bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q4bytes.0), vget_high_s8(q8bytes.0)), sumi2 += vaddvq_s32(vaddq_s32(p2, p3)) * scales[2 * j + 1] as i32;
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q4bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q4bytes.1), vget_high_s8(q8bytes.1)),
);
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) as i32 * scales[2 * j + 1] as i32;
} }
sumf += d * (sumi1 + sumi2) as f32; sumf += d * (sumi1 + sumi2) as f32;
} }
@ -526,27 +464,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3), vreinterpretq_s8_u8(q3h_3),
); );
// TODO: dotprod let p0 = vdotq_s32(q3bytes_0, q8bytes_1.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q3bytes_1, q8bytes_1.1);
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_1.0)), let p2 = vdotq_s32(q3bytes_2, q8bytes_1.2);
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_1.0)), let p3 = vdotq_s32(q3bytes_3, q8bytes_1.3);
); isum += vaddvq_s32(p0) * *scale as i32
let p1 = vaddq_s16( + vaddvq_s32(p1) * *scale.add(1) as i32
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_1.1)), + vaddvq_s32(p2) * *scale.add(2) as i32
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_1.1)), + vaddvq_s32(p3) * *scale.add(3) as i32;
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_1.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_1.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_1.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_1.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
scale = scale.add(4); scale = scale.add(4);
let q3h_0 = vbicq_u8(m2, qhbits.0); let q3h_0 = vbicq_u8(m2, qhbits.0);
@ -571,27 +496,14 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
vreinterpretq_s8_u8(q3h_3), vreinterpretq_s8_u8(q3h_3),
); );
// TODO: dotprod let p0 = vdotq_s32(q3bytes_0, q8bytes_2.0);
let p0 = vaddq_s16( let p1 = vdotq_s32(q3bytes_1, q8bytes_2.1);
vmull_s8(vget_low_s8(q3bytes_0), vget_low_s8(q8bytes_2.0)), let p2 = vdotq_s32(q3bytes_2, q8bytes_2.2);
vmull_s8(vget_high_s8(q3bytes_0), vget_high_s8(q8bytes_2.0)), let p3 = vdotq_s32(q3bytes_3, q8bytes_2.3);
); isum += vaddvq_s32(p0) * *scale as i32
let p1 = vaddq_s16( + vaddvq_s32(p1) * *scale.add(1) as i32
vmull_s8(vget_low_s8(q3bytes_1), vget_low_s8(q8bytes_2.1)), + vaddvq_s32(p2) * *scale.add(2) as i32
vmull_s8(vget_high_s8(q3bytes_1), vget_high_s8(q8bytes_2.1)), + vaddvq_s32(p3) * *scale.add(3) as i32;
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_2), vget_low_s8(q8bytes_2.2)),
vmull_s8(vget_high_s8(q3bytes_2), vget_high_s8(q8bytes_2.2)),
);
let p3 = vaddq_s16(
vmull_s8(vget_low_s8(q3bytes_3), vget_low_s8(q8bytes_2.3)),
vmull_s8(vget_high_s8(q3bytes_3), vget_high_s8(q8bytes_2.3)),
);
isum += vaddvq_s16(p0) as i32 * *scale as i32
+ vaddvq_s16(p1) as i32 * *scale.add(1) as i32
+ vaddvq_s16(p2) as i32 * *scale.add(2) as i32
+ vaddvq_s16(p3) as i32 * *scale.add(3) as i32;
scale = scale.add(4); scale = scale.add(4);
if j == 0 { if j == 0 {
@ -649,7 +561,6 @@ pub(crate) fn vec_dot_q2k_q8k(n: usize, xs: &[BlockQ2K], ys: &[BlockQ8K]) -> Res
let mut is = 0usize; let mut is = 0usize;
// TODO: dotprod // TODO: dotprod
for _j in 0..QK_K / 128 { for _j in 0..QK_K / 128 {
let q2bits = vld1q_u8_x2(q2); let q2bits = vld1q_u8_x2(q2);
q2 = q2.add(32); q2 = q2.add(32);
@ -696,14 +607,7 @@ unsafe fn multiply_accum_with_scale(
q2bytes: int8x16x2_t, q2bytes: int8x16x2_t,
q8bytes: int8x16x2_t, q8bytes: int8x16x2_t,
) -> i32 { ) -> i32 {
let p1 = vaddq_s16( let p1 = vdotq_s32(q2bytes.0, q8bytes.0);
vmull_s8(vget_low_s8(q2bytes.0), vget_low_s8(q8bytes.0)), let p2 = vdotq_s32(q2bytes.1, q8bytes.1);
vmull_s8(vget_high_s8(q2bytes.0), vget_high_s8(q8bytes.0)), vaddvq_s32(p1) * aux[is + index] as i32 + vaddvq_s32(p2) * aux[is + 1 + index] as i32
);
let p2 = vaddq_s16(
vmull_s8(vget_low_s8(q2bytes.1), vget_low_s8(q8bytes.1)),
vmull_s8(vget_high_s8(q2bytes.1), vget_high_s8(q8bytes.1)),
);
vaddvq_s16(p1) as i32 * aux[is + index] as i32
+ vaddvq_s16(p2) as i32 * aux[is + 1 + index] as i32
} }

View File

@ -426,9 +426,7 @@ impl Tensor {
if buffer_size != shape.elem_count() { if buffer_size != shape.elem_count() {
return Err(Error::ShapeMismatch { buffer_size, shape }.bt()); return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
} }
// println!("from vec {buffer_size}");
let storage = device.storage_owned(data)?; let storage = device.storage_owned(data)?;
// println!("Created storage");
let none = BackpropOp::none(); let none = BackpropOp::none();
Ok(from_storage(storage, shape, none, is_variable)) Ok(from_storage(storage, shape, none, is_variable))
} }

View File

@ -1,6 +1,6 @@
use candle_core::{ use candle_core::{
bail,
quantized::{self, GgmlDType}, quantized::{self, GgmlDType},
test_device,
test_utils::to_vec2_round, test_utils::to_vec2_round,
Device, Module, Result, Tensor, Device, Module, Result, Tensor,
}; };
@ -14,44 +14,16 @@ const GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS: f32 = 0.0075;
const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040; const GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS: f32 = 0.0040;
const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02; const GGML_MAX_DOT_PRODUCT_ERROR: f32 = 0.02;
fn test_matmul( #[test]
device: &Device, fn quantized_matmul() -> Result<()> {
(b, m, n, k): (usize, usize, usize, usize), let cpu = &Device::Cpu;
dtype: GgmlDType,
) -> Result<()> {
let lhs = (0..(m * k))
.map(|v| v as f32 / (m * k) as f32)
.collect::<Vec<_>>();
let rhs = (0..(k * n))
.map(|v| v as f32 / (n * k) as f32)
.collect::<Vec<_>>();
let lhs = Tensor::from_slice(&lhs, (m, k), device)?;
let rhs = Tensor::from_slice(&rhs, (k, n), device)?;
let mm = lhs.matmul(&rhs)?;
let qtensor = quantized::QTensor::quantize(&rhs.t()?, dtype)?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&lhs)?;
let error: f32 = ((&mm - &res)?.abs()? / &mm.abs()?)?
.sum_all()?
.to_scalar()?;
let error = error / (b * m * n) as f32;
assert!(
error <= 0.02,
"Error {error} is too big. \nExpected:\n {mm} \nFound:\n {res}\n for {dtype:?}"
);
Ok(())
}
fn quantized_matmul(device: &Device) -> Result<()> {
let (m, k, n) = (3, 64, 4); let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>(); let lhs = (0..(m * k)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4]; let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>(); let rhs = (0..(k * n)).map(|v| v as f32).collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!( assert_eq!(
@ -61,7 +33,6 @@ fn quantized_matmul(device: &Device) -> Result<()> {
341876.0, 994283.0, 1655709.0, 2301518.0 341876.0, 994283.0, 1655709.0, 2301518.0
] ]
); );
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?;
let mm = tensor_lhs.matmul(&tensor_rhs)?; let mm = tensor_lhs.matmul(&tensor_rhs)?;
assert_eq!( assert_eq!(
mm.to_vec2::<f32>()?, mm.to_vec2::<f32>()?,
@ -72,45 +43,35 @@ fn quantized_matmul(device: &Device) -> Result<()> {
] ]
); );
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
match device { assert_eq!(
Device::Metal(_) => assert_eq!( to_vec2_round(&res, 0)?,
to_vec2_round(&res, 0)?, &[
&[ [85120.0, 214562.0, 345455.0, 474748.0],
[84946.0, 214126.0, 344757.0, 473798.0], [213475.0, 604465.0, 1000686.0, 1388317.0],
[213458.0, 604350.0, 1000469.0, 1387990.0], [341876.0, 994283.0, 1655709.0, 2301518.0]
[341970.0, 994574.0, 1656181.0, 2302182.0] ]
] );
),
_ => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[85120.0, 214562.0, 345455.0, 474748.0],
[213475.0, 604465.0, 1000686.0, 1388317.0],
[341876.0, 994283.0, 1655709.0, 2301518.0]
]
),
}
test_matmul(device, (1, 3, 4, 256), GgmlDType::Q4_0)?;
Ok(()) Ok(())
} }
fn quantized_matmul_neg(device: &Device) -> Result<()> { #[test]
fn quantized_matmul_neg() -> Result<()> {
let cpu = &Device::Cpu;
let (m, k, n) = (3, 64, 4); let (m, k, n) = (3, 64, 4);
let lhs = (0..(m * k)) let lhs = (0..(m * k))
.map(|v| v as f32 - (m * k) as f32 / 2.0) .map(|v| v as f32 - (m * k) as f32 / 2.0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let tensor_lhs = Tensor::from_slice(&lhs, (m, k), device)?; let tensor_lhs = Tensor::from_slice(&lhs, (m, k), cpu)?;
let mut dst = vec![42.; 3 * 4]; let mut dst = vec![42.; 3 * 4];
let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8]; let mut rhs_t = vec![k_quants::BlockQ4_0::zeros(); 8];
let rhs = (0..k * n) let rhs = (0..k * n)
.map(|v| v as f32 - (k * n) as f32 / 3.0) .map(|v| v as f32 - (k * n) as f32 / 3.0)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let tensor_rhs = Tensor::from_slice(&rhs, (n, k), device)?.t()?; let tensor_rhs = Tensor::from_slice(&rhs, (n, k), cpu)?.t()?;
k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?; k_quants::BlockQ4_0::from_float(&rhs, &mut rhs_t)?;
k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?; k_quants::matmul((m, k, n), &lhs, &rhs_t, &mut dst)?;
assert_eq!( assert_eq!(
@ -130,52 +91,32 @@ fn quantized_matmul_neg(device: &Device) -> Result<()> {
] ]
); );
let qtensor = quantized::QTensor::quantize(&tensor_rhs.t()?, GgmlDType::Q4_0)?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
match device { assert_eq!(
Device::Metal(_) => assert_eq!( to_vec2_round(&res, 0)?,
to_vec2_round(&res, 0)?, &[
&[ [243524.0, -19596.0, -285051.0, -549815.0],
[243666.0, -19714.0, -285433.0, -550453.0], [23777.0, 21651.0, 19398.0, 18367.0],
[23782.0, 21654.0, 19400.0, 18369.0], [-196472.0, 63012.0, 324585.0, 587902.0]
[-196102.0, 63022.0, 324233.0, 587191.0] ]
] );
),
_ => assert_eq!(
to_vec2_round(&res, 0)?,
&[
[243524.0, -19596.0, -285051.0, -549815.0],
[23777.0, 21651.0, 19398.0, 18367.0],
[-196472.0, 63012.0, 324585.0, 587902.0]
]
),
}
Ok(()) Ok(())
} }
test_device!( #[test]
quantized_matmul, fn quantize_q4_0() -> Result<()> {
quantized_matmul_cpu, use k_quants::BlockQ4_0;
quantized_matmul_cuda,
quantized_matmul_metal
);
test_device!(
quantized_matmul_neg,
quantized_matmul_neg_cpu,
quantized_matmul_neg_cuda,
quantized_matmul_neg_metal
);
fn quantize_q4_0(device: &Device) -> Result<()> {
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let mut dst = vec![0f32; 32 * 4];
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut quant = vec![BlockQ4_0::zeros(); 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_0)?; BlockQ4_0::from_float(&src, &mut quant)?;
let dst = quant.dequantize(device)?; BlockQ4_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
dst.to_vec1::<f32>()?, dst,
&[ &[
-0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625, -0.0, -0.0, 3.875, 3.875, 3.875, 3.875, 7.75, 7.75, 7.75, 7.75, 11.625, 11.625, 11.625,
11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25, 11.625, 15.5, 15.5, 15.5, 15.5, 19.375, 19.375, 19.375, 19.375, 23.25, 23.25, 23.25,
@ -191,17 +132,21 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
127.0, 127.0 127.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q4_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q4_1(device: &Device) -> Result<()> { #[test]
fn quantize_q4_1() -> Result<()> {
use k_quants::BlockQ4_1;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?; let mut quant = vec![BlockQ4_1::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ4_1::from_float(&src, &mut quant)?;
BlockQ4_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst),
&[ &[
0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332, 0.0, 0.0, 2.066, 2.066, 4.133, 4.133, 6.199, 6.199, 8.266, 8.266, 10.332, 10.332,
12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73, 12.398, 12.398, 14.465, 14.465, 16.531, 16.531, 18.598, 18.598, 20.664, 20.664, 22.73,
@ -217,17 +162,21 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996 118.73, 118.73, 120.797, 120.797, 122.863, 122.863, 124.93, 124.93, 126.996, 126.996
] ]
); );
ggml_quantization_error_test(GgmlDType::Q4_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5_0(device: &Device) -> Result<()> { #[test]
fn quantize_q5_0() -> Result<()> {
use k_quants::BlockQ5_0;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?; let mut quant = vec![BlockQ5_0::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ5_0::from_float(&src, &mut quant)?;
BlockQ5_0::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), round_vector(&dst),
&[ &[
-0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625, -0.0, 1.938, 1.938, 3.875, 3.875, 5.813, 5.813, 7.75, 7.75, 9.688, 9.688, 11.625,
11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313, 11.625, 13.563, 13.563, 15.5, 15.5, 17.438, 17.438, 19.375, 19.375, 21.313, 21.313,
@ -243,17 +192,21 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0 119.063, 119.063, 119.063, 119.063, 127.0, 127.0, 127.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q5_0, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ5_0>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5_1(device: &Device) -> Result<()> { #[test]
fn quantize_q5_1() -> Result<()> {
use k_quants::BlockQ5_1;
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>(); let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
let src = Tensor::from_slice(&src, (32 * 4,), device)?; let mut dst = vec![0f32; 32 * 4];
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?; let mut quant = vec![BlockQ5_1::zeros(); 4];
let dst = quant.dequantize(device)?; BlockQ5_1::from_float(&src, &mut quant)?;
BlockQ5_1::to_float(&quant, dst.as_mut_slice())?;
assert_eq!( assert_eq!(
round_vector(&dst.to_vec1::<f32>()?), dst,
&[ &[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0,
@ -267,11 +220,13 @@ fn quantize_q5_1(device: &Device) -> Result<()> {
124.0, 125.0, 126.0, 127.0 124.0, 125.0, 126.0, 127.0
] ]
); );
ggml_quantization_error_test(GgmlDType::Q5_1, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
ggml_quantization_error_test::<BlockQ5_1>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor> { /// Generates a small test vector ranging from -`bound` to `bound` with `size` steps
fn get_test_vector(bound: f32, size: usize) -> (Vec<f32>, Vec<f32>) {
assert!( assert!(
size % crate::quantized::k_quants::QK_K == 0, size % crate::quantized::k_quants::QK_K == 0,
"size must be a multiple of {}", "size must be a multiple of {}",
@ -281,8 +236,10 @@ fn get_test_vector2(bound: f32, size: usize, device: &Device) -> Result<Tensor>
let src = (0..size) let src = (0..size)
.map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.)) .map(|v| (v as f32 - size as f32 / 2.) * bound / (size as f32 / 2.))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let dst = vec![0f32; size];
assert_eq!([src[0], src[size / 2]], [-bound, 0.0]); assert_eq!([src[0], src[size / 2]], [-bound, 0.0]);
Tensor::from_vec(src, (size,), device) (src, dst)
} }
/// Round a vector /// Round a vector
@ -309,7 +266,8 @@ fn compare_with_error(values: &[f32], expected: &[f32], tolerance: f32) {
} }
} }
/// Creates a vector simillarly to the one used in GGML unit tests: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30 /// Creates a vector similar to the ones used in GGML unit tests:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L26-L30
fn create_ggml_like_vector(offset: f32) -> Vec<f32> { fn create_ggml_like_vector(offset: f32) -> Vec<f32> {
(0..GGML_TEST_SIZE) (0..GGML_TEST_SIZE)
.map(|i| 0.1 + 2.0 * (i as f32 + offset).cos()) .map(|i| 0.1 + 2.0 * (i as f32 + offset).cos())
@ -328,15 +286,15 @@ fn calculate_rmse(a: &[f32], b: &[f32]) -> f32 {
sum / a.len() as f32 sum / a.len() as f32
} }
/// Mirrores the GGML quanitzation unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50 /// Similar to the GGML quantization unit test:
fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f32) -> Result<()> { /// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L43-L50
fn ggml_quantization_error_test<T: GgmlType>(max_error: f32) -> Result<()> {
let src = create_ggml_like_vector(0.0); let src = create_ggml_like_vector(0.0);
let src = Tensor::from_slice(&src, (GGML_TEST_SIZE,), device)?; let mut dst = vec![0.0; GGML_TEST_SIZE];
let quant = quantized::QTensor::quantize(&src, dtype)?; let _quant = quantize_roundtrip::<T>(src.as_slice(), dst.as_mut_slice())?;
let dst = quant.dequantize(device)?; let error = calculate_rmse(src.as_slice(), dst.as_slice());
let error = calculate_rmse(&src.to_vec1::<f32>()?, &dst.to_vec1::<f32>()?);
if error > max_error { if error > max_error {
candle_core::bail!( bail!(
"Quantization error {} exceeds max error {}", "Quantization error {} exceeds max error {}",
error, error,
max_error max_error
@ -345,15 +303,19 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
Ok(()) Ok(())
} }
fn quantize_q2k(device: &Device) -> Result<()> { fn quantize_roundtrip<T: GgmlType>(src: &[f32], dst: &mut [f32]) -> Result<Vec<T>> {
let dtype = GgmlDType::Q2K; let mut quant = vec![T::zeros(); src.len() / T::BLCK_SIZE];
T::from_float(src, &mut quant)?;
T::to_float(&quant, dst)?;
Ok(quant)
}
let src = get_test_vector2(0.5, 1024, device)?; #[test]
let quant = quantized::QTensor::quantize(&src, dtype)?; fn quantize_q2k() -> Result<()> {
let dst = quant.dequantize(device)?; use k_quants::BlockQ2K;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ2K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.1); compare_with_error(dst.as_slice(), src.as_slice(), 0.1);
// Test some specific values // Test some specific values
@ -367,26 +329,20 @@ fn quantize_q2k(device: &Device) -> Result<()> {
[-0.499, -0.366, -0.249, 0.0, 0.295, 0.492] [-0.499, -0.366, -0.249, 0.0, 0.295, 0.492]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ2K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 6.0);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?; ggml_quantization_error_test::<BlockQ2K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_2BITS)?;
Ok(()) Ok(())
} }
fn quantize_q3k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q3K; fn quantize_q3k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ3K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ3K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.03); compare_with_error(dst.as_slice(), src.as_slice(), 0.03);
// Test some specific values // Test some specific values
@ -400,26 +356,20 @@ fn quantize_q3k(device: &Device) -> Result<()> {
[-0.493, -0.37, -0.243, -0.0, 0.292, 0.492] [-0.493, -0.37, -0.243, -0.0, 0.292, 0.492]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ3K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 3.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?; ggml_quantization_error_test::<BlockQ3K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR_3BITS)?;
Ok(()) Ok(())
} }
fn quantize_q4k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q4K; fn quantize_q4k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ4K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ4K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.017); compare_with_error(dst.as_slice(), src.as_slice(), 0.017);
// Test some specific values // Test some specific values
@ -433,27 +383,21 @@ fn quantize_q4k(device: &Device) -> Result<()> {
[-0.5, -0.373, -0.25, 0.0, 0.288, 0.498] [-0.5, -0.373, -0.25, 0.0, 0.288, 0.498]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ4K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 4.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ4K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q5k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q5K; fn quantize_q5k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ5K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ5K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.009); compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values // Test some specific values
assert_eq!( assert_eq!(
@ -466,26 +410,21 @@ fn quantize_q5k(device: &Device) -> Result<()> {
[-0.499, -0.372, -0.249, 0.001, 0.279, 0.499] [-0.499, -0.372, -0.249, 0.001, 0.279, 0.499]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ5K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.5);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ5K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q6k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q6K; fn quantize_q6k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ6K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ6K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008); compare_with_error(dst.as_slice(), src.as_slice(), 0.008);
// Test some specific values // Test some specific values
@ -499,27 +438,22 @@ fn quantize_q6k(device: &Device) -> Result<()> {
[-0.497, -0.372, -0.25, -0.0, 0.284, 0.5] [-0.497, -0.372, -0.25, -0.0, 0.284, 0.5]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ6K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 2.0);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ6K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
fn quantize_q8k(device: &Device) -> Result<()> { #[test]
let dtype = GgmlDType::Q8K; fn quantize_q8k() -> Result<()> {
let src = get_test_vector2(0.5, 1024, device)?; use k_quants::BlockQ8K;
let quant = quantized::QTensor::quantize(&src, dtype)?;
let dst = quant.dequantize(device)?;
let src = src.to_vec1::<f32>()?; let (src, mut dst) = get_test_vector(0.5, 1024);
let dst = dst.to_vec1::<f32>()?; let _quant = quantize_roundtrip::<BlockQ8K>(src.as_slice(), dst.as_mut_slice())?;
compare_with_error(dst.as_slice(), src.as_slice(), 0.008); compare_with_error(dst.as_slice(), src.as_slice(), 0.003);
// Test some specific values // Test some specific values
assert_eq!( assert_eq!(
@ -532,79 +466,15 @@ fn quantize_q8k(device: &Device) -> Result<()> {
[-0.5, -0.375, -0.25, -0.0, 0.281, 0.499] [-0.5, -0.375, -0.25, -0.0, 0.281, 0.499]
); );
let src_big = get_test_vector2(128.0, 1024, device)?; let (src_big, mut dst_big) = get_test_vector(128.0, 1024);
let quant_big = quantized::QTensor::quantize(&src_big, dtype)?; let _quant_big = quantize_roundtrip::<BlockQ8K>(src_big.as_slice(), dst_big.as_mut_slice())?;
let dst_big = quant_big.dequantize(device)?;
let src_big = src_big.to_vec1::<f32>()?;
let dst_big = dst_big.to_vec1::<f32>()?;
compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6); compare_with_error(dst_big.as_slice(), src_big.as_slice(), 0.6);
ggml_quantization_error_test(dtype, device, GGML_MAX_QUANTIZATION_TOTAL_ERROR)?; ggml_quantization_error_test::<BlockQ8K>(GGML_MAX_QUANTIZATION_TOTAL_ERROR)?;
Ok(()) Ok(())
} }
test_device!(
quantize_q4_0,
quantize_q4_0_cpu,
quantize_q4_0_cuda,
quantize_q4_0_metal
);
test_device!(
quantize_q4_1,
quantize_q4_1_cpu,
quantize_q4_1_cuda,
quantize_q4_1_metal
);
test_device!(
quantize_q5_0,
quantize_q5_0_cpu,
quantize_q5_0_cuda,
quantize_q5_0_metal
);
test_device!(
quantize_q5_1,
quantize_q5_1_cpu,
quantize_q5_1_cuda,
quantize_q5_1_metal
);
test_device!(
quantize_q2k,
quantize_q2k_cpu,
quantize_q2k_cuda,
quantize_q2k_metal
);
test_device!(
quantize_q3k,
quantize_q3k_cpu,
quantize_q3k_cuda,
quantize_q3k_metal
);
test_device!(
quantize_q4k,
quantize_q4k_cpu,
quantize_q4k_cuda,
quantize_q4k_metal
);
test_device!(
quantize_q5k,
quantize_q5k_cpu,
quantize_q5k_cuda,
quantize_q5k_metal
);
test_device!(
quantize_q6k,
quantize_q6k_cpu,
quantize_q6k_cuda,
quantize_q6k_metal
);
test_device!(
quantize_q8k,
quantize_q8k_cpu,
quantize_q8k_cuda,
quantize_q8k_metal
);
/// Very simple dot product implementation /// Very simple dot product implementation
fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 { fn vec_dot_reference(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b).map(|(a, b)| a * b).sum() a.iter().zip(b).map(|(a, b)| a * b).sum()
@ -620,54 +490,66 @@ fn ggml_reference_matmul_error(dtype: GgmlDType) -> Result<f32> {
GgmlDType::Q5K => 0.000740, GgmlDType::Q5K => 0.000740,
GgmlDType::Q6K => 0.000952, GgmlDType::Q6K => 0.000952,
GgmlDType::Q4_0 => 0.001143, GgmlDType::Q4_0 => 0.001143,
GgmlDType::Q4_1 => 0.007784, GgmlDType::Q4_1 => 0.008,
GgmlDType::Q5_0 => 0.001353, GgmlDType::Q5_0 => 0.001353,
GgmlDType::Q5_1 => 0.001363, GgmlDType::Q5_1 => 0.00149,
GgmlDType::Q8_0 => 0.000092, GgmlDType::Q8_0 => 0.000092,
// Not from the ggml repo. // Not from the ggml repo.
GgmlDType::Q8K => 0.00065, GgmlDType::Q8K => 0.00065,
_ => candle_core::bail!("No GGML results for quantization type {dtype:?}",), _ => bail!("No GGML results for quantization type {dtype:?}",),
}; };
Ok(err) Ok(err)
} }
/// Mirrores the GGML matmul unit test: https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91 /// Similar to the GGML matmul unit test:
/// https://github.com/ggerganov/llama.cpp/blob/master/tests/test-quantize-fns.cpp#L76-L91
fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> { fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
let a = create_ggml_like_vector(0.0); let a = create_ggml_like_vector(0.0);
let b = create_ggml_like_vector(1.0); let b = create_ggml_like_vector(1.0);
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 1.0)?;
// Another example that is more likely to trigger the overflow reported in #1526
let a = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
let b = (0..GGML_TEST_SIZE)
.map(|i| i as f32 / GGML_TEST_SIZE as f32)
.collect::<Vec<_>>();
ggml_matmul_error_test_::<T>(a.as_slice(), b.as_slice(), 2.0)?;
Ok(())
}
fn ggml_matmul_error_test_<T: GgmlType>(a: &[f32], b: &[f32], err_m: f32) -> Result<()> {
let length = a.len(); let length = a.len();
let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE]; let mut a_quant = vec![T::zeros(); length / T::BLCK_SIZE];
let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE]; let mut b_quant = vec![T::VecDotType::zeros(); length / T::VecDotType::BLCK_SIZE];
T::from_float(&a, &mut a_quant)?; T::from_float(a, &mut a_quant)?;
T::VecDotType::from_float(&b, &mut b_quant)?; T::VecDotType::from_float(b, &mut b_quant)?;
let result = T::vec_dot(length, &a_quant, &b_quant)?; let result = T::vec_dot(length, &a_quant, &b_quant)?;
let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?; let result_unopt = T::vec_dot_unopt(length, &a_quant, &b_quant)?;
let reference_result = vec_dot_reference(&a, &b); let reference_result = vec_dot_reference(a, b);
if (result - result_unopt).abs() / length as f32 > 1e-6 { if (result - result_unopt).abs() / length as f32 > 1e-6 {
candle_core::bail!( bail!(
"the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}" "the opt and unopt vec-dot returned different values, opt {result}, unopt {result_unopt}"
) )
} }
let error = (result - reference_result).abs() / length as f32; let error = (result - reference_result).abs() / length as f32;
let ggml_error = ggml_reference_matmul_error(T::DTYPE)?; let ggml_error = ggml_reference_matmul_error(T::DTYPE)? * err_m;
if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR { if !error.is_finite() || error > GGML_MAX_DOT_PRODUCT_ERROR {
candle_core::bail!( bail!("Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",);
"Dot product error {error} exceeds max error {GGML_MAX_DOT_PRODUCT_ERROR}",
);
} }
// We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML // We diverge slightly due to different rounding behavior / f16 to f32 conversions in GGML
// => we use a slightly higher error threshold // => we use a slightly higher error threshold
const ERROR_LENIENCY: f32 = 0.00001; const ERROR_LENIENCY: f32 = 0.00001;
if error - ERROR_LENIENCY > ggml_error { if error - ERROR_LENIENCY > ggml_error {
candle_core::bail!( bail!(
"Dot product error {} exceeds ggml reference error {}", "Dot product error {} exceeds ggml reference error {}",
error, error,
ggml_error ggml_error
@ -676,6 +558,16 @@ fn ggml_matmul_error_test<T: GgmlType>() -> Result<()> {
Ok(()) Ok(())
} }
#[test]
fn quantized_mm() -> Result<()> {
ggml_matmul_error_test::<k_quants::BlockQ4_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ4_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_0>()?;
ggml_matmul_error_test::<k_quants::BlockQ5_1>()?;
ggml_matmul_error_test::<k_quants::BlockQ8_0>()?;
Ok(())
}
/// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result. /// generates random tensors of size `m x k` and `n x k` and calculates their expected matrix multiplication result.
fn get_random_tensors( fn get_random_tensors(
m: usize, m: usize,
@ -699,108 +591,6 @@ fn get_random_tensors(
Ok((lhs, rhs, mm)) Ok((lhs, rhs, mm))
} }
#[macro_export]
macro_rules! quantized_matmul {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
fn $fn_name(device: &Device) -> Result<()> {
test_matmul(device, (1, 3, 4, 256), $dtype)?;
Ok(())
}
test_device!($fn_name, $fn_name_cpu, $fn_name_cuda, $fn_name_metal);
};
}
quantized_matmul!(
quantized_matmul_q4_0_bis,
quantized_matmul_q4_0_cpu,
quantized_matmul_q4_0_cuda,
quantized_matmul_q4_0_metal,
GgmlDType::Q4_0
);
quantized_matmul!(
quantized_matmul_q4_1_bis,
quantized_matmul_q4_1_cpu,
quantized_matmul_q4_1_cuda,
quantized_matmul_q4_1_metal,
GgmlDType::Q4_1
);
quantized_matmul!(
quantized_matmul_q5_0_bis,
quantized_matmul_q5_0_cpu,
quantized_matmul_q5_0_cuda,
quantized_matmul_q5_0_metal,
GgmlDType::Q5_0
);
quantized_matmul!(
quantized_matmul_q5_1_bis,
quantized_matmul_q5_1_cpu,
quantized_matmul_q5_1_cuda,
quantized_matmul_q5_1_metal,
GgmlDType::Q5_1
);
quantized_matmul!(
quantized_matmul_q8_0_bis,
quantized_matmul_q8_0_cpu,
quantized_matmul_q8_0_cuda,
quantized_matmul_q8_0_metal,
GgmlDType::Q8_0
);
// Not implemented in Ggml
// quantized_matmul!(
// quantized_matmul_q8_1_bis,
// quantized_matmul_q8_1_cpu,
// quantized_matmul_q8_1_cuda,
// quantized_matmul_q8_1_metal,
// GgmlDType::Q8_1
// );
// TODO This is bugged (also bugged in GGML
quantized_matmul!(
quantized_matmul_q2k_bis,
quantized_matmul_q2k_cpu,
quantized_matmul_q2k_cuda,
quantized_matmul_q2k_metal,
GgmlDType::Q2K
);
quantized_matmul!(
quantized_matmul_q3k_bis,
quantized_matmul_q3k_cpu,
quantized_matmul_q3k_cuda,
quantized_matmul_q3k_metal,
GgmlDType::Q3K
);
quantized_matmul!(
quantized_matmul_q4k_bis,
quantized_matmul_q4k_cpu,
quantized_matmul_q4k_cuda,
quantized_matmul_q4k_metal,
GgmlDType::Q4K
);
quantized_matmul!(
quantized_matmul_q5k_bis,
quantized_matmul_q5k_cpu,
quantized_matmul_q5k_cuda,
quantized_matmul_q5k_metal,
GgmlDType::Q5K
);
quantized_matmul!(
quantized_matmul_q6k_bis,
quantized_matmul_q6k_cpu,
quantized_matmul_q6k_cuda,
quantized_matmul_q6k_metal,
GgmlDType::Q6K
);
// Not implemented on metal
// quantized_matmul!(
// quantized_matmul_q8k_bis,
// quantized_matmul_q8k_cpu,
// quantized_matmul_q8k_cuda,
// quantized_matmul_q8k_metal,
// GgmlDType::Q8K
// );
#[test] #[test]
fn quantized_matmul_q2k() -> Result<()> { fn quantized_matmul_q2k() -> Result<()> {
use k_quants::BlockQ2K; use k_quants::BlockQ2K;
@ -813,7 +603,7 @@ fn quantized_matmul_q2k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q2K)?; let rhs = quantized::QTensor::quantize::<BlockQ2K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -839,7 +629,7 @@ fn quantized_matmul_q3k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q3K)?; let rhs = quantized::QTensor::quantize::<BlockQ3K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -865,7 +655,7 @@ fn quantized_matmul_q4k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q4K)?; let rhs = quantized::QTensor::quantize::<BlockQ4K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -891,7 +681,7 @@ fn quantized_matmul_q5k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q5K)?; let rhs = quantized::QTensor::quantize::<BlockQ5K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -918,7 +708,7 @@ fn quantized_matmul_q6k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q6K)?; let rhs = quantized::QTensor::quantize::<BlockQ6K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;
@ -943,7 +733,7 @@ fn quantized_matmul_q8k() -> Result<()> {
let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]); let dst = round_vector(&[dst[0], dst[m * n / 3], dst[m * n * 2 / 3], dst[m * n - 1]]);
assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]); assert_eq!(dst, [1.262, 1.513, -0.208, 1.702]);
let rhs = quantized::QTensor::quantize(&rhs, GgmlDType::Q8K)?; let rhs = quantized::QTensor::quantize::<BlockQ8K>(&rhs)?;
let rhs = quantized::QMatMul::from_qtensor(rhs)?; let rhs = quantized::QMatMul::from_qtensor(rhs)?;
let mm = rhs.forward(&lhs)?; let mm = rhs.forward(&lhs)?;

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies] [dependencies]
byteorder = { workspace = true } byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
hf-hub = { workspace = true} hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true } memmap2 = { workspace = true }

View File

@ -11,12 +11,12 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" } candle-datasets = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } candle-flash-attn = { workspace = true, optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true } candle-onnx = { workspace = true, optional = true }
csv = "1.3.0" csv = "1.3.0"
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
@ -49,11 +49,12 @@ tokio = "1.29.1"
[build-dependencies] [build-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
bindgen_cuda = { version = "0.1.1", optional = true }
[features] [features]
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
cudnn = ["candle/cudnn"] cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

View File

@ -4,251 +4,34 @@ use std::io::Write;
use std::path::PathBuf; use std::path::PathBuf;
struct KernelDirectories { struct KernelDirectories {
kernel_dir: &'static str, kernel_glob: &'static str,
rust_target: &'static str, rust_target: &'static str,
include_dirs: &'static [&'static str], include_dirs: &'static [&'static str],
} }
const DIRS: [KernelDirectories; 1] = [KernelDirectories { const KERNEL_DIRS: [KernelDirectories; 1] = [KernelDirectories {
kernel_dir: "examples/custom-ops/kernels/", kernel_glob: "examples/custom-ops/kernels/*.cu",
rust_target: "examples/custom-ops/cuda_kernels.rs", rust_target: "examples/custom-ops/cuda_kernels.rs",
include_dirs: &[], include_dirs: &[],
}]; }];
impl KernelDirectories {
fn maybe_build_ptx(
&self,
cu_file: &std::path::Path,
ptx_file: &std::path::Path,
compute_cap: usize,
) -> Result<()> {
let should_compile = if ptx_file.exists() {
let ptx_modified = ptx_file.metadata()?.modified()?;
let cu_modified = cu_file.metadata()?.modified()?;
cu_modified.duration_since(ptx_modified).is_ok()
} else {
true
};
if should_compile {
#[cfg(feature = "cuda")]
{
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
self.include_dirs.iter().map(|c| format!("-I{c}")).collect();
command
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", out_dir.to_str().unwrap()])
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
#[cfg(not(feature = "cuda"))]
std::fs::OpenOptions::new()
.create(true)
.write(true)
.open(ptx_file)?;
}
Ok(())
}
fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> {
println!("cargo:rerun-if-changed={}", self.kernel_dir);
let kernel_dir = PathBuf::from(self.kernel_dir);
let out_dir = out_dir.join(self.kernel_dir);
if !out_dir.exists() {
std::fs::create_dir_all(&out_dir)?;
}
let mut cu_files = vec![];
let mut cuh_files = vec![];
for file in std::fs::read_dir(kernel_dir)?.flatten() {
let file = file.path();
match file.extension().and_then(|v| v.to_str()) {
Some("cu") => cu_files.push(file),
Some("cuh") => cuh_files.push(file),
_ => {}
}
}
let mut ptx_paths = vec![];
for cu_file in cu_files.iter() {
let file_stem = cu_file
.file_stem()
.with_context(|| format!("no stem {cu_file:?}"))?;
let file_stem = file_stem.to_string_lossy().into_owned();
let ptx_file = out_dir.join(&format!("{file_stem}.ptx"));
self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?;
ptx_paths.push(ptx_file);
}
let regenerate_rs_file = true;
if regenerate_rs_file {
let mut file = std::fs::File::create(self.rust_target)?;
for ptx_path in ptx_paths {
let name = ptx_path
.file_stem()
.context("empty stem")?
.to_string_lossy();
file.write_all(b"#[rustfmt::skip]\n")?;
let const_definition = format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
self.kernel_dir,
);
file.write_all(const_definition.as_bytes())?;
file.write_all(b"\n")?;
}
}
Ok(())
}
}
fn main() -> Result<()> { fn main() -> Result<()> {
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?;
let out_dir = PathBuf::from(out_dir);
#[cfg(feature = "cuda")] #[cfg(feature = "cuda")]
set_cuda_include_dir()?; {
#[cfg(feature = "cuda")] for kdir in KERNEL_DIRS.iter() {
let compute_cap = compute_cap()?; let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write(kdir.rust_target).unwrap()
}
}
#[cfg(not(feature = "cuda"))] #[cfg(not(feature = "cuda"))]
let compute_cap = 0; {
for d in DIRS { for kdir in KERNEL_DIRS.iter() {
d.process(&out_dir, compute_cap)? let _file = std::fs::File::create(kdir.rust_target)?;
}
} }
Ok(()) Ok(())
} }
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute cap from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Grab compute cap from nvidia-smi
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

View File

@ -106,17 +106,17 @@ pub fn main() -> anyhow::Result<()> {
let config = blip::Config::image_captioning_large(); let config = blip::Config::image_captioning_large();
let device = candle_examples::device(args.cpu)?;
let (image_embeds, device, mut model) = if args.quantized { let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu; let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?; let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");
let vb = quantized_blip::VarBuilder::from_gguf(model_file, &device)?; let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?; let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model)) (image_embeds, device, Model::Q(model))
} else { } else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?; let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}"); println!("loaded image {image:?}");

View File

@ -1,2 +1 @@
#[rustfmt::skip] pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/examples/custom-ops/kernels//layernorm_kernels.ptx"));

View File

@ -6,7 +6,8 @@
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
#[allow(unused)] #[rustfmt::skip]
#[cfg(feature = "cuda")]
mod cuda_kernels; mod cuda_kernels;
use clap::Parser; use clap::Parser;

View File

@ -165,14 +165,14 @@ fn main() -> Result<()> {
let mut index_pos = 0; let mut index_pos = 0;
let mut token_generated = 0; let mut token_generated = 0;
for index in 0..args.sample_len { for index in 0..args.sample_len {
let context_size = if cache.use_kv_cache && index > 0 { let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
1 (1, index_pos)
} else { } else {
tokens.len() (tokens.len(), 0)
}; };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?; let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
let logits = llama.forward(&input, index_pos)?; let logits = llama.forward(&input, context_index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
logits logits

View File

@ -262,7 +262,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.extension() .extension()
.map_or(false, |v| v == "safetensors"); .map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf { let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?; let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")? .get_no_shape("model.embed_tokens.weight")?
.shape() .shape()
@ -279,13 +279,13 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
(config.seq_len, config.head_size() / 2), (config.seq_len, config.head_size() / 2),
"rot.freq_cis_real", "rot.freq_cis_real",
)? )?
.dequantize(&device)?; .dequantize(&candle::Device::Cpu)?;
let freq_cis_imag = vb let freq_cis_imag = vb
.get( .get(
(config.seq_len, config.head_size() / 2), (config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag", "rot.freq_cis_imag",
)? )?
.dequantize(&device)?; .dequantize(&candle::Device::Cpu)?;
let fake_vb = candle_nn::VarBuilder::from_tensors( let fake_vb = candle_nn::VarBuilder::from_tensors(
[ [
@ -295,7 +295,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
.into_iter() .into_iter()
.collect(), .collect(),
candle::DType::F32, candle::DType::F32,
&device, &candle::Device::Cpu,
); );
let cache = model::Cache::new(true, &config, fake_vb)?; let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?); let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);

View File

@ -244,14 +244,13 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::config_7b_v0_1(args.use_flash_attn); let config = Config::config_7b_v0_1(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];
let vb = let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QMistral::new(&config, vb)?; let model = QMistral::new(&config, vb)?;
(Model::Quantized(model), device) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() { let dtype = if device.is_cuda() {
DType::BF16 DType::BF16
} else { } else {

View File

@ -299,26 +299,21 @@ fn main() -> Result<()> {
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(), WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(), WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
}; };
let device = candle_examples::device(args.cpu)?; let (model, device) = if args.quantized {
let model = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&filenames[0],
&device,
)?;
println!("Loaded vb");
let model = match args.model { let model = match args.model {
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?, WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?, _ => QMixFormer::new(&config, vb)?,
}; };
println!("Loaded model"); (Model::Quantized(model), Device::Cpu)
Model::Quantized(model)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = match args.model { let model = match args.model {
WhichModel::V2 => MixFormer::new_v2(&config, vb)?, WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
_ => MixFormer::new(&config, vb)?, _ => MixFormer::new(&config, vb)?,
}; };
Model::MixFormer(model) (Model::MixFormer(model), device)
}; };
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());

View File

@ -132,8 +132,7 @@ impl T5ModelBuilder {
} }
pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> { pub fn build_model(&self) -> Result<t5::T5ForConditionalGeneration> {
let device = Device::Cpu; let vb = t5::VarBuilder::from_gguf(&self.weights_filename)?;
let vb = t5::VarBuilder::from_gguf(&self.weights_filename, &device)?;
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?) Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
} }

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file}; use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor; use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
@ -361,7 +361,6 @@ fn main() -> anyhow::Result<()> {
let model_path = args.model()?; let model_path = args.model()?;
let mut file = std::fs::File::open(&model_path)?; let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(false)?;
let mut model = match model_path.extension().and_then(|v| v.to_str()) { let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => { Some("gguf") => {
@ -370,7 +369,7 @@ fn main() -> anyhow::Result<()> {
for (_, tensor) in model.tensor_infos.iter() { for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count(); let elem_count = tensor.shape.elem_count();
total_size_in_bytes += total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.blck_size();
} }
println!( println!(
"loaded {:?} tensors ({}) in {:.2}s", "loaded {:?} tensors ({}) in {:.2}s",
@ -378,16 +377,15 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes), &format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(), start.elapsed().as_secs_f32(),
); );
ModelWeights::from_gguf(model, &mut file, &device)? ModelWeights::from_gguf(model, &mut file)?
} }
Some("ggml" | "bin") | Some(_) | None => { Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file, &device) let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
.map_err(|e| e.with_path(model_path))?;
let mut total_size_in_bytes = 0; let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() { for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count(); let elem_count = tensor.shape().elem_count();
total_size_in_bytes += total_size_in_bytes +=
elem_count * tensor.dtype().type_size() / tensor.dtype().block_size(); elem_count * tensor.dtype().type_size() / tensor.dtype().blck_size();
} }
println!( println!(
"loaded {:?} tensors ({}) in {:.2}s", "loaded {:?} tensors ({}) in {:.2}s",
@ -488,7 +486,7 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = { let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?; let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?; let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)? logits_processor.sample(&logits)?
@ -509,7 +507,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now(); let start_post_prompt = std::time::Instant::now();
let mut sampled = 0; let mut sampled = 0;
for index in 0..to_sample { for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {

View File

@ -236,11 +236,9 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = Device::Cpu;
let config = Config::replit_code_v1_5_3b(); let config = Config::replit_code_v1_5_3b();
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let vb = let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
(model, Device::Cpu) (model, Device::Cpu)
} else { } else {

View File

@ -234,14 +234,13 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = Config::stablelm_3b_4e1t(args.use_flash_attn); let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];
let vb = let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename)?;
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
let model = QStableLM::new(&config, vb)?; let model = QStableLM::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu) (Model::Quantized(model), Device::Cpu)
} else { } else {
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() { let dtype = if device.is_cuda() {
DType::BF16 DType::BF16
} else { } else {

View File

@ -557,10 +557,8 @@ fn main() -> Result<()> {
println!("loaded mel: {:?}", mel.dims()); println!("loaded mel: {:?}", mel.dims());
let mut model = if args.quantized { let mut model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf( let vb =
&weights_filename, candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&weights_filename)?;
&device,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else { } else {
let vb = let vb =

View File

@ -11,14 +11,14 @@ license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" } candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]
bindgen_cuda = "0.1.1"
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
num_cpus = "1.15.0"
rayon = "1.7.0"
[dev-dependencies] [dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] } candle-nn = { path = "../candle-nn", features = ["cuda"] }

View File

@ -2,44 +2,32 @@
// The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment // The cuda build time is very long so one can set the CANDLE_FLASH_ATTN_BUILD_DIR environment
// variable in order to cache the compiled artifacts and avoid recompiling too often. // variable in order to cache the compiled artifacts and avoid recompiling too often.
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use rayon::prelude::*;
use std::path::PathBuf; use std::path::PathBuf;
use std::str::FromStr;
const KERNEL_FILES: [&str; 17] = [ const KERNEL_FILES: [&str; 17] = [
"flash_api.cu", "kernels/flash_api.cu",
"flash_fwd_hdim128_fp16_sm80.cu", "kernels/flash_fwd_hdim128_fp16_sm80.cu",
"flash_fwd_hdim160_fp16_sm80.cu", "kernels/flash_fwd_hdim160_fp16_sm80.cu",
"flash_fwd_hdim192_fp16_sm80.cu", "kernels/flash_fwd_hdim192_fp16_sm80.cu",
"flash_fwd_hdim224_fp16_sm80.cu", "kernels/flash_fwd_hdim224_fp16_sm80.cu",
"flash_fwd_hdim256_fp16_sm80.cu", "kernels/flash_fwd_hdim256_fp16_sm80.cu",
"flash_fwd_hdim32_fp16_sm80.cu", "kernels/flash_fwd_hdim32_fp16_sm80.cu",
"flash_fwd_hdim64_fp16_sm80.cu", "kernels/flash_fwd_hdim64_fp16_sm80.cu",
"flash_fwd_hdim96_fp16_sm80.cu", "kernels/flash_fwd_hdim96_fp16_sm80.cu",
"flash_fwd_hdim128_bf16_sm80.cu", "kernels/flash_fwd_hdim128_bf16_sm80.cu",
"flash_fwd_hdim160_bf16_sm80.cu", "kernels/flash_fwd_hdim160_bf16_sm80.cu",
"flash_fwd_hdim192_bf16_sm80.cu", "kernels/flash_fwd_hdim192_bf16_sm80.cu",
"flash_fwd_hdim224_bf16_sm80.cu", "kernels/flash_fwd_hdim224_bf16_sm80.cu",
"flash_fwd_hdim256_bf16_sm80.cu", "kernels/flash_fwd_hdim256_bf16_sm80.cu",
"flash_fwd_hdim32_bf16_sm80.cu", "kernels/flash_fwd_hdim32_bf16_sm80.cu",
"flash_fwd_hdim64_bf16_sm80.cu", "kernels/flash_fwd_hdim64_bf16_sm80.cu",
"flash_fwd_hdim96_bf16_sm80.cu", "kernels/flash_fwd_hdim96_bf16_sm80.cu",
]; ];
fn main() -> Result<()> { fn main() -> Result<()> {
let num_cpus = std::env::var("RAYON_NUM_THREADS").map_or_else(
|_| num_cpus::get_physical(),
|s| usize::from_str(&s).unwrap(),
);
rayon::ThreadPoolBuilder::new()
.num_threads(num_cpus)
.build_global()
.unwrap();
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
for kernel_file in KERNEL_FILES.iter() { for kernel_file in KERNEL_FILES.iter() {
println!("cargo:rerun-if-changed=kernels/{kernel_file}"); println!("cargo:rerun-if-changed={kernel_file}");
} }
println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_kernel.h");
println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h"); println!("cargo:rerun-if-changed=kernels/flash_fwd_launch_template.h");
@ -66,223 +54,30 @@ fn main() -> Result<()> {
)) ))
} }
}; };
set_cuda_include_dir()?;
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN"); let kernels = KERNEL_FILES.iter().collect();
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN"); let builder = bindgen_cuda::Builder::default()
.kernel_paths(kernels)
let compute_cap = compute_cap()?; .out_dir(build_dir.clone())
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
let out_file = build_dir.join("libflashattention.a"); let out_file = build_dir.join("libflashattention.a");
builder.build_lib(out_file);
let kernel_dir = PathBuf::from("kernels");
let cu_files: Vec<_> = KERNEL_FILES
.iter()
.map(|f| {
let mut obj_file = out_dir.join(f);
obj_file.set_extension("o");
(kernel_dir.join(f), obj_file)
})
.collect();
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
let should_compile = if out_file.exists() {
kernel_dir
.read_dir()
.expect("kernels folder should exist")
.any(|entry| {
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
let in_modified = entry.metadata().unwrap().modified().unwrap();
in_modified.duration_since(*out_modified).is_ok()
} else {
true
}
})
} else {
true
};
if should_compile {
cu_files
.par_iter()
.map(|(cu_file, obj_file)| {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(cu_file);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while executing compiling: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
Ok(())
})
.collect::<Result<()>>()?;
let obj_files = cu_files.iter().map(|c| c.1.clone()).collect::<Vec<_>>();
let mut command = std::process::Command::new("nvcc");
command
.arg("--lib")
.args(["-o", out_file.to_str().unwrap()])
.args(obj_files);
let output = command
.spawn()
.context("failed spawning nvcc")?
.wait_with_output()?;
if !output.status.success() {
anyhow::bail!(
"nvcc error while linking: {:?}\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
&command,
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
)
}
}
println!("cargo:rustc-link-search={}", build_dir.display()); println!("cargo:rustc-link-search={}", build_dir.display());
println!("cargo:rustc-link-lib=flashattention"); println!("cargo:rustc-link-lib=flashattention");
println!("cargo:rustc-link-lib=dylib=cudart"); println!("cargo:rustc-link-lib=dylib=cudart");
println!("cargo:rustc-link-lib=dylib=stdc++"); println!("cargo:rustc-link-lib=dylib=stdc++");
/* laurent: I tried using the cc cuda integration as below but this lead to ptaxs never
finishing to run for some reason. Calling nvcc manually worked fine.
cc::Build::new()
.cuda(true)
.include("cutlass/include")
.flag("--expt-relaxed-constexpr")
.flag("--default-stream")
.flag("per-thread")
.flag(&format!("--gpu-architecture=sm_{compute_cap}"))
.file("kernels/flash_fwd_hdim32_fp16_sm80.cu")
.compile("flashattn");
*/
Ok(()) Ok(())
} }
fn set_cuda_include_dir() -> Result<()> {
// NOTE: copied from cudarc build.rs.
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
let roots = roots.into_iter().map(Into::<PathBuf>::into);
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.context("cannot find include/cuda.h")?;
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
Ok(())
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse compute cap")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}

View File

@ -0,0 +1,62 @@
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_causal, typename Engine, typename Layout>
inline __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
const int col_idx_offset_,
const int max_seqlen_k,
const int row_idx_offset,
const int max_seqlen_q,
const int warp_row_stride,
const float alibi_slope) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}
} // namespace flash

View File

@ -14,9 +14,12 @@ struct BlockInfo {
template<typename Params> template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb) __device__ BlockInfo(const Params &params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]) , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{ {
} }
@ -32,8 +35,10 @@ struct BlockInfo {
const int sum_s_q; const int sum_s_q;
const int sum_s_k; const int sum_s_k;
const uint32_t actual_seqlen_q; const int actual_seqlen_q;
const uint32_t actual_seqlen_k; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
const int seqlen_k_cache;
const int actual_seqlen_k;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -7,15 +7,6 @@
#include <cuda.h> #include <cuda.h>
#include <vector> #include <vector>
// #ifdef OLD_GENERATOR_PATH
// #include <ATen/CUDAGeneratorImpl.h>
// #else
// #include <ATen/cuda/CUDAGeneratorImpl.h>
// #endif
//
// #include <ATen/cuda/CUDAGraphsUtils.cuh>
constexpr int TOTAL_DIM = 0; constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1; constexpr int H_DIM = 1;
constexpr int D_DIM = 2; constexpr int D_DIM = 2;
@ -53,6 +44,7 @@ struct Flash_fwd_params : public Qkv_params {
// The O matrix (output). // The O matrix (output).
void * __restrict__ o_ptr; void * __restrict__ o_ptr;
void * __restrict__ oaccum_ptr;
// The stride between rows of O. // The stride between rows of O.
index_t o_batch_stride; index_t o_batch_stride;
@ -64,9 +56,10 @@ struct Flash_fwd_params : public Qkv_params {
// The pointer to the softmax sum. // The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr; void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;
// The dimensions. // The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
// The scaling factors for the kernel. // The scaling factors for the kernel.
float scale_softmax; float scale_softmax;
@ -76,8 +69,30 @@ struct Flash_fwd_params : public Qkv_params {
int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k; int * __restrict__ cu_seqlens_k;
// If provided, the actual length of each k sequence.
int * __restrict__ seqused_k;
int *__restrict__ blockmask; int *__restrict__ blockmask;
// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
void * __restrict__ vnew_ptr;
// The stride between rows of the Q, K and V matrices.
index_t knew_batch_stride;
index_t vnew_batch_stride;
index_t knew_row_stride;
index_t vnew_row_stride;
index_t knew_head_stride;
index_t vnew_head_stride;
// The cos and sin matrices for rotary embedding.
void * __restrict__ rotary_cos_ptr;
void * __restrict__ rotary_sin_ptr;
// The indices to index into the KV cache.
int *__restrict__ cache_batch_idx;
// The dropout probability (probability of keeping an activation). // The dropout probability (probability of keeping an activation).
float p_dropout; float p_dropout;
// uint32_t p_dropout_in_uint; // uint32_t p_dropout_in_uint;
@ -88,11 +103,22 @@ struct Flash_fwd_params : public Qkv_params {
float rp_dropout; float rp_dropout;
float scale_softmax_rp_dropout; float scale_softmax_rp_dropout;
// Random state. // Local window size
// at::PhiloxCudaState philox_args; int window_size_left, window_size_right;
bool is_bf16; bool is_bf16;
bool is_causal; bool is_causal;
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
bool is_seqlens_k_cumulative;
bool is_rotary_interleaved;
int num_splits; // For split-KV version
void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -132,10 +158,14 @@ struct Flash_bwd_params : public Flash_fwd_params {
// The pointer to the softmax d sum. // The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum; void *__restrict__ dsoftmax_sum;
bool deterministic;
index_t dq_accum_split_stride;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream); template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure); template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream, const bool configure);

View File

@ -1,17 +1,15 @@
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split_kernel=false) {
// FWD_HEADDIM_SWITCH(params.d, [&] { FP16_SWITCH(!params.is_bf16, [&] {
// run_mha_fwd_<cutlass::half_t, kHeadDim>(params, stream); FWD_HEADDIM_SWITCH(params.d, [&] {
// }); // if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
// } run_mha_fwd_<elem_type, kHeadDim>(params, stream);
// } else {
void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) { // run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
FP16_SWITCH(!params.is_bf16, [&] { // }
FWD_HEADDIM_SWITCH(params.d, [&] { });
run_mha_fwd_<elem_type, kHeadDim>(params, stream); });
});
});
} }
extern "C" void run_mha( extern "C" void run_mha(
@ -20,6 +18,7 @@ extern "C" void run_mha(
void *v_ptr, void *v_ptr,
void *o_ptr, void *o_ptr,
void *softmax_lse_ptr, void *softmax_lse_ptr,
void *alibi_slopes_ptr,
int32_t *cu_seqlens_q_ptr, int32_t *cu_seqlens_q_ptr,
int32_t *cu_seqlens_k_ptr, int32_t *cu_seqlens_k_ptr,
@ -28,6 +27,7 @@ extern "C" void run_mha(
uint32_t k_batch_stride, uint32_t k_batch_stride,
uint32_t v_batch_stride, uint32_t v_batch_stride,
uint32_t o_batch_stride, uint32_t o_batch_stride,
uint32_t alibi_slopes_batch_stride,
uint32_t q_row_stride, uint32_t q_row_stride,
uint32_t k_row_stride, uint32_t k_row_stride,
@ -51,8 +51,11 @@ extern "C" void run_mha(
uint32_t seqlen_q_rounded, uint32_t seqlen_q_rounded,
uint32_t seqlen_k_rounded, uint32_t seqlen_k_rounded,
int is_bf16,
int is_causal, int is_causal,
int is_bf16
int window_size_left,
int window_size_right
) { ) {
Flash_fwd_params params; Flash_fwd_params params;
// Reset the parameters // Reset the parameters
@ -65,12 +68,14 @@ extern "C" void run_mha(
params.o_ptr = o_ptr; params.o_ptr = o_ptr;
params.softmax_lse_ptr = softmax_lse_ptr; params.softmax_lse_ptr = softmax_lse_ptr;
params.alibi_slopes_ptr = alibi_slopes_ptr;
// All stride are in elements, not bytes. // All stride are in elements, not bytes.
params.q_batch_stride = q_batch_stride; params.q_batch_stride = q_batch_stride;
params.k_batch_stride = k_batch_stride; params.k_batch_stride = k_batch_stride;
params.v_batch_stride = v_batch_stride; params.v_batch_stride = v_batch_stride;
params.o_batch_stride = o_batch_stride; params.o_batch_stride = o_batch_stride;
params.alibi_slopes_batch_stride = alibi_slopes_batch_stride;
params.q_row_stride = q_row_stride; params.q_row_stride = q_row_stride;
params.k_row_stride = k_row_stride; params.k_row_stride = k_row_stride;
@ -92,7 +97,6 @@ extern "C" void run_mha(
params.seqlen_k_rounded = seqlen_k_rounded; params.seqlen_k_rounded = seqlen_k_rounded;
params.d = d; params.d = d;
params.d_rounded = d_rounded; params.d_rounded = d_rounded;
params.is_causal = is_causal;
// Set the different scale values. // Set the different scale values.
params.scale_softmax = softmax_scale; params.scale_softmax = softmax_scale;
@ -106,6 +110,14 @@ extern "C" void run_mha(
params.cu_seqlens_q = cu_seqlens_q_ptr; params.cu_seqlens_q = cu_seqlens_q_ptr;
params.cu_seqlens_k = cu_seqlens_k_ptr; params.cu_seqlens_k = cu_seqlens_k_ptr;
params.p_ptr = nullptr; // used for `return_softmax`. params.p_ptr = nullptr; // used for `return_softmax`.
params.seqused_k = nullptr;
params.is_causal = is_causal;
params.window_size_left = window_size_left;
params.window_size_right = window_size_right;
params.is_seqlens_k_cumulative = true;
params.num_splits = 1;
cudaStream_t stream = 0; // Use the default stream. cudaStream_t stream = 0; // Use the default stream.
run_mha_fwd(params, stream); run_mha_fwd(params, stream);

View File

@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,32 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, false, true, elem_type>, false>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 64, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 64, 128, 4, false, false, elem_type>, false>(params, stream);
// // 1st ones are good for H100, A100
// // 2nd one is good for A6000 bc we get slightly better occupancy
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<128, 128, 32, 4, true, true, elem_type>, true>(params, stream);
// // 1st one is good for H100, A100, A6000
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t>(params, stream); run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
} }

View File

@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,27 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 32, 4, false, true, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 64, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<160, 128, 128, 8, false, elem_type>>(params, stream);
// // For A6000, no-causal, 1st is fastest. causal, 4th is fastest.
// // For A100, H100, 1st is fastest.
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim160<cutlass::half_t>(params, stream); run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
} }

View File

@ -1,16 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<> template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,27 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 32, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // This one is slightly faster for causal?
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 8, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 32, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 64, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 64, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<192, 128, 128, 8, false, elem_type>>(params, stream);
// });
// // For A100 H100, 1st is faster with dropout, 3rd is faster without dropout
// // For A6000, 1st is faster when causal, 3rd is faster when not causal
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t>(params, stream); run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim224<cutlass::half_t>(params, stream); run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,9 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) { template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t>(params, stream); run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
} }

View File

@ -1,10 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,23 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 128, 4, false, false, elem_type>, Is_dropout>(params, stream);
// // For dropout there might be a lot of register spilling?
// // These two are very slow due to register spilling
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 128, 4, false, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 128, 256, 4, false, elem_type>>(params, stream);
// // This one is slightly slower
// // run_flash_fwd<Flash_fwd_kernel_traits<32, 256, 64, 4, false, elem_type>>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t>(params, stream); run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
} }

View File

@ -1,19 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// if (params.p_dropout == 1.f) {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,26 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// if (params.p_dropout == 1.f) {
// // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
// // Using block size (64 x 256) is 27% slower for seqlen=2k
// // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 128, 4, false, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, false>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, false>(params, stream);
// } else {
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, false, false, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, true, elem_type>, true>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<64, 128, 64, 4, true, false, elem_type>, true>(params, stream);
// }
// }
template<> template<>
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t>(params, stream); run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
} }

View File

@ -1,17 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<>
// void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::bfloat16_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// });
// }
template<> template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream); run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
} }

View File

@ -1,23 +1,10 @@
// Copyright (c) 2023, Tri Dao. // Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation. // Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "flash_fwd_launch_template.h" #include "flash_fwd_launch_template.h"
// template<> template<>
// void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
// using elem_type = cutlass::half_t;
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, true, true, elem_type>, Is_dropout>(params, stream);
// // This 3rd one is good for H100, and A100, A6000
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, false, elem_type>, Is_dropout>(params, stream);
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 64, 4, false, true, elem_type>, Is_dropout>(params, stream);
// // These two are always slower
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, elem_type>>(params, stream);
// // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, elem_type>>(params, stream);
// });
// }
template<> void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t>(params, stream); run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
} }

View File

@ -4,20 +4,18 @@
#pragma once #pragma once
#include <cmath>
#include <cute/algorithm/copy.hpp> #include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/array.h> #include <cutlass/array.h>
#include <cutlass/numeric_types.h> #include <cutlass/numeric_types.h>
#include <cutlass/numeric_conversion.h>
#include "block_info.h" #include "block_info.h"
#include "kernel_traits.h" #include "kernel_traits.h"
#include "utils.h" #include "utils.h"
#include "softmax.h" #include "softmax.h"
#include "philox.cuh"
#include "alibi.h"
namespace flash { namespace flash {
@ -25,49 +23,6 @@ using namespace cute;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int MMA_M,
class... Args,
class TiledMMA>
CUTE_HOST_DEVICE
auto
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
TiledMMA const& tiled_mma) {
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
constexpr int MMAStride_M = MMA_M * AtomShape_M;
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
Stride<_1, Int<MMAStride_M>> >{},
// TODO: Shouldn't this be size<1>?
make_layout(size<2>(TileShape_MNK{})));
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2> template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum, inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
Tensor2 &acc_o, float softmax_scale_log2) { Tensor2 &acc_o, float softmax_scale_log2) {
@ -77,7 +32,7 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
flash::reduce_sum(scores, scores_sum); flash::reduce_sum(scores, scores_sum);
} else { } else {
Tensor scores_max_prev = make_fragment_like(scores_max); Tensor scores_max_prev = make_fragment_like(scores_max);
copy(scores_max, scores_max_prev); cute::copy(scores_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, scores_max); flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
@ -103,23 +58,22 @@ inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, T
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy> template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
inline __device__ void write_softmax_to_gmem( inline __device__ void write_softmax_to_gmem(
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_thr_copy_P Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
) { ) {
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
Layout l = tOrP.layout(); Layout l = tOrP.layout();
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{}); CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
// TODO(laurent): reactivate the following CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
// CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<1>(tPrP); ++mi) { for (int mi = 0; mi < size<1>(tPrP); ++mi) {
copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
} }
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) { inline __device__ void compute_attn_1rowblock(const Params &params, const int bidb, const int bidh, const int m_block) {
using Element = typename Kernel_traits::Element; using Element = typename Kernel_traits::Element;
@ -138,16 +92,65 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
constexpr int kNWarps = Kernel_traits::kNWarps; constexpr int kNWarps = Kernel_traits::kNWarps;
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb); const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN);
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
if (Is_causal) { if (Is_causal || Is_local) {
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN)); n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
// } // }
} }
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
// Otherwise we might read OOB elements from gK and gV.
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
// exit early and no one saves the rng state.
// if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
// auto seeds = at::cuda::philox::unpack(params.philox_args);
// params.rng_state[0] = std::get<0>(seeds);
// params.rng_state[1] = std::get<1>(seeds);
// params.rng_state[0] = 0;
// params.rng_state[1] = 0;
// }
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.o_row_stride, _1{}));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{});
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
Tensor tOrO = make_tensor<Element>(shape(tOgO));
clear(tOrO);
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
if (!Is_even_K) {
#pragma unroll
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
}
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
);
#pragma unroll
for (int m = 0; m < size<1>(tOgO); ++m) {
const int row = get<0>(tOcO(0, m, 0));
if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; }
}
return;
}
// if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); }
// We iterate over the blocks in reverse order. This is because the last block is the only one // We iterate over the blocks in reverse order. This is because the last block is the only one
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
@ -185,8 +188,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
auto gmem_thr_copy_QKV = typename Kernel_traits::GmemTiledCopyQKV{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
auto gmem_thr_copy_P = typename Kernel_traits::GmemTiledCopyP{}.get_thread_slice(tidx); auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
@ -208,16 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Copy Atom retiling // Copy Atom retiling
// //
auto smem_thr_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();} // if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto smem_thr_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK); Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_thr_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
// TODO: this might need to change if we change the mma instruction in SM70 // TODO: this might need to change if we change the mma instruction in SM70
@ -268,8 +275,8 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tQrQ = make_fragment_like(tQgQ); Tensor tQrQ = make_fragment_like(tQgQ);
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
// // Copy rmem to smem // // Copy rmem to smem
@ -285,14 +292,14 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
__syncthreads(); __syncthreads();
} }
int n_block = n_block_max - 1; int n_block = n_block_max - 1;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway. // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash::copy<Is_even_N, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
binfo.actual_seqlen_k - n_block * kBlockN); binfo.actual_seqlen_k - n_block * kBlockN);
cute::cp_async_fence(); cute::cp_async_fence();
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
// __syncthreads(); // __syncthreads();
@ -302,7 +309,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
__syncthreads(); __syncthreads();
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
copy(smem_thr_copy_Q, tSsQ, tSrQ_copy_view); cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
} }
// auto seeds = at::cuda::philox::unpack(params.philox_args); // auto seeds = at::cuda::philox::unpack(params.philox_args);
@ -313,13 +320,19 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o); clear(acc_o);
float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
// For performance reason, we separate out two kinds of iterations: // For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't. // those that need masking on S, and those that don't.
// We need masking on S for the very last block when K and V has length not multiple of kBlockN. // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration. // We will have at least 1 "masking" iteration.
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1; // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = (!Is_causal && !Is_local)
? 1
: ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
#pragma unroll #pragma unroll
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
@ -330,28 +343,42 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Advance gV // Advance gV
if (masking_step > 0) { if (masking_step > 0) {
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
} else { } else {
// Clear the smem tiles to account for predicated off loads // Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
); );
} }
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
// if (cute::thread0()) { print(acc_s); } // if (cute::thread0()) { print(acc_s); }
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print_tensor(scores); }
// We don't put the masking before the matmul S = Q K^T because we don't clear sK // We don't put the masking before the matmul S = Q K^T because we don't clear sK
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
// can produce Inf / NaN. // can produce Inf / NaN.
if (!Is_causal) {
if (!Is_even_N) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); } if (Has_alibi) {
flash::apply_alibi<Is_causal>(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
}
if (!Is_causal && !Is_local) {
if (!Is_even_MN) { flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
} else { } else {
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n) // Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N) // Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
@ -364,20 +391,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Idk why it's get<1> and not get<0> of the stride. // Idk why it's get<1> and not get<0> of the stride.
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); } // if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
// I can't get the stride from idx_row // I can't get the stride from idx_row
flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, flash::apply_mask_local</*HasWSLeft=*/Is_local>(
// m_block * kBlockM + get<0>(idx_row(0)), scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, // m_block * kBlockM + get<0>(idx_row(0)),
kNWarps * 16); m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16); binfo.actual_seqlen_q, kNWarps * 16,
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16); params.window_size_left, params.window_size_right
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
);
// if (cute::thread0()) { print_tensor(scores); }
} }
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
@ -385,24 +416,24 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// TODO: when we have key_padding_mask we'll need to Check_inf // TODO: when we have key_padding_mask we'll need to Check_inf
masking_step == 0 masking_step == 0
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) ? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); : softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
// Convert scores from fp32 to fp16/bf16 // Convert scores from fp32 to fp16/bf16
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
@ -411,37 +442,38 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
} }
// if (cute::thread0()) { print(tOrP); } // if (cute::thread0()) { print(tOrP); }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
// if (cute::thread0()) { print(scores); } // if (cute::thread0()) { print(scores); }
// This check is at the end of the loop since we always have at least 1 iteration // This check is at the end of the loop since we always have at least 1 iteration
if (n_masking_steps > 1 && n_block <= 0) { if (n_masking_steps > 1 && n_block <= n_block_min) {
--n_block; --n_block;
break; break;
} }
} }
// These are the iterations where we don't need masking on S // These are the iterations where we don't need masking on S
for (; n_block >= 0; --n_block) { for (; n_block >= n_block_min; --n_block) {
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
clear(acc_s); clear(acc_s);
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
// Advance gV // Advance gV
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence(); cute::cp_async_fence();
flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_thr_copy_Q, smem_thr_copy_K acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
); );
flash::cp_async_wait<0>(); flash::cp_async_wait<0>();
__syncthreads(); __syncthreads();
if (n_block > 0) { if (n_block > n_block_min) {
// Advance gK // Advance gK
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
// This cp_async_fence needs to be in the if block, otherwise the synchronization // This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions. // isn't right and we get race conditions.
cute::cp_async_fence(); cute::cp_async_fence();
@ -449,22 +481,44 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
if (Has_alibi) {
flash::apply_alibi<Is_causal>(
scores,
n_block * kBlockN,
binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q,
kNWarps * 16,
alibi_slope
);
}
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
flash::apply_mask_local(
scores, n_block * kBlockN, binfo.actual_seqlen_k,
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
binfo.actual_seqlen_q, kNWarps * 16,
params.window_size_left, params.window_size_right
);
}
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(scores); Tensor rP = flash::convert_type<Element>(scores);
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout())); Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32; int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
uint32_t block_col_idx = n_block * (kBlockN / 32); int block_col_idx = n_block * (kBlockN / 32);
if (Return_softmax) { if (Return_softmax) {
Tensor tOrP_copy = make_fragment_like(tOrP); Tensor tOrP_copy = make_fragment_like(tOrP);
copy(tOrP, tOrP_copy); cute::copy(tOrP, tOrP_copy);
flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>( flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset, tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
block_row_idx, block_col_idx, kNWarps block_row_idx, block_col_idx, kNWarps
); );
flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
tPgP.data() = tPgP.data() + (-kBlockN); tPgP.data() = tPgP.data() + (-kBlockN);
} }
if (Is_dropout) { if (Is_dropout) {
@ -472,7 +526,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
block_row_idx, block_col_idx, kNWarps); block_row_idx, block_col_idx, kNWarps);
} }
flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_thr_copy_V); flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
} }
// Epilogue // Epilogue
@ -496,15 +550,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor rO = flash::convert_type<Element>(acc_o); Tensor rO = flash::convert_type<Element>(acc_o);
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning // Partition sO to match the accumulator partitioning
auto smem_thr_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
// sO has the same size as sQ, so we don't need to sync here. // sO has the same size as sQ, so we don't need to sync here.
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
copy(smem_thr_copy_O, taccOrO, taccOsO); cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
@ -515,14 +569,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse), Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<Int<kBlockM>>{}, Stride<_1>{}); Shape<Int<kBlockM>>{}, Stride<_1>{});
auto gmem_thr_copy_O = typename Kernel_traits::GmemTiledCopyO{}.get_thread_slice(tidx); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgO = gmem_thr_copy_O.partition_D(gO); Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
__syncthreads(); __syncthreads();
Tensor tOrO = make_tensor<Element>(shape(tOgO)); Tensor tOrO = make_tensor<Element>(shape(tOgO));
copy(gmem_thr_copy_O, tOsO, tOrO); cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
@ -548,14 +603,15 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
} }
// Clear_OOB_K must be false since we don't want to write zeros to gmem // Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
); );
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
inline __device__ void compute_attn(const Params &params) { inline __device__ void compute_attn(const Params &params) {
const int m_block = blockIdx.x; const int m_block = blockIdx.x;
// The block index for the batch. // The block index for the batch.
@ -571,7 +627,7 @@ inline __device__ void compute_attn(const Params &params) {
// the attention matrix. This way, as long as we have the batch, head, and the location of // the attention matrix. This way, as long as we have the batch, head, and the location of
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block); flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
} }
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -4,15 +4,14 @@
#pragma once #pragma once
// #include <ATen/cuda/CUDAContext.h>
#include "static_switch.h" #include "static_switch.h"
#include "flash.h" #include "flash.h"
#include "flash_fwd_kernel.h" #include "flash_fwd_kernel.h"
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax> template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) { __global__ void flash_fwd_kernel(Flash_fwd_params params) {
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params); static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
} }
template<typename Kernel_traits, bool Is_dropout, bool Is_causal> template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
@ -26,35 +25,39 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.b, params.h); dim3 grid(num_m_block, params.b, params.h);
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
// for cu_seqlens_q as well.
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool is_even_K = params.d == Kernel_traits::kHeadDim;
const bool return_softmax = params.p_ptr != nullptr; const bool return_softmax = params.p_ptr != nullptr;
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
// Will only return softmax if dropout, to reduce compilation time. BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>; BOOL_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>; // Will only return softmax if dropout, to reduce compilation time.
// if (smem_size >= 48 * 1024) { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// C10_CUDA_CHECK(cudaFuncSetAttribute( // If return_softmax, set IsEvenMNConst to false to reduce number of templates
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// } // If Is_local, set Is_causal to false
int ctas_per_sm; auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); // int ctas_per_sm;
// C10_CUDA_KERNEL_LAUNCH_CHECK(); // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
});
}); });
}); });
}); });
} }
template<typename T> template<typename T>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 32; constexpr static int Headdim = 32;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
@ -64,7 +67,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 64; constexpr static int Headdim = 64;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
@ -86,7 +89,7 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 96; constexpr static int Headdim = 96;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -112,7 +115,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 128; constexpr static int Headdim = 128;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -149,7 +152,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 160; constexpr static int Headdim = 160;
// auto dprops = at::cuda::getCurrentDeviceProperties(); // auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0; bool is_sm8x = true; // dprops->major == 8 && dprops->minor > 0;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
@ -179,7 +182,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 192; constexpr static int Headdim = 192;
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
BOOL_SWITCH(params.is_causal, Is_causal, [&] { BOOL_SWITCH(params.is_causal, Is_causal, [&] {
if constexpr(!Is_dropout) { if constexpr(!Is_dropout) {
@ -198,7 +201,7 @@ void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 224; constexpr static int Headdim = 224;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_block; int max_smem_per_block;
@ -224,7 +227,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T> template<typename T>
void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) { void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
constexpr int Headdim = 256; constexpr static int Headdim = 256;
int device; int device;
cudaGetDevice(&device); cudaGetDevice(&device);
int max_smem_per_sm, max_smem_per_block; int max_smem_per_sm, max_smem_per_block;

View File

@ -91,17 +91,20 @@ struct Flash_fwd_kernel_traits : public Base {
SmemLayoutAtomQ{}, SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{})); Shape<Int<kBlockN>, Int<kHeadDim>>{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomVtransposed = decltype( using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape( using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{}, SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{})); Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape // Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter? // And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomVtransposedNoSwizzle{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype( using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -110,7 +113,8 @@ struct Flash_fwd_kernel_traits : public Base {
using SmemLayoutO = decltype(tile_to_shape( using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{}, SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{})); Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
static constexpr int kSmemQCount = size(SmemLayoutQ{}); static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2; static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
@ -138,11 +142,11 @@ struct Flash_fwd_kernel_traits : public Base {
DefaultCopy DefaultCopy
>; >;
using GmemTiledCopyQKV = decltype( using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{}, make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype( using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtom{}, GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
@ -151,10 +155,30 @@ struct Flash_fwd_kernel_traits : public Base {
Stride<Int<kGmemThreadsPerRowP>, _1>>; Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype( using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomP{}, GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
using GmemLayoutAtomOaccum = std::conditional_t<
kBlockKSmem == 32,
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
Stride< _8, _1>>,
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
Stride< _16, _1>>
>;
using GmemTiledCopyOaccum = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
using GmemTiledCopyRotcossin = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinCont = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
}; };
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
@ -223,16 +247,19 @@ struct Flash_bwd_kernel_traits : public Base {
SmemLayoutAtomKV{}, SmemLayoutAtomKV{},
make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomKtransposed = decltype( using SmemLayoutAtomKtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutKtransposed = decltype(tile_to_shape( using SmemLayoutKtransposed = decltype(tile_to_shape(
SmemLayoutAtomKtransposed{}, SmemLayoutAtomKtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{}))); make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// Maybe the KtransposeNoSwizzle just needs to have the right shape // Maybe the KtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter? // And the strides don't matter?
using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomKtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
// TODO: generalize to other values of kBlockN // TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
@ -250,24 +277,30 @@ struct Flash_bwd_kernel_traits : public Base {
using SmemLayoutPdS = decltype(tile_to_shape( using SmemLayoutPdS = decltype(tile_to_shape(
SmemLayoutAtomPdS{}, SmemLayoutAtomPdS{},
make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>;
using SmemLayoutAtomPdStransposed = decltype( using SmemLayoutAtomPdStransposed = decltype(
composition(Swizzle<kSwizzlePdS, 3, 3>{}, composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
Stride<_1, Int<kPBlockN>>>{}));
using SmemLayoutPdStransposed = decltype(tile_to_shape( using SmemLayoutPdStransposed = decltype(tile_to_shape(
SmemLayoutAtomPdStransposed{}, SmemLayoutAtomPdStransposed{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{}))); make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomPdStransposedNoSwizzle{},
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>;
using SmemLayoutAtomQdOtransposed = decltype( using SmemLayoutAtomQdOtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutQdOtransposed = decltype(tile_to_shape( using SmemLayoutQdOtransposed = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposed{}, SmemLayoutAtomQdOtransposed{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{}))); make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
SmemLayoutAtomQdOtransposedNoSwizzle{},
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
using SmemLayoutAtomdKV = decltype( using SmemLayoutAtomdKV = decltype(
composition(Swizzle<kSwizzle, 3, 3>{}, composition(Swizzle<kSwizzle, 3, 3>{},
@ -292,13 +325,11 @@ struct Flash_bwd_kernel_traits : public Base {
static constexpr int kSmemdSCount = size(SmemLayoutPdS{}); static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
static constexpr int kSmemPCount = size(SmemLayoutPdS{}); static constexpr int kSmemPCount = size(SmemLayoutPdS{});
static constexpr int kSmemdQCount = size(SmemLayoutdQ{}); static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
static constexpr int kSmemdPsumCount = kBlockM;
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
static constexpr int kSmemSize = kSmemQdOSize static constexpr int kSmemSize = kSmemQdOSize
+ (!Is_V_in_regs + (!Is_V_in_regs
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)

View File

@ -0,0 +1,159 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include "cute/algorithm/copy.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/layout/layout.h"
#include <cutlass/numeric_types.h>
using namespace cute;
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
struct Flash_kernel_traits_sm90 {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
static constexpr bool Has_cp_async = true;
#else
using Element = cutlass::half_t;
static constexpr bool Has_cp_async = false;
#endif
using ElementAccum = float;
using index_t = uint32_t;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MMA_Atom_Arch = std::conditional_t<
std::is_same_v<elem_type, cutlass::half_t>,
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
>;
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
#else
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
#else
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
#endif
};
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
struct Flash_fwd_kernel_traits : public Base {
using Element = typename Base::Element;
using ElementAccum = typename Base::ElementAccum;
using index_t = typename Base::index_t;
static constexpr bool Has_cp_async = Base::Has_cp_async;
using SmemCopyAtom = typename Base::SmemCopyAtom;
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
// The number of threads.
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = TiledMMA<
typename Base::MMA_Atom_Arch,
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
using SmemLayoutAtomQ = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
Layout<Shape<_8, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutKV = decltype(tile_to_shape(
SmemLayoutAtomQ{},
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using SmemLayoutAtomVtransposed = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
Stride<_1, Int<kBlockKSmem>>>{}));
using SmemLayoutVtransposed = decltype(tile_to_shape(
SmemLayoutAtomVtransposed{},
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
// Maybe the VtransposeNoSwizzle just needs to have the right shape
// And the strides don't matter?
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
using SmemLayoutAtomO = decltype(
composition(Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
static constexpr int kSmemQCount = size(SmemLayoutQ{});
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
// to the same banks.
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
// from the same address by the same threadblock. This is slightly faster.
using Gmem_copy_struct = std::conditional_t<
Has_cp_async,
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
DefaultCopy
>;
using GmemTiledCopyQKV = decltype(
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemTiledCopyO = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
Stride<Int<kGmemThreadsPerRowP>, _1>>;
using GmemTiledCopyP = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
GmemLayoutAtomP{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
};
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -8,8 +8,7 @@
#include <cute/tensor.hpp> #include <cute/tensor.hpp>
#include <cutlass/cutlass.h> #include <cutlass/numeric_types.h>
#include <cutlass/array.h>
#include "philox.cuh" #include "philox.cuh"
#include "utils.h" #include "utils.h"
@ -117,15 +116,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
} }
template <typename Engine, typename Layout> template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) { inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
const int col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2; const int col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) { if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results // Without the "make_coord" we get wrong results
#pragma unroll #pragma unroll
@ -137,30 +139,30 @@ inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t
} }
} }
template <typename Engine, typename Layout> template <bool HasWSLeft=true, typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_, inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_, const int max_seqlen_k, const int row_idx_offset,
const uint32_t warp_row_stride) { const int max_seqlen_q, const int warp_row_stride,
const int window_size_left, const int window_size_right) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor"); static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32; const int lane_id = threadIdx.x % 32;
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
const uint32_t row_idx_offset = row_idx_offset_;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride; const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) { for (int i = 0; i < size<0, 0>(tensor); ++i) {
const uint32_t row_idx = row_idx_base + i * 8; const int row_idx = row_idx_base + i * 8;
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1); const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
#pragma unroll #pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8; const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll #pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) { for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = col_idx_base + j; const int col_idx = col_idx_base + j;
if (col_idx >= col_idx_limit) { if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
} }
} }
@ -174,10 +176,19 @@ inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const u
} }
} }
template <typename Engine, typename Layout>
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
const int max_seqlen_k, const int row_idx_offset,
const int max_seqlen_q, const int warp_row_stride) {
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
max_seqlen_q, warp_row_stride, -1, 0);
}
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
inline __device__ void apply_mask_causal_w_idx( inline __device__ void apply_mask_causal_w_idx(
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_) const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
{ {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout0::rank == 2, "Only support 2D Tensor");
@ -186,7 +197,7 @@ inline __device__ void apply_mask_causal_w_idx(
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
#pragma unroll #pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) { for (int mi = 0; mi < size<0>(tensor); ++mi) {
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
#pragma unroll #pragma unroll
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
@ -204,8 +215,8 @@ inline __device__ void apply_mask_causal_w_idx(
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t, inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
unsigned long long seed, unsigned long long offset, unsigned long long seed, unsigned long long offset,
uint32_t block_row_start, uint32_t block_col_start, int block_row_start, int block_col_start,
uint32_t block_row_stride) { int block_row_stride) {
// tensor has shape (8, MMA_M, MMA_N / 2) // tensor has shape (8, MMA_M, MMA_N / 2)
using T = typename Engine::value_type; using T = typename Engine::value_type;
auto encode_dropout = [](bool keep, T val) { auto encode_dropout = [](bool keep, T val) {

View File

@ -87,46 +87,6 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = flash::half2_unpack<T>(a);
float2 bf = flash::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = flash::hfma2_to_float<T>(a.x, b.x);
sum += flash::hfma2_to_float<T>(a.y, b.y);
sum += flash::hfma2_to_float<T>(a.z, b.z);
sum += flash::hfma2_to_float<T>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
struct MaxOp { struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } __device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
@ -173,10 +133,12 @@ static __device__ inline T run(T x, Operator &op) {
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1, template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy0, typename TiledCopy1> typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
Tensor4 const& tCsB, TiledMma tiled_mma, Tensor4 const& tCsB, TiledMma tiled_mma,
TiledCopy0 smem_thr_copy_A, TiledCopy1 smem_thr_copy_B) { TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
@ -184,13 +146,13 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { copy(smem_thr_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) { copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
@ -199,19 +161,20 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
typename TiledMma, typename TiledCopy> typename TiledMma, typename TiledCopy, typename ThrCopy>
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
TiledMma tiled_mma, TiledCopy smem_thr_copy_B) { TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
copy(smem_thr_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
#pragma unroll #pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) { for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) { if (i < size<2>(tCrA) - 1) {
copy(smem_thr_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
} }
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
} }
@ -225,7 +188,10 @@ inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
static_assert(decltype(size<0>(acc_layout))::value == 4); static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3); static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); // TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -241,9 +207,13 @@ inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
static_assert(mma_shape_K == 8 || mma_shape_K == 16); static_assert(mma_shape_K == 8 || mma_shape_K == 16);
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), // TD [2023-08-13]: Same error as above on Cutlass 3.2
get<0, 1>(l), // return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
get<1, 1, 1>(l)); // get<0, 1>(l),
// get<1, 1, 1>(l));
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
get<1>(get<0>(l)),
get<1>(get<1>(get<1>(l))));
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -319,9 +289,9 @@ void cp_async_wait() {
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3> typename Engine2, typename Layout2, typename Engine3, typename Layout3>
inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &S, inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, int max_MN=0) { Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
@ -335,13 +305,13 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
#pragma unroll #pragma unroll
for (int k = 0; k < size<2>(S); ++k) { for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) { if (Is_even_K || predicate_K(k)) {
copy(thr_copy, S(_, m, k), D(_, m, k)); cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) { } else if (Clear_OOB_K) {
clear(D(_, m, k)); cute::clear(D(_, m, k));
} }
} }
} else if (Clear_OOB_MN) { } else if (Clear_OOB_MN) {
clear(D(_, m, _)); cute::clear(D(_, m, _));
} }
} }
// TD [2023-04-13]: Strange that the code below can cause race condition. // TD [2023-04-13]: Strange that the code below can cause race condition.
@ -350,7 +320,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, _), D(_, m, _)); // copy(tiled_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, _)); // clear(D(_, m, _));
// } // }
@ -362,7 +332,7 @@ inline __device__ void copy(TiledCopy thr_copy, Tensor<Engine0, Layout0> const &
// #pragma unroll // #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) { // for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(thr_copy, S(_, m, k), D(_, m, k)); // copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) { // } else if (Clear_OOB_MN) {
// clear(D(_, m, k)); // clear(D(_, m, k));
// } // }

View File

@ -7,6 +7,8 @@ extern "C" {
v_ptr: *const c_void, v_ptr: *const c_void,
o_ptr: *const c_void, o_ptr: *const c_void,
softmax_lse_ptr: *const c_void, softmax_lse_ptr: *const c_void,
alibi_slopes_ptr: *const c_void,
cu_seqlens_q_ptr: *const i32, cu_seqlens_q_ptr: *const i32,
cu_seqlens_k_ptr: *const i32, cu_seqlens_k_ptr: *const i32,
@ -14,6 +16,7 @@ extern "C" {
k_batch_stride: u32, k_batch_stride: u32,
v_batch_stride: u32, v_batch_stride: u32,
o_batch_stride: u32, o_batch_stride: u32,
alibi_slopes_batch_stride: u32,
q_row_stride: u32, q_row_stride: u32,
k_row_stride: u32, k_row_stride: u32,
@ -37,8 +40,11 @@ extern "C" {
seqlen_q_rounded: u32, seqlen_q_rounded: u32,
seqlen_k_rounded: u32, seqlen_k_rounded: u32,
is_causal: c_int,
is_bf16: c_int, is_bf16: c_int,
is_causal: c_int,
window_size_left: c_int,
window_size_right: c_int,
); );
} }

View File

@ -3,12 +3,14 @@ mod ffi;
use candle::backend::BackendStorage; use candle::backend::BackendStorage;
use candle::cuda_backend::cudarc::driver::DevicePtr; use candle::cuda_backend::cudarc::driver::DevicePtr;
use candle::cuda_backend::WrapErr; use candle::cuda_backend::WrapErr;
use candle::{CpuStorage, Layout, Result, Shape, Tensor}; use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
use half::{bf16, f16}; use half::{bf16, f16};
pub struct FlashAttn { pub struct FlashAttn {
pub softmax_scale: f32, pub softmax_scale: f32,
pub causal: bool, pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
} }
fn round_multiple(x: usize, m: usize) -> usize { fn round_multiple(x: usize, m: usize) -> usize {
@ -85,6 +87,51 @@ impl FlashAttn {
candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}") candle::bail!("number of k/v heads {num_heads_k} must divide number of heads in query {num_heads}")
} }
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
"DType mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes.dtype(),
DType::F32
);
}
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
if num_heads != alibi_slopes_layout.shape().dims1()? {
candle::bail!(
"shape mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes_layout.shape(),
(num_heads)
);
}
let alibi_slopes = match &*alibi_slopes {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
};
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
// if window_size_left > self.max_seqlen_k or None => -1
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
let head_size = round_multiple(head_size_og, 8); let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32); let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(seqlen_q, 128); let seqlen_q_rounded = round_multiple(seqlen_q, 128);
@ -94,9 +141,22 @@ impl FlashAttn {
let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?; let dst = unsafe { dev.alloc::<T>(elem_count) }.w()?;
let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?; let softmax_lse = dev.alloc_zeros::<f32>(b_sz * num_heads * seqlen_q).w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 };
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
let is_causal = if window_size_left < 0 && window_size_right == 0 {
1
} else {
0
};
if window_size_left < 0 && window_size_right >= 0 {
window_size_left = seqlen_k as i32;
}
if window_size_left >= 0 && window_size_right < 0 {
window_size_right = seqlen_k as i32;
}
unsafe { unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@ -109,12 +169,14 @@ impl FlashAttn {
v_ptr, v_ptr,
dst_ptr, dst_ptr,
softmax_lse_ptr, softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ std::ptr::null(), /* cu_seqlens_q_ptr */ std::ptr::null(),
/* cu_seqlens_k_ptr */ std::ptr::null(), /* cu_seqlens_k_ptr */ std::ptr::null(),
/* q_batch_stride */ q_stride[0] as u32, /* q_batch_stride */ q_stride[0] as u32,
/* k_batch_stride */ k_stride[0] as u32, /* k_batch_stride */ k_stride[0] as u32,
/* v_batch_stride */ v_stride[0] as u32, /* v_batch_stride */ v_stride[0] as u32,
/* o_batch_stride */ o_stride[0] as u32, /* o_batch_stride */ o_stride[0] as u32,
/* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32, /* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32,
@ -133,8 +195,10 @@ impl FlashAttn {
/* seqlen_k */ seqlen_k as u32, /* seqlen_k */ seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ is_bf16, /* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
) )
} }
@ -197,20 +261,137 @@ pub fn flash_attn(
softmax_scale: f32, softmax_scale: f32,
causal: bool, causal: bool,
) -> Result<Tensor> { ) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttn { let op = FlashAttn {
softmax_scale, softmax_scale,
causal, alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttn {
softmax_scale,
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
/// Flash-attention v2 layer.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(batch, seq_len_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(batch, seq_len_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
///
/// The resulting tensor has dimensions `(batch, seq_len_q, num_heads_q, head_size)`.
pub fn flash_attn_alibi_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttn {
softmax_scale,
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
}; };
q.apply_op3(k, v, op) q.apply_op3(k, v, op)
} }
struct FlashAttnVarLen { struct FlashAttnVarLen {
softmax_scale: f32, pub softmax_scale: f32,
causal: bool, pub max_seqlen_q: usize,
max_seqlen_q: usize, pub max_seqlen_k: usize,
max_seqlen_k: usize, pub seqlens_q: Tensor,
seqlens_q: Tensor, pub seqlens_k: Tensor,
seqlens_k: Tensor, pub alibi_slopes: Option<Tensor>,
pub window_size_left: Option<usize>,
pub window_size_right: Option<usize>,
} }
impl FlashAttnVarLen { impl FlashAttnVarLen {
@ -311,7 +492,54 @@ impl FlashAttnVarLen {
if nseqlens_k != nseqlens_q { if nseqlens_k != nseqlens_q {
candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}") candle::bail!("seqlens_q and seqlens_k should have the same number of elements {nseqlens_q} <> {nseqlens_k}")
} }
let batch_size = nseqlens_q - 1; let batch_size = nseqlens_q - 1;
let alibi_slopes_ptr = if let Some(alibi_slopes) = &self.alibi_slopes {
if alibi_slopes.dtype() != DType::F32 {
candle::bail!(
"DType mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes.dtype(),
DType::F32
);
}
let (alibi_slopes, alibi_slopes_layout) = alibi_slopes.storage_and_layout();
if num_heads != alibi_slopes_layout.shape().dims1()? {
candle::bail!(
"shape mismatch alibi_slopes {:?}, expected {:?}",
alibi_slopes_layout.shape(),
(num_heads)
);
}
let alibi_slopes = match &*alibi_slopes {
candle::Storage::Cuda(c) => c.as_cuda_slice::<f32>()?,
_ => candle::bail!("alibi_slopes must be a cuda tensor"),
};
let alibi_slopes = alibi_slopes.slice(alibi_slopes_layout.start_offset()..);
*alibi_slopes.device_ptr() as *const core::ffi::c_void
} else {
std::ptr::null()
};
// if window_size_left > self.max_seqlen_k or None => -1
let mut window_size_left = self
.window_size_left
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
// if window_size_right > self.max_seqlen_k or None => -1
let mut window_size_right = self
.window_size_right
.filter(|v| v <= &self.max_seqlen_k)
.map(|v| v as i32)
.unwrap_or(-1);
let head_size = round_multiple(head_size_og, 8); let head_size = round_multiple(head_size_og, 8);
let head_size_rounded = round_multiple(head_size, 32); let head_size_rounded = round_multiple(head_size, 32);
let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128); let seqlen_q_rounded = round_multiple(self.max_seqlen_q, 128);
@ -323,9 +551,22 @@ impl FlashAttnVarLen {
.alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q) .alloc_zeros::<f32>(batch_size * num_heads * self.max_seqlen_q)
.w()?; .w()?;
let causal = if self.causal { 1 } else { 0 };
let is_bf16 = if is_bf16 { 1 } else { 0 }; let is_bf16 = if is_bf16 { 1 } else { 0 };
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
let is_causal = if window_size_left < 0 && window_size_right == 0 {
1
} else {
0
};
if window_size_left < 0 && window_size_right >= 0 {
window_size_left = self.max_seqlen_k as i32;
}
if window_size_left >= 0 && window_size_right < 0 {
window_size_right = self.max_seqlen_k as i32;
}
unsafe { unsafe {
let q_ptr = *q.device_ptr() as *const core::ffi::c_void; let q_ptr = *q.device_ptr() as *const core::ffi::c_void;
let k_ptr = *k.device_ptr() as *const core::ffi::c_void; let k_ptr = *k.device_ptr() as *const core::ffi::c_void;
@ -340,12 +581,14 @@ impl FlashAttnVarLen {
v_ptr, v_ptr,
dst_ptr, dst_ptr,
softmax_lse_ptr, softmax_lse_ptr,
/* alibi_slopes_ptr */ alibi_slopes_ptr,
/* cu_seqlens_q_ptr */ seqlens_q_ptr, /* cu_seqlens_q_ptr */ seqlens_q_ptr,
/* cu_seqlens_k_ptr */ seqlens_k_ptr, /* cu_seqlens_k_ptr */ seqlens_k_ptr,
/* q_batch_stride */ 0, /* q_batch_stride */ 0,
/* k_batch_stride */ 0, /* k_batch_stride */ 0,
/* v_batch_stride */ 0, /* v_batch_stride */ 0,
/* o_batch_stride */ 0, /* o_batch_stride */ 0,
/* alibi_slopes_batch_stride */ 0,
/* q_row_stride */ q_stride[q_rank - 3] as u32, /* q_row_stride */ q_stride[q_rank - 3] as u32,
/* k_row_stride */ k_stride[k_rank - 3] as u32, /* k_row_stride */ k_stride[k_rank - 3] as u32,
/* v_row_stride */ v_stride[v_rank - 3] as u32, /* v_row_stride */ v_stride[v_rank - 3] as u32,
@ -364,8 +607,10 @@ impl FlashAttnVarLen {
/* seqlen_k */ self.max_seqlen_k as u32, /* seqlen_k */ self.max_seqlen_k as u32,
/* seqlen_q_rounded */ seqlen_q_rounded as u32, /* seqlen_q_rounded */ seqlen_q_rounded as u32,
/* seqlen_k_rounded */ seqlen_k_rounded as u32, /* seqlen_k_rounded */ seqlen_k_rounded as u32,
/* is_causal */ causal,
/* is_bf16 */ is_bf16, /* is_bf16 */ is_bf16,
/* is_causal */ is_causal,
/* window_size_left */ window_size_left,
/* window_size_right */ window_size_right,
) )
} }
@ -440,13 +685,176 @@ pub fn flash_attn_varlen(
softmax_scale: f32, softmax_scale: f32,
causal: bool, causal: bool,
) -> Result<Tensor> { ) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttnVarLen { let op = FlashAttnVarLen {
softmax_scale, softmax_scale,
causal,
max_seqlen_q, max_seqlen_q,
max_seqlen_k, max_seqlen_k,
seqlens_q: seqlens_q.clone(), seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(), seqlens_k: seqlens_k.clone(),
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: None,
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
pub fn flash_attn_varlen_alibi(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
let window_size_left = None;
let window_size_right = if causal { Some(0) } else { None };
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
};
q.apply_op3(k, v, op)
}
#[allow(clippy::too_many_arguments)]
/// Flash-attention v2 layer with variable-length batching.
///
/// This implements scaled dot-product attention, `softmax(Q @ K^T . softmax_scale) @ V`.
/// Multi-query and grouped-query attention are supported by using tensors k and v with fewer heads
/// than q, the number of heads in k and v has to be divisible by the number of heads in q.
///
/// # Arguments
///
/// * `q` - Query tensor with shape `(total_q, num_heads_q, head_size)`.
/// * `k` - Key tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `v` - Value tensor with shape `(total_kv, num_heads_kv, head_size)`.
/// * `alibi_slopes` - Alibi slopes tensor with shape `(num_heads_q)`.
/// * `seqlens_q` - The cumulative lengths of the sequences in the batch, used to index in q.
/// * `seqlens_k` - The cumulative lengths of the sequences in the batch, used to index in k and v.
/// * `max_seqlen_q` - The maximum query sequence length for q in the batch.
/// * `max_seqlen_k` - The maximum query sequence length for k and v in the batch.
/// * `window_size_left` - Limit left attention to value tokens.
/// * `window_size_right` - Limit right attention to value tokens.
///
/// `seqlens_q` and `seqlens_k` contain `batch_size + 1` elements, typically `0`, `seqlen_1`,
/// `seqlen_1 + seqlen_2`, etc.
///
/// The resulting tensor has dimensions `(total_q, num_heads_q, head_size)`.
///
/// # Causal mask
///
/// `window_size_left=None` with `window_size_right=Some(0)` applies a causal mask to the result
/// of `Q @ K^T`
pub fn flash_attn_varlen_alibi_windowed(
q: &Tensor,
k: &Tensor,
v: &Tensor,
alibi_slopes: &Tensor,
seqlens_q: &Tensor,
seqlens_k: &Tensor,
max_seqlen_q: usize,
max_seqlen_k: usize,
softmax_scale: f32,
window_size_left: Option<usize>,
window_size_right: Option<usize>,
) -> Result<Tensor> {
let op = FlashAttnVarLen {
softmax_scale,
max_seqlen_q,
max_seqlen_k,
seqlens_q: seqlens_q.clone(),
seqlens_k: seqlens_k.clone(),
alibi_slopes: Some(alibi_slopes.clone()),
window_size_left,
window_size_right,
}; };
q.apply_op3(k, v, op) q.apply_op3(k, v, op)
} }

View File

@ -12,6 +12,4 @@ license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
[build-dependencies] [build-dependencies]
anyhow = { version = "1", features = ["backtrace"] } bindgen_cuda = "0.1.1"
glob = "0.3.1"
rayon = "1.7.0"

View File

@ -1,243 +1,8 @@
use std::io::Write;
fn main() { fn main() {
println!("cargo:rerun-if-changed=build.rs"); println!("cargo:rerun-if-changed=build.rs");
cuda::set_include_dir(); let builder = bindgen_cuda::Builder::default();
let (write, kernel_paths) = cuda::build_ptx(); println!("cargo:info={builder:?}");
if write { let bindings = builder.build_ptx().unwrap();
let mut file = std::fs::File::create("src/lib.rs").unwrap(); bindings.write("src/lib.rs").unwrap();
for kernel_path in kernel_paths {
let name = kernel_path.file_stem().unwrap().to_str().unwrap();
file.write_all(
format!(
r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}.ptx"));"#,
name.to_uppercase().replace('.', "_"),
name
)
.as_bytes(),
)
.unwrap();
file.write_all(&[b'\n']).unwrap();
}
}
}
mod cuda {
use anyhow::{Context, Result};
pub fn set_include_dir() {
use std::path::PathBuf;
// NOTE: copied from cudarc build.rs.
// We can't actually set a env!() value from another crate,
// so we have to do that here.
// use PathBuf;
let env_vars = [
"CUDA_PATH",
"CUDA_ROOT",
"CUDA_TOOLKIT_ROOT_DIR",
"CUDNN_LIB",
];
#[allow(unused)]
let env_vars = env_vars
.into_iter()
.map(std::env::var)
.filter_map(Result::ok)
.map(Into::<PathBuf>::into);
let roots = [
"/usr",
"/usr/local/cuda",
"/opt/cuda",
"/usr/lib/cuda",
"C:/Program Files/NVIDIA GPU Computing Toolkit",
"C:/CUDA",
];
#[allow(unused)]
let roots = roots.into_iter().map(Into::<PathBuf>::into);
#[cfg(feature = "ci-check")]
let root: PathBuf = "ci".into();
#[cfg(not(feature = "ci-check"))]
let root = env_vars
.chain(roots)
.find(|path| path.join("include").join("cuda.h").is_file())
.unwrap();
println!(
"cargo:rustc-env=CUDA_INCLUDE_DIR={}",
root.join("include").display()
);
}
pub fn build_ptx() -> (bool, Vec<std::path::PathBuf>) {
use rayon::prelude::*;
use std::path::PathBuf;
let out_dir = std::env::var("OUT_DIR").unwrap();
let kernel_paths: Vec<PathBuf> = glob::glob("src/*.cu")
.unwrap()
.map(|p| p.unwrap())
.collect();
let mut include_directories: Vec<PathBuf> = glob::glob("src/**/*.cuh")
.unwrap()
.map(|p| p.unwrap())
.collect();
println!("cargo:rerun-if-changed=src/");
// for path in &kernel_paths {
// println!("cargo:rerun-if-changed={}", path.display());
// }
for path in &mut include_directories {
// println!("cargo:rerun-if-changed={}", path.display());
let destination =
std::format!("{out_dir}/{}", path.file_name().unwrap().to_str().unwrap());
std::fs::copy(path.clone(), destination).unwrap();
// remove the filename from the path so it's just the directory
path.pop();
}
include_directories.sort();
include_directories.dedup();
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
#[allow(unused)]
let include_options: Vec<String> = include_directories
.into_iter()
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
.collect::<Vec<_>>();
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
let ignore = if output_filename.exists() {
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
let in_modified = p.metadata().unwrap().modified().unwrap();
out_modified.duration_since(in_modified).is_ok()
} else {
false
};
if ignore {
None
} else {
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}
})
.collect::<Vec<_>>();
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
.unwrap()
.map(|p| p.unwrap())
.collect();
// We should rewrite `src/lib.rs` only if there are some newly compiled kernels, or removed
// some old ones
let write = !children.is_empty() || kernel_paths.len() < ptx_paths.len();
for (kernel_path, child) in children {
let output = child.expect("nvcc failed to run. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
assert!(
output.status.success(),
"nvcc error while compiling {kernel_path:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}",
String::from_utf8_lossy(&output.stdout),
String::from_utf8_lossy(&output.stderr)
);
}
(write, kernel_paths)
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}
} }

View File

@ -1 +1,9 @@
pub const AFFINE: &str = include_str!(concat!(env!("OUT_DIR"), "/affine.ptx"));
pub const BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/binary.ptx"));
pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));

View File

@ -15,7 +15,6 @@ const CAST: &str = include_str!("cast.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
const CONV: &str = include_str!("conv.metal"); const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
/// Most kernels apply similarly across the tensors /// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the /// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
@ -63,8 +62,6 @@ macro_rules! primitive {
}; };
} }
primitive!(usize); primitive!(usize);
primitive!(i64);
primitive!(i32);
primitive!(u32); primitive!(u32);
primitive!(f32); primitive!(f32);
@ -120,7 +117,6 @@ pub enum Source {
Reduce, Reduce,
Mfa, Mfa,
Conv, Conv,
Quantized,
} }
macro_rules! ops{ macro_rules! ops{
@ -219,17 +215,17 @@ type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipeline
pub struct Kernels { pub struct Kernels {
libraries: RwLock<Libraries>, libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>, pipelines: RwLock<Pipelines>,
// fence: metal::Fence, fence: metal::Fence,
} }
impl Kernels { impl Kernels {
pub fn new() -> Self { pub fn new(fence: metal::Fence) -> Self {
let libraries = RwLock::new(Libraries::new()); let libraries = RwLock::new(Libraries::new());
let pipelines = RwLock::new(Pipelines::new()); let pipelines = RwLock::new(Pipelines::new());
Self { Self {
libraries, libraries,
pipelines, pipelines,
// fence, fence,
} }
} }
@ -243,7 +239,6 @@ impl Kernels {
Source::Cast => CAST, Source::Cast => CAST,
Source::Reduce => REDUCE, Source::Reduce => REDUCE,
Source::Conv => CONV, Source::Conv => CONV,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"), Source::Mfa => panic!("Invalid lib"),
} }
} }
@ -350,7 +345,7 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output)); set_params!(encoder, (length, input, output));
@ -359,7 +354,7 @@ pub fn call_unary_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -381,7 +376,7 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -403,7 +398,7 @@ pub fn call_unary_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -422,7 +417,7 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?; let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output)); set_params!(encoder, (length, left, right, output));
@ -433,7 +428,7 @@ pub fn call_binary_contiguous(
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -458,7 +453,7 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len(); let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product(); let width: usize = shape.iter().product();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -483,7 +478,7 @@ pub fn call_binary_strided(
encoder.use_resource(right_input, metal::MTLResourceUsage::Read); encoder.use_resource(right_input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -502,7 +497,7 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output)); set_params!(encoder, (length, (input, input_offset), output));
@ -511,7 +506,7 @@ pub fn call_cast_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -531,7 +526,7 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product(); let length: usize = shape.iter().product();
@ -553,7 +548,7 @@ pub fn call_cast_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -573,7 +568,7 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -602,7 +597,7 @@ pub fn call_reduce_contiguous(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -624,7 +619,7 @@ pub fn call_reduce_strided(
let elements_to_sum = length / out_length; let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -660,7 +655,7 @@ pub fn call_reduce_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -679,7 +674,7 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?; let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -710,7 +705,7 @@ pub fn call_last_softmax(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -730,7 +725,7 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output)); set_params!(encoder, (size, mul, add, input, output));
@ -739,7 +734,7 @@ pub fn call_affine(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -762,7 +757,7 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -783,7 +778,7 @@ pub fn call_affine_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -802,7 +797,7 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
@ -811,7 +806,7 @@ pub fn call_powf(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -833,7 +828,7 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -853,7 +848,7 @@ pub fn call_powf_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -872,7 +867,7 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?; let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output)); set_params!(encoder, (size, mul, input, output));
@ -881,7 +876,7 @@ pub fn call_elu(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -903,7 +898,7 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -923,7 +918,7 @@ pub fn call_elu_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -945,7 +940,7 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?; let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product(); let size: usize = shape.iter().product();
@ -974,7 +969,7 @@ pub fn call_where_cond_strided(
encoder.use_resource(right, metal::MTLResourceUsage::Read); encoder.use_resource(right, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1001,7 +996,7 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1024,7 +1019,7 @@ pub fn call_index_select(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1053,7 +1048,7 @@ pub fn call_gather(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1076,7 +1071,7 @@ pub fn call_gather(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1105,7 +1100,7 @@ pub fn call_scatter_add(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1128,7 +1123,7 @@ pub fn call_scatter_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1158,7 +1153,7 @@ pub fn call_index_add(
let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?; let pipeline = kernels.load_pipeline(device, Source::Indexing, name)?;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
@ -1182,7 +1177,7 @@ pub fn call_index_add(
encoder.use_resource(ids, metal::MTLResourceUsage::Read); encoder.use_resource(ids, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
} }
@ -1386,7 +1381,7 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes; let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
@ -1426,11 +1421,12 @@ pub fn call_gemm(
height: 1, height: 1,
depth: 1, depth: 1,
}; };
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_size, group_size); encoder.dispatch_thread_groups(grid_size, group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1455,7 +1451,7 @@ pub fn call_im2col1d_strided(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1475,7 +1471,7 @@ pub fn call_im2col1d_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1505,7 +1501,7 @@ pub fn call_im2col_strided(
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1527,7 +1523,7 @@ pub fn call_im2col_strided(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())
@ -1553,7 +1549,7 @@ pub fn call_upsample_nearest_2d(
let scale_h = shape[3] as f32 / out_h as f32; let scale_h = shape[3] as f32 / out_h as f32;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el); let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence); encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
set_params!( set_params!(
encoder, encoder,
@ -1571,179 +1567,7 @@ pub fn call_upsample_nearest_2d(
encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(input, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size); encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
//encoder.update_fence(&kernels.fence); encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
#[derive(Debug, Clone, Copy)]
pub enum GgmlDType {
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
F16,
F32,
}
pub fn call_quantized_matmul_t(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
dtype: GgmlDType,
(b, m, n, k): (usize, usize, usize, usize),
lhs: &Buffer,
lhs_offset: usize,
rhs: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
// Everything is in reverse
let ne00 = k as i64;
let ne01 = n as i64;
let ne02 = b as i64;
let ne03 = 1 as i64;
let nb00 = 0i64;
let nb01 = 0 as i64;
let nb02 = 0 as i64;
let ne10 = k as i64;
let ne11 = m as i64;
let ne12 = b as i64;
let ne13 = 1 as i64;
let nb10 = 0i64;
let nb11 = 0i64;
let nb12 = 0i64;
let ne0 = n as i64;
let ne1 = m as i64;
let r2: u32 = (ne12 / ne02) as u32;
let r3: u32 = (ne13 / ne03) as u32;
let (nth0, nth1, align) = match dtype {
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q8_1 => {
let nth0 = 8;
let nth1 = 8;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q4K => {
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q3K | GgmlDType::Q5K => {
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q6K => {
let nth0 = 2;
let nth1 = 32;
let align = 2;
(nth0, nth1, align)
}
GgmlDType::F16 | GgmlDType::Q8K => {
// Original implem uses rows
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::F32 => {
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
};
let thread_groups_count = MTLSize {
width: divide(ne01 as usize, align),
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
let name = match dtype {
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
};
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = command_buffer.new_compute_command_encoder();
//encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
rhs,
(lhs, lhs_offset),
output,
ne00,
ne01,
ne02,
nb00,
nb01,
nb02,
ne10,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3
)
);
encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
//encoder.update_fence(&kernels.fence);
encoder.end_encoding(); encoder.end_encoding();
Ok(()) Ok(())

File diff suppressed because it is too large Load Diff

View File

@ -110,6 +110,7 @@ UNARY_OP(gelu_erf)
UNARY_OP(erf) UNARY_OP(erf)
UNARY_OP(tanh) UNARY_OP(tanh)
UNARY_OP(recip) UNARY_OP(recip)
UNARY(id, float, copy_f32, copy_f32_strided) UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided) UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided) UNARY(id, uint8_t, copy_u8, copy_u8_strided)
@ -128,7 +129,6 @@ BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp) BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log) BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu) BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(abs)
BFLOAT_UNARY_OP(ceil) BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor) BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round) BFLOAT_UNARY_OP(round)

View File

@ -11,7 +11,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
half = { workspace = true } half = { workspace = true }
thiserror = { workspace = true } thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
@ -20,7 +20,7 @@ rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
metal = { workspace = true, optional = true } metal = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true } candle-metal-kernels = { workspace = true, optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -222,10 +222,7 @@ impl Benchmark for QMatMul {
type RunResult = Tensor; type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> { fn preprocess() -> Result<Self::PreProcessData> {
let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32]; let zeros = vec![candle::quantized::k_quants::BlockQ4_0::zeros(); 4096 * 11008 / 32];
let mm = candle::quantized::QTensor::new( let mm = candle::quantized::QTensor::new(zeros, (4096, 11008))?;
candle::quantized::QStorage::Cpu(Box::new(zeros)),
(4096, 11008),
)?;
let mm = candle::quantized::QMatMul::from_qtensor(mm)?; let mm = candle::quantized::QMatMul::from_qtensor(mm)?;
let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?; let arg = Tensor::randn(0f32, 1., (128, 11008), &Device::Cpu)?;
Ok((mm, arg)) Ok((mm, arg))

View File

@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { path = "../candle-core", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { path = "../candle-nn" }
prost = "0.12.1" prost = "0.12.1"
[build-dependencies] [build-dependencies]
@ -20,4 +20,3 @@ prost-build = "0.12.1"
[dev-dependencies] [dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }

View File

@ -15,9 +15,9 @@ crate-type = ["cdylib"]
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true} candle-onnx = { workspace = true, optional = true }
half = { workspace = true } half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] } pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }

View File

@ -33,9 +33,7 @@ def has_mkl() -> bool:
pass pass
@staticmethod @staticmethod
def load_ggml( def load_ggml(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]:
""" """
Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
@ -43,9 +41,7 @@ def load_ggml(
pass pass
@staticmethod @staticmethod
def load_gguf( def load_gguf(path: Union[str, PathLike]) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
path: Union[str, PathLike], device: Optional[Device] = None
) -> Tuple[Dict[str, QTensor], Dict[str, Any]]:
""" """
Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
and the second maps metadata keys to metadata values. and the second maps metadata keys to metadata values.

View File

@ -1074,20 +1074,20 @@ impl PyTensor {
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> { fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized; use ::candle::quantized;
let res = match quantized_dtype.to_lowercase().as_str() { let res = match quantized_dtype.to_lowercase().as_str() {
"q2k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q2K), "q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
"q3k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q3K), "q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
"q4_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_0), "q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
"q4_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4_1), "q4_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_1>(self),
"q4k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q4K), "q4k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4K>(self),
"q5_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_0), "q5_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_0>(self),
"q5_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5_1), "q5_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5_1>(self),
"q5k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q5K), "q5k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ5K>(self),
"q6k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q6K), "q6k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ6K>(self),
"q8_0" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_0), "q8_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_0>(self),
"q8_1" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8_1), "q8_1" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8_1>(self),
"q8k" => quantized::QTensor::quantize(self, quantized::GgmlDType::Q8K), "q8k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ8K>(self),
"f16" => quantized::QTensor::quantize(self, quantized::GgmlDType::F16), "f16" => quantized::QTensor::quantize::<f16>(self),
"f32" => quantized::QTensor::quantize(self, quantized::GgmlDType::F32), "f32" => quantized::QTensor::quantize::<f32>(self),
dt => { dt => {
return Err(PyErr::new::<PyValueError, _>(format!( return Err(PyErr::new::<PyValueError, _>(format!(
"unknown quantized-dtype {dt}" "unknown quantized-dtype {dt}"
@ -1278,19 +1278,13 @@ fn save_safetensors(
} }
#[pyfunction] #[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] #[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, /// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors,
/// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. /// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]]
fn load_ggml( fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject, PyObject)> {
path: &str,
device: Option<PyDevice>,
py: Python<'_>,
) -> PyResult<(PyObject, PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?; let mut file = std::fs::File::open(path)?;
let device = device.unwrap_or(PyDevice::Cpu).as_device()?; let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
let ggml =
::candle::quantized::ggml_file::Content::read(&mut file, &device).map_err(wrap_err)?;
let tensors = ggml let tensors = ggml
.tensors .tensors
.into_iter() .into_iter()
@ -1319,16 +1313,11 @@ fn load_ggml(
} }
#[pyfunction] #[pyfunction]
#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] #[pyo3(text_signature = "(path:Union[str,PathLike])")]
/// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, /// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors,
/// and the second maps metadata keys to metadata values. /// and the second maps metadata keys to metadata values.
/// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]]
fn load_gguf( fn load_gguf(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
path: &str,
device: Option<PyDevice>,
py: Python<'_>,
) -> PyResult<(PyObject, PyObject)> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
use ::candle::quantized::gguf_file; use ::candle::quantized::gguf_file;
fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> { fn gguf_value_to_pyobject(v: &gguf_file::Value, py: Python<'_>) -> PyResult<PyObject> {
let v: PyObject = match v { let v: PyObject = match v {
@ -1360,7 +1349,7 @@ fn load_gguf(
.tensor_infos .tensor_infos
.keys() .keys()
.map(|key| { .map(|key| {
let qtensor = gguf.tensor(&mut file, key, &device)?; let qtensor = gguf.tensor(&mut file, key)?;
Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))) Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py)))
}) })
.collect::<::candle::Result<Vec<_>>>() .collect::<::candle::Result<Vec<_>>>()

View File

@ -12,9 +12,9 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true } byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true } candle-flash-attn = { workspace = true, optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true } num-traits = { workspace = true }
rand = { workspace = true } rand = { workspace = true }

View File

@ -356,7 +356,6 @@ impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>( pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content, ct: gguf_file::Content,
reader: &mut R, reader: &mut R,
device: &Device,
) -> Result<Self> { ) -> Result<Self> {
let cpu = &Device::Cpu; let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) { let md_get = |s: &str| match ct.metadata.get(s) {
@ -384,28 +383,21 @@ impl ModelWeights {
.unwrap_or(10000f32); .unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?; let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight")?;
let tok_embeddings = tok_embeddings.dequantize(cpu)?; let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new( let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight")?, rms_norm_eps)?;
ct.tensor(reader, "output_norm.weight", device)?, let output = ct.tensor(reader, "output.weight")?;
rms_norm_eps,
)?;
let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count); let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count { for layer_idx in 0..block_count {
let prefix = format!("blk.{layer_idx}"); let prefix = format!("blk.{layer_idx}");
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?; let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"))?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?; let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?; let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp_or_moe = if n_expert <= 1 { let mlp_or_moe = if n_expert <= 1 {
let feed_forward_w1 = let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w2 = let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
MlpOrMoe::Mlp(Mlp { MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
@ -413,15 +405,15 @@ impl ModelWeights {
}) })
} else { } else {
let feed_forward_gate_inp = let feed_forward_gate_inp =
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
let mut experts = Vec::with_capacity(n_expert); let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert { for i in 0..n_expert {
let feed_forward_w1 = let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"), device)?; ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
let feed_forward_w2 = let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"), device)?; ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
let feed_forward_w3 = let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"), device)?; ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
experts.push(Mlp { experts.push(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
@ -434,9 +426,8 @@ impl ModelWeights {
experts, experts,
} }
}; };
let attention_norm = let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?; let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp"); let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");

View File

@ -165,13 +165,9 @@ impl Attention {
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?; .transpose(1, 2)?;
// let (query_states1, key_states1) = let (query_states, key_states) =
// self.rotary_emb self.rotary_emb
// .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?; .apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
// println!("{query_states:?} {query_states1:?}");
// println!("{key_states:?} {key_states1:?}");
let query_states = query_states.contiguous()?;
let key_states = key_states.contiguous()?;
let (key_states, value_states) = match &self.kv_cache { let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states), None => (key_states, value_states),

View File

@ -199,10 +199,7 @@ impl MHA {
Some((prev_k, _)) => prev_k.dim(1)?, Some((prev_k, _)) => prev_k.dim(1)?,
}; };
// In the python implementation, a single tensor is returned with the third axis of size 3. // In the python implementation, a single tensor is returned with the third axis of size 3.
// let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?; let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
let q = qkv.i((.., .., 0))?;
let k = qkv.i((.., .., 1))?;
let v = qkv.i((.., .., 2))?;
let (k, v) = match &self.kv_cache { let (k, v) = match &self.kv_cache {
None => (k, v), None => (k, v),
Some((prev_k, prev_v)) => { Some((prev_k, prev_v)) => {
@ -314,7 +311,7 @@ impl MixFormerSequentialForCausalLM {
let mut blocks = Vec::new(); let mut blocks = Vec::new();
for i in 0..cfg.n_layer { for i in 0..cfg.n_layer {
let block = ParallelBlock::new(cfg, vb.pp(i + 1))?; let block = ParallelBlock::new(cfg, vb.pp(i + 1))?;
blocks.push(block); blocks.push(block)
} }
let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?; let head = CausalLMHead::new(cfg, vb.pp(cfg.n_layer + 1))?;
Ok(Self { Ok(Self {
@ -335,7 +332,7 @@ impl MixFormerSequentialForCausalLM {
Some(get_mask(seq_len, xs.device())?) Some(get_mask(seq_len, xs.device())?)
}; };
for block in self.blocks.iter_mut() { for block in self.blocks.iter_mut() {
xs = block.forward(&xs, mask.as_ref())?; xs = block.forward(&xs, mask.as_ref())?
} }
xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1) xs.narrow(1, seq_len - 1, 1)?.apply(&self.head)?.squeeze(1)
} }

View File

@ -10,33 +10,33 @@ pub struct VarBuilder {
} }
impl VarBuilder { impl VarBuilder {
pub fn from_gguf<P: AsRef<std::path::Path>>(p: P, device: &Device) -> Result<Self> { pub fn from_gguf<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let mut file = std::fs::File::open(p)?; let mut file = std::fs::File::open(p)?;
let content = candle::quantized::gguf_file::Content::read(&mut file)?; let content = candle::quantized::gguf_file::Content::read(&mut file)?;
let mut data = std::collections::HashMap::new(); let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() { for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut file, tensor_name, device)?; let tensor = content.tensor(&mut file, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor)); data.insert(tensor_name.to_string(), Arc::new(tensor));
} }
Ok(Self { Ok(Self {
data: Arc::new(data), data: Arc::new(data),
path: Vec::new(), path: Vec::new(),
device: device.clone(), device: Device::Cpu,
}) })
} }
pub fn from_gguf_buffer(buffer: &[u8], device: &Device) -> Result<Self> { pub fn from_gguf_buffer(buffer: &[u8]) -> Result<Self> {
let mut cursor = std::io::Cursor::new(buffer); let mut cursor = std::io::Cursor::new(buffer);
let content = candle::quantized::gguf_file::Content::read(&mut cursor)?; let content = candle::quantized::gguf_file::Content::read(&mut cursor)?;
let mut data = std::collections::HashMap::new(); let mut data = std::collections::HashMap::new();
for tensor_name in content.tensor_infos.keys() { for tensor_name in content.tensor_infos.keys() {
let tensor = content.tensor(&mut cursor, tensor_name, device)?; let tensor = content.tensor(&mut cursor, tensor_name)?;
data.insert(tensor_name.to_string(), Arc::new(tensor)); data.insert(tensor_name.to_string(), Arc::new(tensor));
} }
Ok(Self { Ok(Self {
data: Arc::new(data), data: Arc::new(data),
path: Vec::new(), path: Vec::new(),
device: device.clone(), device: Device::Cpu,
}) })
} }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true } num-traits = { workspace = true }

View File

@ -61,7 +61,7 @@ impl Model {
let start = Date::now(); let start = Date::now();
let model: SelectedModel = if quantized { let model: SelectedModel = if quantized {
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights, &device)?; let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?; let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
SelectedModel::Q(model) SelectedModel::Q(model)
} else { } else {

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }
@ -31,7 +31,7 @@ js-sys = "0.3.64"
wasm-bindgen = "0.2.87" wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37" wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2" wasm-logger = "0.2"
yew-agent = "0.2.0" yew-agent = "0.3.0"
yew = { version = "0.20.0", features = ["csr"] } yew = { version = "0.20.0", features = ["csr"] }
[dependencies.web-sys] [dependencies.web-sys]

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true } num-traits = { workspace = true }

View File

@ -41,7 +41,6 @@ impl Model {
) -> Result<Model, JsError> { ) -> Result<Model, JsError> {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let device = Device::Cpu;
let name: ModelName = serde_json::from_slice(&config)?; let name: ModelName = serde_json::from_slice(&config)?;
let config: Config = serde_json::from_slice(&config)?; let config: Config = serde_json::from_slice(&config)?;
@ -51,9 +50,8 @@ impl Model {
let start = Date::now(); let start = Date::now();
console_log!("weights len: {:?}", weights.len()); console_log!("weights len: {:?}", weights.len());
let model = if quantized { let model = if quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( let vb =
&weights, &device, candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
)?;
console_log!("weights loaded"); console_log!("weights loaded");
if name._name_or_path == "microsoft/phi-2" { if name._name_or_path == "microsoft/phi-2" {
let model = QMixFormer::new_v2(&config, vb)?; let model = QMixFormer::new_v2(&config, vb)?;

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
# App crates. # App crates.

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -7,7 +7,6 @@ pub use candle_transformers::models::quantized_t5::{
use candle_wasm_example_t5::console_log; use candle_wasm_example_t5::console_log;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
const DEVICE: Device = Device::Cpu;
#[wasm_bindgen] #[wasm_bindgen]
pub struct ModelEncoder { pub struct ModelEncoder {
@ -32,7 +31,7 @@ impl ModelConditionalGeneration {
) -> Result<ModelConditionalGeneration, JsError> { ) -> Result<ModelConditionalGeneration, JsError> {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let vb = VarBuilder::from_gguf_buffer(&weights)?;
let mut config: Config = serde_json::from_slice(&config)?; let mut config: Config = serde_json::from_slice(&config)?;
let tokenizer = let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?; Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
@ -47,7 +46,7 @@ impl ModelConditionalGeneration {
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let input: ConditionalGenerationParams = let input: ConditionalGenerationParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;
let device = &DEVICE; let device = &Device::Cpu;
self.model.clear_kv_cache(); self.model.clear_kv_cache();
let mut output_token_ids = [self.config.pad_token_id as u32].to_vec(); let mut output_token_ids = [self.config.pad_token_id as u32].to_vec();
let prompt = input.prompt; let prompt = input.prompt;
@ -129,7 +128,7 @@ impl ModelEncoder {
) -> Result<ModelEncoder, JsError> { ) -> Result<ModelEncoder, JsError> {
console_error_panic_hook::set_once(); console_error_panic_hook::set_once();
console_log!("loading model"); console_log!("loading model");
let vb = VarBuilder::from_gguf_buffer(&weights, &DEVICE)?; let vb = VarBuilder::from_gguf_buffer(&weights)?;
let mut config: Config = serde_json::from_slice(&config)?; let mut config: Config = serde_json::from_slice(&config)?;
config.use_cache = false; config.use_cache = false;
let tokenizer = let tokenizer =
@ -139,7 +138,7 @@ impl ModelEncoder {
} }
pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> { pub fn decode(&mut self, input: JsValue) -> Result<JsValue, JsError> {
let device = &DEVICE; let device = &Device::Cpu;
let input: DecoderParams = let input: DecoderParams =
serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?; serde_wasm_bindgen::from_value(input).map_err(|m| JsError::new(&m.to_string()))?;

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" } candle-transformers = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] } tokenizers = { workspace = true, features = ["unstable_wasm"] }
@ -31,7 +31,7 @@ js-sys = "0.3.64"
wasm-bindgen = "0.2.87" wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37" wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2" wasm-logger = "0.2"
yew-agent = "0.2.0" yew-agent = "0.3.0"
yew = { version = "0.20.0", features = ["csr"] } yew = { version = "0.20.0", features = ["csr"] }
[dependencies.web-sys] [dependencies.web-sys]

View File

@ -315,7 +315,6 @@ impl Decoder {
let model = if md.quantized { let model = if md.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer( let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(
&md.weights, &md.weights,
&device,
)?; )?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?) Model::Quantized(m::quantized_model::Whisper::load(&vb, config)?)
} else { } else {

View File

@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true license.workspace = true
[dependencies] [dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
candle-nn = { path = "../../candle-nn", version = "0.3.3" } candle-nn = { workspace = true }
num-traits = { workspace = true } num-traits = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
@ -31,7 +31,7 @@ js-sys = "0.3.64"
wasm-bindgen = "0.2.87" wasm-bindgen = "0.2.87"
wasm-bindgen-futures = "0.4.37" wasm-bindgen-futures = "0.4.37"
wasm-logger = "0.2" wasm-logger = "0.2"
yew-agent = "0.2.0" yew-agent = "0.3.0"
yew = { version = "0.20.0", features = ["csr"] } yew = { version = "0.20.0", features = ["csr"] }
[dependencies.web-sys] [dependencies.web-sys]

View File

@ -7,7 +7,7 @@ keywords.workspace = true
categories.workspace = true categories.workspace = true
[dependencies] [dependencies]
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" } candle = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
getrandom = { version = "0.2", features = ["js"] } getrandom = { version = "0.2", features = ["js"] }

View File

@ -40,7 +40,7 @@ fn quantized_matmul_neg() -> Result<()> {
] ]
); );
let qtensor = quantized::QTensor::new(quantized::QStorage::Cpu(Box::new(rhs_t)), (4, 64))?; let qtensor = quantized::QTensor::new(rhs_t, (4, 64))?;
let matmul = quantized::QMatMul::from_qtensor(qtensor)?; let matmul = quantized::QMatMul::from_qtensor(qtensor)?;
let res = matmul.forward(&tensor_lhs)?; let res = matmul.forward(&tensor_lhs)?;
assert_eq!( assert_eq!(