Compare commits

..

7 Commits

Author SHA1 Message Date
3f3730b657 Preliminary implementation for the vocos model. 2024-02-14 22:16:09 +01:00
058a910d0e Add a readme for rwkv. (#1712) 2024-02-14 15:31:33 +01:00
26fe162ab5 Custom tokenizer for rwkv. (#1711)
* Custom tokenizer for rwkv.

* Custom tokenizer.

* Getting the tokenizer to work.
2024-02-14 15:11:38 +01:00
121a71e01f Fix the silu cuda kernel. (#1710) 2024-02-14 11:08:18 +01:00
2d5f2a728d Add the RWKV model (v5). (#1707)
* Start adding the RWKV model.

* More of the forward step.

* Handle rescaling.

* FeedForward.

* More work on RWKV.

* Better state tracking.

* Finish a first pass on forward.

* Fix the shape mismatches.

* Do not rescale in f32.

* Rename to rwkv-v5.

* Add the new models to the readme.
2024-02-14 10:58:32 +01:00
68f7655895 Add ConvNeXt-V2 and smaller model variants. (#1709) 2024-02-14 10:53:07 +01:00
b60064780d feat: add silu activation function (#1706)
* feat: add silu activation function

* use silu/arg in grad

* update candle-nn

* use node
2024-02-14 10:27:22 +01:00
40 changed files with 1377 additions and 2855 deletions

View File

@ -75,6 +75,9 @@ We also provide a some command line based examples using state of the art models
experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
performance.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
(English/Chinese) general LLMs with 6b and 34b parameters.
@ -193,6 +196,8 @@ If you have an addition to this list, please submit a pull request.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Qwen1.5.
- RWKV.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
@ -210,7 +215,8 @@ If you have an addition to this list, please submit a pull request.
- BLIP.
- TrOCR.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.

View File

@ -380,6 +380,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vvtanh(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vvexpf(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vvexp(y.as_mut_ptr(), y.as_ptr(), &(y.len() as i32)) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -402,6 +412,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
}
}
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
#[inline]

View File

@ -589,6 +589,13 @@ impl Tensor {
let relu_grad = arg.ge(&arg.zeros_like()?)?.to_dtype(arg.dtype())?;
*sum_grad = sum_grad.add(&(&grad * relu_grad)?)?
}
Op::Unary(arg, UnaryOp::Silu) => {
let sum_grad = grads.or_insert(arg)?;
// d/dx silu = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
let sigmoid_arg = (*node / arg)?;
let silu_grad = (&sigmoid_arg * (1. + (arg * (1. - &sigmoid_arg)?)?)?)?;
*sum_grad = sum_grad.add(&(&grad * silu_grad)?)?
}
Op::Elu(arg, alpha) => {
// d/dx elu(x) = 1 for x > 0, alpha * e^x for x <= 0
let sum_grad = grads.or_insert(arg)?;

View File

@ -679,6 +679,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
("uerf", DType::F32) => contiguous::erf::FLOAT,
("usilu", DType::F32) => contiguous::silu::FLOAT,
("uabs", DType::F32) => contiguous::abs::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
@ -696,6 +697,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => contiguous::gelu::HALF,
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
("uerf", DType::F16) => contiguous::erf::HALF,
("usilu", DType::F16) => contiguous::silu::HALF,
("uabs", DType::F16) => contiguous::abs::HALF,
("uceil", DType::F16) => contiguous::ceil::HALF,
("ufloor", DType::F16) => contiguous::floor::HALF,
@ -730,6 +732,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F32) => strided::gelu::FLOAT,
("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT,
("uerf", DType::F32) => strided::erf::FLOAT,
("usilu", DType::F32) => strided::silu::FLOAT,
("uabs", DType::F32) => strided::abs::FLOAT,
("uceil", DType::F32) => strided::ceil::FLOAT,
("ufloor", DType::F32) => strided::floor::FLOAT,
@ -745,6 +748,7 @@ impl BackendStorage for MetalStorage {
("ugelu", DType::F16) => strided::gelu::HALF,
("ugelu_erf", DType::F16) => strided::gelu_erf::HALF,
("uerf", DType::F16) => strided::erf::HALF,
("usilu", DType::F16) => strided::silu::HALF,
("uabs", DType::F16) => strided::abs::HALF,
("uceil", DType::F16) => strided::ceil::HALF,
("ufloor", DType::F16) => strided::floor::HALF,

View File

@ -333,6 +333,16 @@ pub fn vd_tanh_inplace(y: &mut [f64]) {
unsafe { ffi::vdTanh(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_exp_inplace(y: &mut [f32]) {
unsafe { ffi::vsExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vd_exp_inplace(y: &mut [f64]) {
unsafe { ffi::vdExp(y.len() as i32, y.as_ptr(), y.as_mut_ptr()) }
}
#[inline]
pub fn vs_gelu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
@ -355,6 +365,28 @@ pub fn vd_gelu(vs: &[f64], ys: &mut [f64]) {
}
}
#[inline]
pub fn vs_silu(vs: &[f32], ys: &mut [f32]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vs_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
#[inline]
pub fn vd_silu(vs: &[f64], ys: &mut [f64]) {
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = -v
}
vd_exp_inplace(ys);
for (&v, y) in vs.iter().zip(ys.iter_mut()) {
*y = v / (1.0 + *y)
}
}
macro_rules! binary_op {
($fn_name:ident, $ty:ty, $mkl_name:ident) => {
#[inline]

View File

@ -61,6 +61,7 @@ pub enum UnaryOp {
GeluErf,
Erf,
Relu,
Silu,
Tanh,
Floor,
Ceil,
@ -390,6 +391,7 @@ pub(crate) struct Gelu;
pub(crate) struct GeluErf;
pub(crate) struct Erf;
pub(crate) struct Relu;
pub(crate) struct Silu;
pub(crate) struct Tanh;
pub(crate) struct Floor;
pub(crate) struct Ceil;
@ -724,6 +726,77 @@ impl UnaryOpT for Erf {
}
}
/// Silu operation
impl UnaryOpT for Silu {
const NAME: &'static str = "silu";
const V: Self = Silu;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v / (bf16::ONE + (-v).exp())
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v / (f16::ONE + (-v).exp())
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v / (1.0 + (-v).exp())
}
#[inline(always)]
fn u8(_: u8) -> u8 {
0
}
#[inline(always)]
fn u32(_: u32) -> u32 {
0
}
#[inline(always)]
fn i64(_: i64) -> i64 {
0
}
const KERNEL: &'static str = "usilu";
#[cfg(feature = "mkl")]
const F32_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::mkl::vs_silu(xs, ys)
}
#[cfg(feature = "mkl")]
const F64_VEC: bool = true;
#[cfg(feature = "mkl")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::mkl::vd_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F32_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
crate::accelerate::vs_silu(xs, ys)
}
#[cfg(feature = "accelerate")]
const F64_VEC: bool = true;
#[cfg(feature = "accelerate")]
#[inline(always)]
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
crate::accelerate::vd_silu(xs, ys)
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";

View File

@ -508,6 +508,7 @@ impl Tensor {
unary_op!(gelu_erf, GeluErf);
unary_op!(erf, Erf);
unary_op!(relu, Relu);
unary_op!(silu, Silu);
unary_op!(ceil, Ceil);
unary_op!(floor, Floor);
unary_op!(round, Round);

View File

@ -270,6 +270,19 @@ fn unary_grad(device: &Device) -> Result<()> {
[0.7358, 2.0000, 0.2707, 1.0000]
);
// testing compared to pytorch nn.Silu()
let y = x.silu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.8577, 0.7311, 3.9281, 0.0806]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0881, 0.9277, 1.0527, 0.5747],
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;

View File

@ -120,6 +120,13 @@ fn unary_op(device: &Device) -> Result<()> {
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.silu()?, 4)?,
[
[-0.1423, 0.7311, 3.9281, -0.0475, 0.3112],
[2.53, -0.2553, -0.1205, 1.5447, 2.6395]
]
);
assert_eq!(
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]

View File

@ -1,6 +1,7 @@
# candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) and
[ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](https://arxiv.org/abs/2301.00808).
This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the

View File

@ -12,38 +12,62 @@ use candle_transformers::models::convnext;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Atto,
Femto,
Pico,
Nano,
Tiny,
Small,
Base,
Large,
AttoV2,
FemtoV2,
PicoV2,
NanoV2,
TinyV2,
BaseV2,
LargeV2,
XLarge,
Huge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Tiny => "tiny",
Self::Small => "small",
Self::Base => "base",
Self::Large => "large",
Self::XLarge => "xlarge",
};
// The XLarge model only has an ImageNet-22K variant
let variant = match self {
Self::XLarge => "fb_in22k_ft_in1k",
_ => "fb_in1k",
Self::Atto => "convnext_atto.d2_in1k",
Self::Femto => "convnext_femto.d1_in1k",
Self::Pico => "convnext_pico.d1_in1k",
Self::Nano => "convnext_nano.d1h_in1k",
Self::Tiny => "convnext_tiny.fb_in1k",
Self::Small => "convnext_small.fb_in1k",
Self::Base => "convnext_base.fb_in1k",
Self::Large => "convnext_large.fb_in1k",
Self::AttoV2 => "convnextv2_atto.fcmae_ft_in1k",
Self::FemtoV2 => "convnextv2_femto.fcmae_ft_in1k",
Self::PicoV2 => "convnextv2_pico.fcmae_ft_in1k",
Self::NanoV2 => "convnextv2_nano.fcmae_ft_in1k",
Self::TinyV2 => "convnextv2_tiny.fcmae_ft_in1k",
Self::BaseV2 => "convnextv2_base.fcmae_ft_in1k",
Self::LargeV2 => "convnextv2_large.fcmae_ft_in1k",
Self::XLarge => "convnext_xlarge.fb_in22k_ft_in1k",
Self::Huge => "convnextv2_huge.fcmae_ft_in1k",
};
format!("timm/convnext_{name}.{variant}")
format!("timm/{name}")
}
fn config(&self) -> convnext::Config {
match self {
Self::Tiny => convnext::Config::tiny(),
Self::Atto | Self::AttoV2 => convnext::Config::atto(),
Self::Femto | Self::FemtoV2 => convnext::Config::femto(),
Self::Pico | Self::PicoV2 => convnext::Config::pico(),
Self::Nano | Self::NanoV2 => convnext::Config::nano(),
Self::Tiny | Self::TinyV2 => convnext::Config::tiny(),
Self::Small => convnext::Config::small(),
Self::Base => convnext::Config::base(),
Self::Large => convnext::Config::large(),
Self::Base | Self::BaseV2 => convnext::Config::base(),
Self::Large | Self::LargeV2 => convnext::Config::large(),
Self::XLarge => convnext::Config::xlarge(),
Self::Huge => convnext::Config::huge(),
}
}
}

View File

@ -0,0 +1,17 @@
## candle-rwkv
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
with performance on par with transformer architectures. Several variants are
available, candle implements the v5 version and can be used with Eagle 7B([blog
post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
```bash
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
avx: true, neon: false, simd128: false, f16c: true
temp: 0.00 repeat-penalty: 1.10 repeat-last-n: 64
The smallest prime is ϕ(2) = 2.
The smallest composite is ϕ(3) = 3.
The smallest perfect number is ϕ(5) = 5.
The smallest perfect square is ϕ(4) = 4.
The smallest perfect cube is ϕ(6) = 6.
```

View File

@ -0,0 +1,265 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
config: Config,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
config,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
let mut tokens = self.tokenizer.encode(prompt)?;
let mut generated_tokens = 0usize;
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[[t]], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
print!("{}", self.tokenizer.decode(&[t])?)
}
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let logits = match next_logits.as_ref() {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
print!("{}", self.tokenizer.decode(&[next_token])?);
std::io::stdout().flush()?;
let input = Tensor::new(&[[next_token]], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Eagle7b,
World1b5,
World3b,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
Self::World3b => "RWKV/rwkv-5-world-3b",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Eagle7b => "refs/pr/1",
Self::World1b5 | Self::World3b => "refs/pr/2",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "world1b5")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("lmz/candle-rwkv".to_string())
.get("rwkv_vocab_v20230424.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::new(tokenizer)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
config,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -55,6 +55,11 @@ __device__ __forceinline__ T relu_fwd(T x) {
return maxg(x, zero);
}
template<typename T>
__device__ __forceinline__ T silu_fwd(T x) {
return x / (static_cast<T>(1) + expg(-x));
}
#define UNARY_OP1(TYPENAME, FN_NAME, FUNC) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
@ -103,6 +108,7 @@ UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, ugelu_erf_bf16, gelu_erf_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
UNARY_OP(__nv_bfloat16, usilu_bf16, silu_fwd(x))
UNARY_OP1(__nv_bfloat16, upowf_bf16, powg(x, param))
#endif
@ -127,6 +133,7 @@ UNARY_OP(__half, ugelu_f16, gelu_fwd(x))
UNARY_OP(__half, ugelu_erf_f16, gelu_erf_fwd(x))
UNARY_OP(__half, urelu_f16, relu_fwd(x))
UNARY_OP1(__half, uelu_f16, elu_fwd(x, param))
UNARY_OP(__half, usilu_f16, silu_fwd(x))
UNARY_OP1(__half, upowf_f16, powg(x, param))
#endif
@ -173,5 +180,7 @@ UNARY_OP(float, urelu_f32, relu_fwd(x))
UNARY_OP(double, urelu_f64, relu_fwd(x))
UNARY_OP1(float, uelu_f32, elu_fwd(x, param))
UNARY_OP1(double, uelu_f64, elu_fwd(x, param))
UNARY_OP(float, usilu_f32, silu_fwd(x))
UNARY_OP(double, usilu_f64, silu_fwd(x))
UNARY_OP1(float, upowf_f32, powg(x, param))
UNARY_OP1(double, upowf_f64, powg(x, param))

View File

@ -1,2 +0,0 @@
xcrun metal -c src/gemm/kernels/steel_gemm.metal -I src/
xcrun metallib steel_gemm.air -o src/gemm/steel_gemm.metallib

View File

@ -1,317 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
#if defined(__HAVE_BFLOAT__)
typedef bfloat bfloat16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
// Check for nan
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
_fp_encoding_traits<float>::inf_mask) {
return uint16_t(as_type<uint32_t>(0x7FC0));
}
// Take bits
uint32_t float_bits = as_type<uint32_t>(x);
// Round to nearest even
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
// Take upper 16 bits
return float_bits >> 16;
}
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
// Upper 16 bits are the data and lower 16 bits are 0s
return as_type<float>((uint32_t)x << 16);
}
struct _MLX_BFloat16;
template <typename T>
static constexpr constant bool can_convert_to_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
/////////////////////////////////////////////////////////////////////////////
// Bfloat struct
/////////////////////////////////////////////////////////////////////////////
struct _MLX_BFloat16 {
/////////////////////////////////////////////////////////////////////////////
// Constructors
uint16_t bits_;
_MLX_BFloat16() thread = default;
_MLX_BFloat16() threadgroup = default;
_MLX_BFloat16() device = default;
_MLX_BFloat16() constant = default;
struct bits_to_bfloat_struct {};
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
return bits_to_bfloat_struct();
}
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
: bits_(bits) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions to bfloat
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) device
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions from bfloat
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const thread {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const threadgroup {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const device {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const constant {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
};
/////////////////////////////////////////////////////////////////////////////
// Bfloat operators
/////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////
// Unary ops
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
return -static_cast<float>(x);
}
/////////////////////////////////////////////////////////////////////////////
// Binary operators
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
/////////////////////////////////////////////////////////////////////////////
// Arithmetic Operators
#define bfloat_binop(_op_, _operator_) \
bfloat_binop_base( \
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(_op_, _operator_, float, float, float); \
bfloat_binop_helper(_op_, _operator_, float, half, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
bfloat_binop(+, operator+);
bfloat_binop(-, operator-);
bfloat_binop(*, operator*);
bfloat_binop(/, operator/);
/////////////////////////////////////////////////////////////////////////////
// Comparison ops
#define bfloat_compop(__op__, __operator__) \
bfloat_binop_base( \
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
bfloat_compop(>, operator>);
bfloat_compop(<, operator<);
bfloat_compop(>=, operator>=);
bfloat_compop(<=, operator<=);
bfloat_compop(==, operator==);
bfloat_compop(!=, operator!=);
#undef bfloat_compop
#undef bfloat_binop_base
#undef bfloat_binop_helper
#undef bfloat_binop
/////////////////////////////////////////////////////////////////////////////
// Inplace Operators
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, itype rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
} \
constexpr METAL_FUNC addr_space itype& __operator__( \
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
#define bfloat_inplace_op(itype) \
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
bfloat_inplace_op(float);
bfloat_inplace_op(half);
bfloat_inplace_op(int16_t);
bfloat_inplace_op(int32_t);
bfloat_inplace_op(int64_t);
bfloat_inplace_op(uint16_t);
bfloat_inplace_op(uint32_t);
bfloat_inplace_op(uint64_t);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
#undef bfloat_inplace_op
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
bfloat_inplace_op_helper(__op__, __operator__, device); \
bfloat_inplace_op_helper(__op__, __operator__, thread); \
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
bfloat_inplace_op_addr_space_helper(+, operator+=);
bfloat_inplace_op_addr_space_helper(-, operator-=);
bfloat_inplace_op_addr_space_helper(*, operator*=);
bfloat_inplace_op_addr_space_helper(/, operator/=);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
/////////////////////////////////////////////////////////////////////////////
// Bfloat typedef
/////////////////////////////////////////////////////////////////////////////
typedef struct _MLX_BFloat16 bfloat16_t;
/////////////////////////////////////////////////////////////////////////////
// Bfloat numeric limits
/////////////////////////////////////////////////////////////////////////////
#pragma METAL internals : enable
namespace metal {
template <>
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
static constexpr constant int digits = 8;
static constexpr constant int digits10 = 2;
static constexpr constant int max_digits10 = 4;
static constexpr constant int radix = 2;
static constexpr constant int min_exponent = -125;
static constexpr constant int min_exponent10 = -37;
static constexpr constant int max_exponent = 128;
static constexpr constant int max_exponent10 = 38;
static constexpr bfloat16_t min() {
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t lowest() {
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t max() {
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t epsilon() {
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t round_error() {
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t infinity() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t quiet_NaN() {
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t signaling_NaN() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t denorm_min() {
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
}
};
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
return x != x;
}
} // namespace metal
#pragma METAL internals : disable
#endif // defined(__HAVE_BFLOAT__)
#include "gemm/bf16_math.h"

View File

@ -1,394 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "gemm/bf16.h"
///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16
///////////////////////////////////////////////////////////////////////////////
/*
Following the Metal Shading Language Specification (Metal 3.1)
"bfloat is an extended itypeing point type that only allows implicit conversion
to a type of greater itypeing point rank. While bfloat can be implicitly
converted to itype, it cannot be implicitly converted to half, and neither
itype nor half can be implicitly converted to bfloat."
Further, as far as I can tell, the stdlib math/simd functions are not defined
for bfloat and calling with an argument of type bfloat will result in that
argument getting implicitly converted to itype which then returns an output
that is (likely) a itype which cannot be implicitly converted into a bfloat
This leads to situations where
bfloat a = 5.0bf;
bfloat b = metal::abs(a); // this will throw an error since abs return itype
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
For the moment, I will be adding overloaded instantiations of the math
functions to accordingly automatically handle the casting
*/
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
\
METAL_FUNC otype abs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acos(itype x) { \
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acosh(itype x) { \
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asin(itype x) { \
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asinh(itype x) { \
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atan(itype y_over_x) { \
return static_cast<otype>( \
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
} \
METAL_FUNC otype atan2(itype y, itype x) { \
return static_cast<otype>( \
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atanh(itype x) { \
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype ceil(itype x) { \
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cos(itype x) { \
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cosh(itype x) { \
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cospi(itype x) { \
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype divide(itype x, itype y) { \
return static_cast<otype>( \
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype exp(itype x) { \
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp10(itype x) { \
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp2(itype x) { \
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fabs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fdim(itype x, itype y) { \
ctype t = static_cast<ctype>(x - y); \
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
} \
METAL_FUNC otype floor(itype x) { \
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fma(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fma( \
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
} \
METAL_FUNC otype fmax(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmin(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmod(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fract(itype x) { \
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype frexp(itype x, thread int& exp) { \
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
} \
METAL_FUNC otype ldexp(itype x, int k) { \
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
} \
METAL_FUNC otype log(itype x) { \
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log10(itype x) { \
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log2(itype x) { \
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype max(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype max3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype median3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype min(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype min3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype nextafter(itype x, itype y) { \
return static_cast<otype>( \
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
} \
METAL_FUNC otype pow(itype x, itype y) { \
return static_cast<otype>( \
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype powr(itype x, itype y) { \
return static_cast<otype>( \
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype rint(itype x) { \
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype round(itype x) { \
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype rsqrt(itype x) { \
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sin(itype x) { \
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinh(itype x) { \
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinpi(itype x) { \
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sqrt(itype x) { \
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tan(itype x) { \
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanh(itype x) { \
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanpi(itype x) { \
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype trunc(itype x) { \
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
}
namespace metal {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_MAYBE_FAST_MATH__);
namespace fast {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_FAST_MATH__);
} // namespace fast
namespace precise {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_PRECISE_MATH__);
} // namespace precise
} // namespace metal
///////////////////////////////////////////////////////////////////////////////
// Metal simd for bfloat16
///////////////////////////////////////////////////////////////////////////////
#define instantiate_metal_simd_comm_funcs( \
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
\
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
return ctype_to_otype( \
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
return ctype_to_otype( \
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down( \
itype data, itype filling_data, ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down( \
itype data, itype filling_data, ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), \
itype_to_ctype(filling_data), \
delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up( \
itype data, itype filling_data, ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up( \
itype data, itype filling_data, ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), \
itype_to_ctype(filling_data), \
delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
return ctype_to_otype( \
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
}
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
\
METAL_FUNC otype simd_max(itype data) { \
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_min(itype data) { \
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_product(itype data) { \
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_sum(itype data) { \
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_xor(itype data) { \
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined(__HAVE_BFLOAT__)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal {
instantiate_metal_simd_comm_funcs(
bfloat16_t,
bfloat16_t,
uint16_t,
bfloat16_to_uint16,
uint16_to_bfloat16);
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
} // namespace metal

View File

@ -1,131 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
struct complex64_t;
template <typename T>
static constexpr constant bool can_convert_to_complex64 =
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_complex64 =
!is_same_v<T, complex64_t> &&
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
struct complex64_t {
float real;
float imag;
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
// Conversions to complex64_t
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) thread : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) device : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) constant : real(x), imag(0) {}
// Conversions from complex64_t
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const thread {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const threadgroup {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const device {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const constant {
return static_cast<T>(real);
}
};
constexpr complex64_t operator-(complex64_t x) {
return {-x.real, -x.imag};
}
constexpr bool operator>=(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
}
constexpr bool operator>(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
}
constexpr bool operator<=(complex64_t a, complex64_t b) {
return operator>=(b, a);
}
constexpr bool operator<(complex64_t a, complex64_t b) {
return operator>(b, a);
}
constexpr bool operator==(complex64_t a, complex64_t b) {
return a.real == b.real && a.imag == b.imag;
}
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a.real * b.real + a.imag * b.imag;
auto y = a.imag * b.real - a.real * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
if (real != 0 && (real < 0 != b.real < 0)) {
real += b.real;
}
if (imag != 0 && (imag < 0 != b.imag < 0)) {
imag += b.imag;
}
return {real, imag};
}

View File

@ -1,292 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "gemm/loader.h"
#include "gemm/mma.h"
#include "gemm/transforms.h"
#include "utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(C, params->ldc);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(C, params->ldc);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@ -1,5 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "params.h"

View File

@ -1,89 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "gemm/bf16.h"
#include "gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
gemm_kernel::run(
A, B, C,
params,
As, Bs,
simd_lane_id, simd_group_id, tid, lid
);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \
const constant GEMMParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@ -1,254 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = float,
typename Epilogue = TransformAdd<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
const device T *C [[buffer(2)]],
device T *D [[buffer(3)]],
const constant GEMMAddMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
using gemm_kernel =
GEMMKernel<T, T, BM, BN, BK, WM, WN,
transpose_a, transpose_b,
MN_aligned, K_aligned,
AccumType, Epilogue>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col * params->fdc;
D += c_row * params->ldd + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(params->alpha, params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2)]], \
device itype *D [[buffer(3)]], \
const constant GEMMAddMMParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@ -1,280 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device U *C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid;
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
if(!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iter_remaining,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
}
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device otype *C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel
///////////////////////////////////////////////////////////////////////////////
template <typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
int offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
D[0] = Epilogue::apply(out);
}
template <typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device OutT *C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * fdc + gid.y * ldc;
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
int offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
Epilogue op(alpha, beta);
D[0] = op.apply(out, *C);
}
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
[[kernel]] void gemm_splitk_accum<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype *C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]);
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float);

View File

@ -1,125 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "utils2.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@ -1,264 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "gemm/transforms.h"
#include "utils.h"
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@ -1,79 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
///////////////////////////////////////////////////////////////////////////////
// GEMM param classes
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
struct GEMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int swizzle_log;
const int gemm_k_iterations_aligned;
};
struct GEMMSpiltKParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int tiles_n;
const int tiles_m;
const int split_k_partitions;
const int split_k_partition_stride;
const int split_k_partition_size;
const int gemm_k_iterations_aligned;
};
struct GEMMAddMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int ldd;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
const float alpha;
const float beta;
const int fdc;
};
} // namespace steel
} // namespace mlx

View File

@ -1,63 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx

View File

@ -1,276 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_math>
#include "gemm/bf16.h"
#include "gemm/complex.h"
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
template <typename U>
struct Limits {
static const constant U max = metal::numeric_limits<U>::max();
static const constant U min = metal::numeric_limits<U>::min();
static const constant U finite_max = metal::numeric_limits<U>::max();
static const constant U finite_min = metal::numeric_limits<U>::min();
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};
instantiate_float_limit(half);
instantiate_float_limit(float);
instantiate_float_limit(bfloat16_t);
template <>
struct Limits<bool> {
static constexpr constant bool max = true;
static constexpr constant bool min = false;
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
inline size_t elem_to_loc(
uint elem,
device const int* shape,
device const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
inline size_t elem_to_loc(
uint elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <int NDIM>
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline size_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM]) {
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
return elem * stride;
}
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
// Non templated version to handle arbitrary dims
inline size_t elem_to_loc(
uint3 elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline uint elem_to_loc_nd(
uint elem,
device const int* shape,
device const size_t* strides);
template <>
inline uint elem_to_loc_nd<1>(
uint elem,
device const int* shape,
device const size_t* strides) {
return (elem % shape[0]) * strides[0];
}
template <>
inline uint elem_to_loc_nd<2>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<3>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<4>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[3]) * strides[3];
elem /= shape[3];
loc += (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
/** Compute ceil((float)N/(float)M) */
inline size_t ceildiv(size_t N, size_t M) {
return (N + M - 1) / M;
}
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {
float xp1 = 1.0f + x;
if (xp1 == Limits<float>::max) {
return Limits<float>::max;
}
if (xp1 == 1.0f) {
return x;
}
return x * (metal::log(xp1) / (xp1 - 1.0f));
}
inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x);
if (xp1 == Limits<float>::max) {
return Limits<bfloat16_t>::max;
}
if (xp1 == 1.0f) {
return x;
}
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}

View File

@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_stdlib>
#include "gemm/host.h"
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")

View File

@ -1,7 +1,6 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
NSUInteger,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
};
use std::collections::HashMap;
use std::ffi::c_void;
@ -17,7 +16,6 @@ const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const GEMM: &[u8] = include_bytes!("gemm/steel_gemm.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
/// Most kernels apply similarly across the tensors
@ -124,7 +122,6 @@ pub enum Source {
Cast,
Reduce,
Mfa,
Gemm,
Conv,
Random,
Quantized,
@ -186,7 +183,7 @@ macro_rules! ops{
pub mod unary {
ops!(
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
tanh, recip
tanh, recip, silu
);
}
pub mod binary {
@ -251,7 +248,6 @@ impl Kernels {
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"),
Source::Gemm => panic!("Invalid lib"),
}
}
@ -275,14 +271,6 @@ impl Kernels {
))
})?
}
Source::Gemm => {
let source_data = GEMM;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load GEMM: {e}"
))
})?
}
source => {
let source_content = self.get_library_source(source);
device
@ -1242,34 +1230,6 @@ impl ConstantValues {
}
}
fn string_to_static_str(s: String) -> &'static str {
Box::leak(s.into_boxed_str())
}
use core::ffi::c_int;
#[repr(C)]
#[derive(Debug)]
struct GEMMParams {
m: c_int,
n: c_int,
k: c_int,
lda: c_int,
ldb: c_int,
ldc: c_int,
tiles_n: c_int,
tiles_m: c_int,
batch_stride_a: c_int,
batch_stride_b: c_int,
batch_stride_c: c_int,
swizzle_log: c_int,
gemm_k_iterations_aligned: c_int,
}
#[allow(clippy::too_many_arguments)]
pub fn call_gemm(
device: &Device,
@ -1291,10 +1251,10 @@ pub fn call_gemm(
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (a_trans, lda) = if lhs_m1 == 1 && lhs_m2 == k {
(false, k as c_int)
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
(true, n as c_int)
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
@ -1302,10 +1262,10 @@ pub fn call_gemm(
mnk: (m, n, k),
})?;
};
let (b_trans, ldb) = if rhs_m1 == 1 && rhs_m2 == n {
(false, n as c_int)
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
(true, k as c_int)
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
@ -1313,195 +1273,119 @@ pub fn call_gemm(
mnk: (m, n, k),
})?;
};
// let d_trans = false;
// let alpha = 1.0f32;
// let beta = 0.0f32;
// let batched = b > 1;
// let fused_activation = false;
// let fused_bias = false;
// let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
// let m_simd = 8;
// let n_simd = 8;
// let k_simd = 64;
// let m_splits = 1;
// let n_splits = 1;
// (m_simd, n_simd, k_simd, m_splits, n_splits)
// } else {
// let m_simd = 40;
// let n_simd = 40;
// let k_simd = 32;
// let m_splits = 1;
// let n_splits = 1;
// (m_simd, n_simd, k_simd, m_splits, n_splits)
// };
// let constants = Some(ConstantValues::new(vec![
// (0, Value::USize(m)),
// (1, Value::USize(n)),
// (2, Value::USize(k)),
// (10, Value::Bool(a_trans)),
// (11, Value::Bool(b_trans)),
// (13, Value::Bool(d_trans)),
// (20, Value::F32(alpha)),
// (21, Value::F32(beta)),
// (100, Value::Bool(batched)),
// (101, Value::Bool(fused_activation)),
// // Garbage
// (102, Value::Bool(false)),
// (103, Value::Bool(false)),
// (113, Value::Bool(false)),
// (50_000, Value::Bool(false)),
// // End garbage
// (200, Value::U16(m_simd)),
// (201, Value::U16(n_simd)),
// (202, Value::U16(k_simd)),
// (210, Value::U16(m_splits)),
// (211, Value::U16(n_splits)),
// (50_001, Value::Bool(fused_bias)),
// ]));
let a_trans_name = if a_trans { "t" } else { "n" };
let b_trans_name = if b_trans { "t" } else { "n" };
let (iname, oname) = match name {
"sgemm" => ("float32", "float32"),
"hgemm" => ("float16", "float16"),
"bgemm" => ("bfloat16", "bfloat16"),
let d_trans = false;
let alpha = 1.0f32;
let beta = 0.0f32;
let batched = b > 1;
let fused_activation = false;
let fused_bias = false;
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
let m_simd = 8;
let n_simd = 8;
let k_simd = 64;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
} else {
let m_simd = 40;
let n_simd = 40;
let k_simd = 32;
let m_splits = 1;
let n_splits = 1;
(m_simd, n_simd, k_simd, m_splits, n_splits)
};
let constants = Some(ConstantValues::new(vec![
(0, Value::USize(m)),
(1, Value::USize(n)),
(2, Value::USize(k)),
(10, Value::Bool(a_trans)),
(11, Value::Bool(b_trans)),
(13, Value::Bool(d_trans)),
(20, Value::F32(alpha)),
(21, Value::F32(beta)),
(100, Value::Bool(batched)),
(101, Value::Bool(fused_activation)),
// Garbage
(102, Value::Bool(false)),
(103, Value::Bool(false)),
(113, Value::Bool(false)),
(50_000, Value::Bool(false)),
// End garbage
(200, Value::U16(m_simd)),
(201, Value::U16(n_simd)),
(202, Value::U16(k_simd)),
(210, Value::U16(m_splits)),
(211, Value::U16(n_splits)),
(50_001, Value::Bool(fused_bias)),
]));
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
let m_group = m_simd * m_splits;
let n_group = n_simd * n_splits;
let a_block_length = m_group * k_simd;
let b_block_length = k_simd * n_group;
let mut block_elements = a_block_length + b_block_length;
if (m % 8 != 0) && (n % 8 != 0) {
let c_block_length = m_group * n_group;
block_elements = std::cmp::max(c_block_length, block_elements)
}
if fused_bias {
if d_trans {
block_elements = std::cmp::max(block_elements, m_group);
} else {
block_elements = std::cmp::max(block_elements, n_group);
}
}
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
other => {
return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm"
)))
)));
}
};
let mut bm = 32;
let mut bn = 32;
let mut bk = 16;
let wm = 2;
let wn = 2;
if b * m * n >= 1 << 20 {
if !a_trans && b_trans {
bm = 64;
bn = if oname == "float32" { 64 } else { 32 };
bk = if oname == "float32" { 16 } else { 32 };
} else {
bm = 64;
bn = 64;
}
}
let mnaligned = if m % bm == 0 && n % bn == 0 {
"taligned"
} else {
"naligned"
};
let kaligned = if k % bk == 0 { "taligned" } else { "naligned" };
// let bytes = match &name[..] {
// "sgemm" => 4,
// "hgemm" => 2,
// other => {
// return Err(MetalKernelError::LoadLibraryError(format!(
// "{other} is not a valid kernel for gemm"
// )));
// }
// };
let name = format!("steel_gemm_{a_trans_name}{b_trans_name}_{iname}_{oname}_bm{bm}_bn{bn}_bk{bk}_wm{wm}_wn{wn}_MN_{mnaligned}_K_{kaligned}");
let name = string_to_static_str(name);
let pipeline = kernels.load_pipeline(device, Source::Gemm, name)?;
// let m_group = m_simd * m_splits;
// let n_group = n_simd * n_splits;
// let a_block_length = m_group * k_simd;
// let b_block_length = k_simd * n_group;
// let mut block_elements = a_block_length + b_block_length;
// if (m % 8 != 0) && (n % 8 != 0) {
// let c_block_length = m_group * n_group;
// block_elements = std::cmp::max(c_block_length, block_elements)
// }
// if fused_bias {
// if d_trans {
// block_elements = std::cmp::max(block_elements, m_group);
// } else {
// block_elements = std::cmp::max(block_elements, n_group);
// }
// }
// let block_bytes = block_elements * bytes;
let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
let batch_stride_a: i32 = if lhs_stride.len() > 2 {
lhs_stride[lhs_stride.len() - 3] as i32
} else {
0
};
let batch_stride_b: i32 = if rhs_stride.len() > 2 {
rhs_stride[rhs_stride.len() - 3] as i32
} else {
0
};
let batch_stride_c = (m * n) as i32;
let swizzle_log = 0;
let tiles_n = ((n + bn - 1) / bn) as c_int;
let tiles_m = ((m + bm - 1) / bm) as c_int;
let params = GEMMParams {
m: m as c_int,
n: n as c_int,
k: k as c_int,
lda,
ldb,
ldc: n as c_int,
tiles_m,
tiles_n,
batch_stride_a,
batch_stride_b,
batch_stride_c,
swizzle_log,
gemm_k_iterations_aligned: (k / bk) as c_int,
};
let params_buffer = device.new_buffer_with_data(
&params as *const GEMMParams as *const c_void,
core::mem::size_of::<GEMMParams>() as u64,
MTLResourceOptions::StorageModeShared,
);
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(2, Some(output), 0);
encoder.set_buffer(3, Some(&params_buffer), 0);
// TODO Tensor D
let grid_z = b;
// if batched {
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
// let byte_stride_c = m * n * bytes as usize;
// // TODO byte_stride_d
// let byte_stride_d = 0;
if batched {
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
let byte_stride_c = m * n * bytes as usize;
// TODO byte_stride_d
let byte_stride_d = 0;
// let buffer: Vec<u64> = vec![
// byte_stride_a as _,
// byte_stride_b as _,
// byte_stride_c as _,
// byte_stride_d as _,
// ];
// // encoder.set_bytes(
// // 10,
// // (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
// // buffer.as_ptr() as *const NSUInteger as *const c_void,
// // );
// }
let tile = 1 << swizzle_log;
let tm = (tiles_m + tile - 1) / tile;
let tn = tiles_n * tile;
let buffer: Vec<u64> = vec![
byte_stride_a as _,
byte_stride_b as _,
byte_stride_c as _,
byte_stride_d as _,
];
encoder.set_bytes(
10,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
buffer.as_ptr() as *const NSUInteger as *const c_void,
);
}
let grid_size = MTLSize {
width: tn as u64,
height: tm as u64,
width: divide(n, n_group.into()),
height: divide(m, m_group.into()),
depth: grid_z as NSUInteger,
};
let group_size = MTLSize {
width: 32,
height: wn,
depth: wm,
width: 32 * (m_splits as u64) * (n_splits as u64),
height: 1,
depth: 1,
};
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);

View File

@ -231,6 +231,25 @@ fn gelu_f32() {
assert_eq!(approx(results, 3), expected);
}
#[test]
fn silu_f16() {
let v: Vec<f16> = [-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let expected: Vec<f32> = vec![-0.0, -0.27, 0.0, 0.73, 1.76, 2.86, 10.0, 20.0];
let results = run(&v, unary::contiguous::silu::HALF);
assert_eq!(approx_f16(results, 2), expected);
}
#[test]
fn silu_f32() {
let v: Vec<f32> = vec![-10f32, -1.0, 0., 1., 2., 3., 10.0, 20.0];
let expected: Vec<f32> = vec![-0.0, -0.269, 0.0, 0.731, 1.762, 2.858, 10.0, 20.0];
let results = run(&v, unary::contiguous::silu::FLOAT);
assert_eq!(approx(results, 3), expected);
}
#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];

View File

@ -64,6 +64,9 @@ template <typename T> METAL_FUNC T relu(T in){
}
return in;
}
template <typename T> METAL_FUNC T silu(T in){
return in / (static_cast<T>(1) + exp(-in));
}
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
@ -108,6 +111,7 @@ UNARY_OP(neg)
UNARY_OP(exp)
UNARY_OP(log)
UNARY_OP(gelu)
UNARY_OP(silu)
UNARY_OP(abs)
UNARY_OP(ceil)
UNARY_OP(floor)
@ -135,6 +139,7 @@ BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
BFLOAT_UNARY_OP(log)
BFLOAT_UNARY_OP(gelu)
BFLOAT_UNARY_OP(silu)
BFLOAT_UNARY_OP(abs)
BFLOAT_UNARY_OP(ceil)
BFLOAT_UNARY_OP(floor)

View File

@ -30,7 +30,7 @@ impl super::Module for Activation {
Self::Relu => xs.relu(),
Self::Relu2 => xs.relu()?.sqr(),
Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => crate::ops::silu(xs),
Self::Silu => xs.silu(),
Self::Sigmoid => crate::ops::sigmoid(xs),
Self::HardSigmoid => crate::ops::hard_sigmoid(xs),
Self::Swiglu => crate::ops::swiglu(xs),

View File

@ -35,13 +35,12 @@ pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
}
pub fn silu(xs: &Tensor) -> Result<Tensor> {
// TODO: Should we have a specialized op for this?
xs / (xs.neg()?.exp()? + 1.0)?
xs.silu()
}
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
let xs = xs.chunk(2, candle::D::Minus1)?;
crate::ops::silu(&xs[0])? * &xs[1]
&xs[0].silu()? * &xs[1]
}
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {

View File

@ -2,10 +2,16 @@
//!
//! See "A ConvNet for the 2020s" Liu et al. 2022
//! <https://arxiv.org/abs/2201.03545>
//! and
//! "ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders" Woo et al. 2023
//! <https://arxiv.org/abs/2301.00808>
//! Original code: https://github.com/facebookresearch/ConvNeXt/
//! Original code:
//! https://github.com/facebookresearch/ConvNeXt/
//! https://github.com/facebookresearch/ConvNeXt-V2/
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
use candle::shape::ShapeWithOneHole;
use candle::{Result, D};
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
@ -13,31 +19,71 @@ use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
pub struct Config {
blocks: [usize; 4],
channels: [usize; 4],
use_conv_mlp: bool,
}
impl Config {
pub fn atto() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [40, 80, 160, 320],
use_conv_mlp: true,
}
}
pub fn femto() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [48, 96, 192, 384],
use_conv_mlp: true,
}
}
pub fn pico() -> Self {
Self {
blocks: [2, 2, 6, 2],
channels: [64, 128, 256, 512],
use_conv_mlp: true,
}
}
pub fn nano() -> Self {
Self {
blocks: [2, 2, 8, 2],
channels: [80, 160, 320, 640],
use_conv_mlp: true,
}
}
pub fn tiny() -> Self {
Self {
blocks: [3, 3, 9, 3],
channels: [96, 192, 384, 768],
use_conv_mlp: false,
}
}
pub fn small() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [96, 192, 384, 768],
use_conv_mlp: false,
}
}
pub fn base() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [128, 256, 512, 1024],
use_conv_mlp: false,
}
}
pub fn large() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [192, 384, 768, 1536],
use_conv_mlp: false,
}
}
@ -45,8 +91,68 @@ impl Config {
Self {
blocks: [3, 3, 27, 3],
channels: [256, 512, 1024, 2048],
use_conv_mlp: false,
}
}
pub fn huge() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [352, 704, 1408, 2816],
use_conv_mlp: false,
}
}
}
// Layer norm for data in channels-last format.
fn layer_norm_cl(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(dim, 1e-6, vb)?;
Ok(Func::new(move |xs| xs.apply(&norm)))
}
// Layer norm for data in channels-first format.
fn layer_norm_cf(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(dim, 1e-6, vb)?;
Ok(Func::new(move |xs| {
let xs = xs
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?;
Ok(xs)
}))
}
// Global response normalization layer
// Based on https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/grn.py
fn convnext2_grn(dim: usize, channels_last: bool, vb: VarBuilder) -> Result<Func<'static>> {
let (shape, spatial_dim, channel_dim) = if channels_last {
((1, 1, 1, ()).into_shape(dim)?, [1, 2], 3)
} else {
((1, (), 1, 1).into_shape(dim)?, [2, 3], 1)
};
let gamma = vb.get(dim, "weight")?.reshape(&shape)?;
let beta = vb.get(dim, "bias")?.reshape(&shape)?;
Ok(Func::new(move |xs| {
let residual = xs;
let gx = xs
.sqr()?
.sum_keepdim(spatial_dim)?
.mean_keepdim(spatial_dim)?
.sqrt()?;
let gxmean = gx.mean_keepdim(channel_dim)?;
let nx = gx.broadcast_div(&(gxmean + 1e-6)?)?;
let xs = xs
.broadcast_mul(&nx)?
.broadcast_mul(&gamma)?
.broadcast_add(&beta)?;
xs + residual
}))
}
// Initial downsampling via a patchify layer.
@ -56,16 +162,9 @@ fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
..Default::default()
};
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
Ok(Func::new(move |xs| {
// The layer norm works with channels-last format.
let xs = xs
.apply(&patchify)?
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?;
Ok(xs)
}))
let norm = layer_norm_cf(out_channels, vb.pp(1))?;
Ok(Func::new(move |xs| xs.apply(&patchify)?.apply(&norm)))
}
// Downsampling applied after the stages.
@ -74,31 +173,49 @@ fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
stride: 2,
..Default::default()
};
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
let norm = layer_norm_cf(dim / 2, vb.pp(0))?;
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
Ok(Func::new(move |xs| {
let xs = xs
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?
.apply(&conv)?;
Ok(xs)
}))
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&conv)))
}
// MLP equivalent of pointwise convolutions.
// MLP block from the original paper with optional GRN layer (v2 models).
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
let grn = convnext2_grn(4 * dim, true, vb.pp("grn"));
Ok(Func::new(move |xs| {
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
if let Ok(g) = &grn {
xs = xs.apply(g)?;
}
xs = xs.apply(&fc2)?;
Ok(xs)
}))
}
// A block consisting of a depthwise convolution, a MLP and layer scaling.
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
// MLP block using pointwise convolutions, with optional GRN layer (v2 models).
fn convnext_conv_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(dim, 4 * dim, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(4 * dim, dim, 1, conv2d_cfg, vb.pp("fc2"))?;
let grn = convnext2_grn(4 * dim, false, vb.pp("grn"));
Ok(Func::new(move |xs| {
let mut xs = xs.apply(&fc1)?.gelu_erf()?;
if let Ok(g) = &grn {
xs = xs.apply(g)?;
}
xs = xs.apply(&fc2)?;
Ok(xs)
}))
}
// A block consisting of a depthwise convolution, a MLP and layer scaling (v1 models only).
fn convnext_block(dim: usize, use_conv_mlp: bool, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
groups: dim,
padding: 3,
@ -106,20 +223,36 @@ fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
};
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
let gamma = vb.get(dim, "gamma");
let gamma = vb.get(dim, "gamma")?;
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
let (mlp, norm) = if use_conv_mlp {
(
convnext_conv_mlp(dim, vb.pp("mlp"))?,
layer_norm_cf(dim, vb.pp("norm"))?,
)
} else {
(
convnext_mlp(dim, vb.pp("mlp"))?,
layer_norm_cl(dim, vb.pp("norm"))?,
)
};
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs
.apply(&conv_dw)?
.permute((0, 2, 3, 1))?
.apply(&norm)?
.apply(&mlp)?
.broadcast_mul(&gamma)?
.permute((0, 3, 1, 2))?;
let mut xs = xs.apply(&conv_dw)?;
xs = if use_conv_mlp {
xs.apply(&norm)?.apply(&mlp)?
} else {
xs.permute((0, 2, 3, 1))?
.apply(&norm)?
.apply(&mlp)?
.permute((0, 3, 1, 2))?
};
if let Ok(g) = &gamma {
xs = xs.broadcast_mul(&g.reshape((1, (), 1, 1))?)?;
};
xs + residual
}))
@ -137,7 +270,11 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
}
for block_idx in 0..nblocks {
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
blocks.push(convnext_block(
dim,
cfg.use_conv_mlp,
vb.pp(format!("blocks.{block_idx}")),
)?);
}
Ok(Func::new(move |xs| {
@ -149,8 +286,9 @@ fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func
}))
}
// Classification head.
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
let norm = layer_norm_cl(outputs, vb.pp("norm"))?;
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
}

View File

@ -1,13 +1,12 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct LlamaConfig {
pub hidden_size: usize,
pub intermediate_size: usize,

View File

@ -34,6 +34,7 @@ pub mod quantized_t5;
pub mod qwen2;
pub mod repvgg;
pub mod resnet;
pub mod rwkv_v5;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
@ -41,6 +42,7 @@ pub mod t5;
pub mod trocr;
pub mod vgg;
pub mod vit;
pub mod vocos;
pub mod whisper;
pub mod with_tracing;
pub mod wuerstchen;

View File

@ -0,0 +1,409 @@
use super::with_tracing::{layer_norm, linear_no_bias as linear, LayerNorm, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use std::collections::{HashMap, HashSet};
fn default_num_attention_heads() -> usize {
64
}
// https://huggingface.co/RWKV/HF_v5-Eagle-7B/blob/main/configuration_rwkv5.py
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub attention_hidden_size: usize,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
pub head_size: usize,
pub intermediate_size: Option<usize>,
pub layer_norm_epsilon: f64,
pub rescale_every: usize,
}
struct StatePerLayer {
extract_key_value: Tensor,
linear_attention: Tensor,
feed_forward: Tensor,
}
pub struct State {
per_layer: Vec<StatePerLayer>,
pos: usize,
}
impl State {
pub fn new(batch_size: usize, cfg: &Config, dev: &Device) -> Result<Self> {
let mut per_layer = Vec::with_capacity(cfg.num_hidden_layers);
// Certainly a weird convention but taken from modeling_rwkv5.py
let num_attention_heads = cfg.hidden_size / cfg.num_attention_heads;
for _layer_idx in 0..cfg.num_hidden_layers {
let extract_key_value = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
let linear_attention = Tensor::zeros(
(
batch_size,
num_attention_heads,
cfg.hidden_size / num_attention_heads,
cfg.hidden_size / num_attention_heads,
),
DType::F32,
dev,
)?;
let feed_forward = Tensor::zeros((batch_size, cfg.hidden_size), DType::F32, dev)?;
per_layer.push(StatePerLayer {
extract_key_value,
linear_attention,
feed_forward,
});
}
Ok(Self { per_layer, pos: 0 })
}
}
#[derive(Debug, Clone)]
struct SelfAttention {
key: Linear,
receptance: Linear,
value: Linear,
gate: Linear,
output: Linear,
ln_x: candle_nn::GroupNorm,
time_mix_key: Tensor,
time_mix_value: Tensor,
time_mix_receptance: Tensor,
time_decay: Tensor,
time_faaaa: Tensor,
time_mix_gate: Tensor,
layer_id: usize,
n_attn_heads: usize,
}
impl SelfAttention {
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_size = cfg.hidden_size;
let attn_hidden_size = cfg.attention_hidden_size;
let key = linear(hidden_size, attn_hidden_size, vb.pp("key"))?;
let receptance = linear(hidden_size, attn_hidden_size, vb.pp("receptance"))?;
let value = linear(hidden_size, attn_hidden_size, vb.pp("value"))?;
let gate = linear(hidden_size, attn_hidden_size, vb.pp("gate"))?;
let output = linear(attn_hidden_size, hidden_size, vb.pp("output"))?;
let ln_x = candle_nn::group_norm(
hidden_size / cfg.head_size,
hidden_size,
1e-5,
vb.pp("ln_x"),
)?;
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
let time_mix_value = vb.get((1, 1, cfg.hidden_size), "time_mix_value")?;
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
let n_attn_heads = cfg.hidden_size / cfg.head_size;
let time_decay = vb.get((n_attn_heads, cfg.head_size), "time_decay")?;
let time_faaaa = vb.get((n_attn_heads, cfg.head_size), "time_faaaa")?;
let time_mix_gate = vb.get((1, 1, cfg.hidden_size), "time_mix_gate")?;
Ok(Self {
key,
value,
receptance,
gate,
output,
ln_x,
time_mix_key,
time_mix_value,
time_mix_receptance,
time_decay,
time_faaaa,
time_mix_gate,
layer_id,
n_attn_heads,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let h = self.time_decay.dim(0)?;
let (b, t, s) = xs.dims3()?;
let s = s / h;
let (receptance, key, value, gate) = {
// exctract key-value
let shifted = state.per_layer[self.layer_id].extract_key_value.clone();
let shifted = if shifted.rank() == 2 {
shifted.unsqueeze(1)?
} else {
shifted
};
let key = ((xs * &self.time_mix_key)? + &shifted * (1.0 - &self.time_mix_key)?)?;
let value = ((xs * &self.time_mix_value)? + &shifted * (1.0 - &self.time_mix_value)?)?;
let receptance = ((xs * &self.time_mix_receptance)?
+ &shifted * (1.0 - &self.time_mix_receptance)?)?;
let gate = ((xs * &self.time_mix_gate)? + &shifted * (1.0 - &self.time_mix_gate)?)?;
let key = self.key.forward(&key)?;
let value = self.value.forward(&value)?;
let receptance = self.receptance.forward(&receptance)?;
let gate = candle_nn::ops::silu(&self.gate.forward(&gate)?)?;
state.per_layer[self.layer_id].extract_key_value = xs.i((.., t - 1))?;
(receptance, key, value, gate)
};
// linear attention
let mut state_ = state.per_layer[self.layer_id].linear_attention.clone();
let key = key.reshape((b, t, h, s))?.permute((0, 2, 3, 1))?;
let value = value.reshape((b, t, h, s))?.transpose(1, 2)?;
let receptance = receptance.reshape((b, t, h, s))?.transpose(1, 2)?;
let time_decay = self
.time_decay
.exp()?
.neg()?
.exp()?
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let time_faaaa =
self.time_faaaa
.reshape(((), 1, 1))?
.reshape((self.n_attn_heads, (), 1))?;
let mut out: Vec<Tensor> = Vec::with_capacity(t);
for t_ in 0..t {
//
let rt = receptance.i((.., .., t_..t_ + 1))?;
let kt = key.i((.., .., .., t_..t_ + 1))?;
let vt = value.i((.., .., t_..t_ + 1))?;
let at = kt.matmul(&vt)?;
let rhs = (time_faaaa.broadcast_mul(&at)? + &state_)?;
let out_ = rt.matmul(&rhs)?.squeeze(2)?;
state_ = (&at + time_decay.broadcast_mul(&state_))?;
out.push(out_)
}
let out = Tensor::cat(&out, 1)?.reshape((b * t, h * s, 1))?;
let out = out.apply(&self.ln_x)?.reshape((b, t, h * s))?;
let out = (out * gate)?.apply(&self.output)?;
state.per_layer[self.layer_id].linear_attention = state_;
Ok(out)
}
}
#[derive(Debug, Clone)]
struct FeedForward {
time_mix_key: Tensor,
time_mix_receptance: Tensor,
key: Linear,
receptance: Linear,
value: Linear,
layer_id: usize,
}
impl FeedForward {
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let int_size = cfg
.intermediate_size
.unwrap_or(((cfg.hidden_size as f64 * 3.5) as usize) / 32 * 32);
let key = linear(cfg.hidden_size, int_size, vb.pp("key"))?;
let receptance = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("receptance"))?;
let value = linear(int_size, cfg.hidden_size, vb.pp("value"))?;
let time_mix_key = vb.get((1, 1, cfg.hidden_size), "time_mix_key")?;
let time_mix_receptance = vb.get((1, 1, cfg.hidden_size), "time_mix_receptance")?;
Ok(Self {
key,
receptance,
value,
time_mix_key,
time_mix_receptance,
layer_id,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let shifted = &state.per_layer[self.layer_id].feed_forward;
let key = (xs.broadcast_mul(&self.time_mix_key)?
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_key)?)?)?;
let receptance = (xs.broadcast_mul(&self.time_mix_receptance)?
+ shifted.broadcast_mul(&(1.0 - &self.time_mix_receptance)?)?)?;
let key = key.apply(&self.key)?.relu()?.sqr()?;
let value = key.apply(&self.value)?;
let receptance = candle_nn::ops::sigmoid(&receptance.apply(&self.receptance)?)?;
state.per_layer[self.layer_id].feed_forward = xs.i((.., xs.dim(1)? - 1))?;
let xs = (receptance * value)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
struct Block {
pre_ln: Option<LayerNorm>,
ln1: LayerNorm,
ln2: LayerNorm,
attention: SelfAttention,
feed_forward: FeedForward,
}
impl Block {
pub fn new(layer_id: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let ln1 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln1"))?;
let ln2 = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("ln2"))?;
let pre_ln = if layer_id == 0 {
let ln = layer_norm(cfg.hidden_size, cfg.layer_norm_epsilon, vb.pp("pre_ln"))?;
Some(ln)
} else {
None
};
let attention = SelfAttention::new(layer_id, cfg, vb.pp("attention"))?;
let feed_forward = FeedForward::new(layer_id, cfg, vb.pp("feed_forward"))?;
Ok(Self {
pre_ln,
ln1,
ln2,
attention,
feed_forward,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let xs = match self.pre_ln.as_ref() {
None => xs.clone(),
Some(pre_ln) => xs.apply(pre_ln)?,
};
let attention = self.attention.forward(&xs.apply(&self.ln1)?, state)?;
let xs = (xs + attention)?;
let feed_forward = self.feed_forward.forward(&xs.apply(&self.ln2)?, state)?;
let xs = (xs + feed_forward)?;
Ok(xs)
}
}
#[derive(Debug, Clone)]
pub struct Model {
embeddings: Embedding,
blocks: Vec<Block>,
ln_out: LayerNorm,
head: Linear,
rescale_every: usize,
layers_are_rescaled: bool,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("rwkv");
let embeddings = embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embeddings"))?;
let mut blocks = Vec::with_capacity(cfg.num_hidden_layers);
let vb_b = vb_m.pp("blocks");
for block_index in 0..cfg.num_hidden_layers {
let block = Block::new(block_index, cfg, vb_b.pp(block_index))?;
blocks.push(block)
}
let ln_out = layer_norm(cfg.hidden_size, 1e-5, vb_m.pp("ln_out"))?;
let head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("head"))?;
Ok(Self {
embeddings,
blocks,
ln_out,
head,
rescale_every: cfg.rescale_every,
layers_are_rescaled: false, // This seem to only happen for the f16/bf16 dtypes.
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (_b_size, _seq_len) = xs.dims2()?;
let mut xs = xs.apply(&self.embeddings)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
xs = block.forward(&xs, state)?;
if self.layers_are_rescaled && (block_idx + 1) % self.rescale_every == 0 {
xs = (xs / 2.)?
}
}
let xs = xs.apply(&self.ln_out)?.apply(&self.head)?;
state.pos += 1;
Ok(xs)
}
}
type Bytes = Vec<u8>;
// https://github.com/BlinkDL/ChatRWKV/blob/095e812aef15a1f74107f6c39d13578a2412dc46/RWKV_v5_demo.py#L14
pub struct Tokenizer {
table: Vec<Vec<Vec<Bytes>>>,
good: Vec<HashSet<u8>>,
idx2token: HashMap<u32, Vec<u8>>,
token2idx: HashMap<Vec<u8>, u32>,
}
impl Tokenizer {
pub fn new<P: AsRef<std::path::Path>>(p: P) -> Result<Self> {
let file = std::fs::File::open(p)?;
let token2idx: HashMap<String, u32> =
serde_json::from_reader(file).map_err(candle::Error::wrap)?;
let token2idx = token2idx
.into_iter()
.map(|(key, value)| (key.into_bytes(), value))
.collect::<HashMap<_, _>>();
let idx2token = token2idx
.iter()
.map(|(key, value)| (*value, key.to_vec()))
.collect::<HashMap<_, _>>();
let max_idx = token2idx.values().copied().max().unwrap_or(0);
let mut table = vec![vec![vec![]; 256]; 256];
let mut good = vec![HashSet::new(); 256];
for idx in (0..(1 + max_idx)).rev() {
let s = match idx2token.get(&idx) {
None => continue,
Some(s) => s,
};
if s.len() >= 2 {
let (s0, s1) = (s[0], s[1]);
table[s0 as usize][s1 as usize].push(s.to_vec());
good[s0 as usize].insert(s1);
}
}
Ok(Self {
table,
good,
idx2token,
token2idx,
})
}
pub fn decode_bytes(&self, tokens: &[u32]) -> Vec<u8> {
let mut v = Vec::new();
for token_id in tokens.iter() {
if let Some(token) = self.idx2token.get(token_id) {
v.extend_from_slice(token.as_slice())
}
}
v
}
pub fn decode(&self, tokens: &[u32]) -> Result<String> {
let bytes = self.decode_bytes(tokens);
String::from_utf8(bytes).map_err(candle::Error::wrap)
}
pub fn encode_bytes(&self, bytes: &[u8]) -> Result<Vec<u32>> {
let mut tokens = Vec::new();
let mut i = 0;
while i < bytes.len() {
let mut s = vec![bytes[i]];
if i + 1 < bytes.len() && self.good[bytes[i] as usize].contains(&bytes[i + 1]) {
let table = &self.table[bytes[i] as usize][bytes[i + 1] as usize];
for table_elem in table.iter() {
if bytes[i..].starts_with(table_elem) {
s = table_elem.to_vec();
break;
}
}
}
i += s.len();
let token = match self.token2idx.get(&s) {
None => candle::bail!("unexpected token '{}' {s:?}", String::from_utf8_lossy(&s)),
Some(token) => *token,
};
tokens.push(token)
}
Ok(tokens)
}
pub fn encode(&self, str: &str) -> Result<Vec<u32>> {
self.encode_bytes(str.as_bytes())
}
}

View File

@ -0,0 +1,156 @@
#![allow(unused)]
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::{conv1d, embedding, linear, Conv1d, Conv1dConfig, Embedding, Linear, VarBuilder};
pub struct AdaLayerNorm {
eps: f64,
dim: usize,
scale: Embedding,
shift: Embedding,
}
fn layer_norm(x: &Tensor, eps: f64) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let hidden_size = x.dim(D::Minus1)?;
let x = x.to_dtype(internal_dtype)?;
let x = {
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
x.broadcast_sub(&mean_x)?
};
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + eps)?.sqrt()?)?;
x_normed.to_dtype(x_dtype)
}
impl AdaLayerNorm {
pub fn new(
num_embeddings: usize,
embedding_dim: usize,
eps: f64,
vb: VarBuilder,
) -> Result<Self> {
let scale = embedding(num_embeddings, embedding_dim, vb.pp("scale"))?;
let shift = embedding(num_embeddings, embedding_dim, vb.pp("shift"))?;
Ok(Self {
eps,
dim: embedding_dim,
scale,
shift,
})
}
pub fn forward(&self, xs: &Tensor, cond_embedding_id: &Tensor) -> Result<Tensor> {
let scale = self.scale.forward(cond_embedding_id)?;
let shift = self.shift.forward(cond_embedding_id)?;
let xs = layer_norm(xs, self.eps)?;
xs * scale + shift
}
}
pub struct ConvNeXtBlock {
dwconv: Conv1d,
pwconv1: Linear,
pwconv2: Linear,
gamma: Option<Tensor>,
}
impl ConvNeXtBlock {
pub fn new(
dim: usize,
intermediate_dim: usize,
layer_scale_init_value: f64,
adanorm_num_embeddings: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let dwconv = {
let cfg = Conv1dConfig {
padding: 3,
groups: dim,
..Default::default()
};
conv1d(dim, dim, 7, cfg, vb.pp("dwconv"))?
};
let pwconv1 = linear(dim, intermediate_dim, vb.pp("pwconv1"))?;
let pwconv2 = linear(intermediate_dim, dim, vb.pp("pwconv2"))?;
let gamma = if layer_scale_init_value > 0. {
Some(vb.get(dim, "gamma")?)
} else {
None
};
Ok(Self {
dwconv,
pwconv1,
pwconv2,
gamma,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.dwconv)?.transpose(1, 2)?;
// TODO: norm
let xs = xs.apply(&self.pwconv1)?.gelu()?.apply(&self.pwconv2)?;
let xs = match self.gamma.as_ref() {
Some(gamma) => (gamma * xs)?,
None => xs,
};
xs.transpose(1, 2)? + residual
}
}
struct VocosBackbone {
embed: Conv1d,
convnext: Vec<ConvNeXtBlock>,
final_layer_norm: candle_nn::LayerNorm,
}
impl VocosBackbone {
pub fn new(
input_channels: usize,
dim: usize,
intermediate_dim: usize,
num_layers: dim,
layer_scale_init_value: f64,
adanorm_num_embeddings: Option<usize>,
vb: VarBuilder,
) -> Result<Self> {
let embed = {
let cfg = Conv1dConfig {
padding: 3,
..Default::default()
};
conv1d(input_channels, dim, 7, cfg, vb.pp("embed"))?
};
let mut convnext = Vec::with_capacity(num_layers);
let vb_c = vb.pp("convnext");
for i in 0..num_layers {
let block = ConvNeXtBlock::new(
dim,
intermediate_dim,
layer_scale_init_value,
adanorm_num_embeddings,
vb_c.pp(i),
)?;
}
let final_layer_norm = candle_nn::layer_norm(dim, 1e-6, vb.pp("final_layer_norm"))?;
Ok(Self {
embed,
convnext,
final_layer_norm,
})
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = xs.apply(&self.embed)?;
// TODO: norm
let mut xs = xs.transpose(1, 2)?;
for conv_block in self.convnext.iter() {
xs = conv_block.forward(&xs)?
}
xs.apply(&self.final_layer_norm)
}
}