Compare commits

..

11 Commits

Author SHA1 Message Date
67d93b4f42 More happy tests. 2024-01-15 18:46:18 +01:00
c35d7d50db Making the CI happy. 2024-01-15 18:31:09 +01:00
9694671bbf Not implementing quantized. 2024-01-15 18:00:43 +01:00
3dbf65ef20 Rebase after phi2 merge + fix replit default to CPU. 2024-01-15 17:52:49 +01:00
b2db5adf82 Bad code removal. 2024-01-15 17:43:00 +01:00
9ef040338d After rebase. 2024-01-15 17:43:00 +01:00
3aefc709c7 Cleanup the fence. 2024-01-15 17:43:00 +01:00
c8c603ce96 Removing the fences speeds everything up and *is* correct this time... 2024-01-15 17:43:00 +01:00
61ad8d91cc Fix the rebase. 2024-01-15 17:43:00 +01:00
2cd1e59c9e Cleanup. 2024-01-15 17:43:00 +01:00
9c4b4f0da0 Metal quantized modifications proposal.
- Add a device param, wherever needed.
- Create new QMetal storage thing that implements QuantizedType.
- Update everywhere needed.

Fix Python.

Fixing examples.

Fix: fmt + clippy + stub.

Moving everything around.

Only missing the actual implems.

Fixing everything + adding dequantized kernels.

More work.

Fixing matmul.

Fmt + Clippy

Some clippy fixes.

Working state.

Q2K Metal -> Bugged (also present in GGML).
Q4K CPU -> Bugged (present previously, new test catch it).
Q5K CPU -> Bugged (present previously).
Q8_1 Both -> Never really implemented it seems
Q8K metal -> Never implemented in metal

Fixing Q2K bug (present in ggml).
2024-01-15 17:42:58 +01:00
105 changed files with 536 additions and 13597 deletions

View File

@ -5,15 +5,49 @@ on:
pull_request: pull_request:
jobs: jobs:
start-runner:
name: Start self-hosted EC2 runner
runs-on: ubuntu-latest
# Don't run on forks, they won't have access to secrets anyway.
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
env:
AWS_REGION: us-east-1
EC2_AMI_ID: ami-03cfed9ea28f4b002
EC2_INSTANCE_TYPE: g5.xlarge
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
EC2_SECURITY_GROUP: sg-030175c435ac141d6
outputs:
label: ${{ steps.start-ec2-runner.outputs.label }}
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Start EC2 runner
id: start-ec2-runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: start
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
ec2-image-id: ${{ env.EC2_AMI_ID }}
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
subnet-id: ${{ env.EC2_SUBNET_ID }}
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
aws-resource-tags: > # optional, requires additional permissions
[
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
]
test-cuda: test-cuda:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true cancel-in-progress: true
runs-on: [single-gpu, nvidia-gpu, t4, ci] needs: start-runner # required to start the main job when the runner is ready
container: runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
options: --gpus 0
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
permissions: permissions:
contents: write contents: write
packages: write packages: write
@ -24,10 +58,32 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: Install dependencies
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
- name: Install Rust Stable - name: Install Rust Stable
uses: actions-rust-lang/setup-rust-toolchain@v1 run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2 - uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- name: Test (cuda) - name: Test (cuda)
run: cargo test --features cuda run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
name: Stop self-hosted EC2 runner
needs:
- start-runner
- test-cuda
runs-on: ubuntu-latest
env:
AWS_REGION: us-east-1
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v1
with:
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
aws-region: ${{ env.AWS_REGION }}
- name: Stop EC2 runner
uses: philschmid/philschmid-ec2-github-runner@main
with:
mode: stop
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}

View File

@ -19,7 +19,7 @@ exclude = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "0.4.0" version = "0.3.3"
edition = "2021" edition = "2021"
description = "Minimalist ML framework." description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle" repository = "https://github.com/huggingface/candle"
@ -31,14 +31,14 @@ license = "MIT OR Apache-2.0"
accelerate-src = { version = "0.3.2" } accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] } anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3" byteorder = "1.4.3"
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" } candle = { path = "./candle-core", package = "candle-core" }
candle-datasets = { path = "./candle-datasets", version = "0.4.0" } candle-datasets = { path = "./candle-datasets" }
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" } candle-flash-attn = { path = "./candle-flash-attn" }
candle-kernels = { path = "./candle-kernels", version = "0.4.0" } candle-kernels = { path = "./candle-kernels" }
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" } candle-metal-kernels = { path = "./candle-metal-kernels" }
candle-nn = { path = "./candle-nn", version = "0.4.0" } candle-nn = { path = "./candle-nn" }
candle-onnx = { path = "./candle-onnx", version = "0.4.0" } candle-onnx = { path = "./candle-onnx" }
candle-transformers = { path = "./candle-transformers", version = "0.4.0" } candle-transformers = { path = "./candle-transformers" }
clap = { version = "4.2.4", features = ["derive"] } clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false } criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.10.0", features = ["f16"] } cudarc = { version = "0.10.0", features = ["f16"] }
@ -53,12 +53,12 @@ log = "0.4"
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] } memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
num_cpus = "1.15.0" num_cpus = "1.15.0"
num-traits = "0.2.15" num-traits = "0.2.15"
parquet = { version = "50.0.0" } parquet = { version = "45.0.0" }
rand = "0.8.5" rand = "0.8.5"
rand_distr = "0.4.3" rand_distr = "0.4.3"
rayon = "1.7.0" rayon = "1.7.0"
rusttype = { version = "0.9", default-features = false } rusttype = { version = "0.9", default-features = false }
safetensors = "0.4.1" safetensors = "0.3.1"
serde = { version = "1.0.171", features = ["derive"] } serde = { version = "1.0.171", features = ["derive"] }
serde_plain = "1.0.2" serde_plain = "1.0.2"
serde_json = "1.0.99" serde_json = "1.0.99"

View File

@ -65,9 +65,8 @@ We also provide a some command line based examples using state of the art models
- [Falcon](./candle-examples/examples/falcon/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b. - [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets. Also supports pre-trained on 1T tokens of English and code datasets.
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants. - [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
- [Mamba](./candle-examples/examples/mamba/): an inference only
implementation of the Mamba state space model. implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28. better performance than all publicly available 13b models as of 2023-09-28.
@ -112,10 +111,9 @@ We also provide a some command line based examples using state of the art models
evaluation, segmentation). evaluation, segmentation).
- [VGG](./candle-examples/examples/vgg/), - [VGG](./candle-examples/examples/vgg/),
[RepVGG](./candle-examples/examples/repvgg): computer vision models. [RepVGG](./candle-examples/examples/repvgg): computer vision models.
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to - [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image. generate captions for an image.
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
dedicated submodels for hand-writing and printed recognition.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation - [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text. model, generates the translated text from the input text.
@ -186,10 +184,10 @@ If you have an addition to this list, please submit a pull request.
- Falcon. - Falcon.
- StarCoder. - StarCoder.
- Phi 1, 1.5, and 2. - Phi 1, 1.5, and 2.
- Mamba, Minimal Mamba - Minimal Mamba
- Mistral 7b v0.1. - Mistral 7b v0.1.
- Mixtral 8x7b v0.1. - Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B. - StableLM-3B-4E1T.
- Replit-code-v1.5-3B. - Replit-code-v1.5-3B.
- Bert. - Bert.
- Yi-6B and Yi-34B. - Yi-6B and Yi-34B.
@ -208,9 +206,8 @@ If you have an addition to this list, please submit a pull request.
- Wurstchen v2. - Wurstchen v2.
- Image to text. - Image to text.
- BLIP. - BLIP.
- TrOCR.
- Computer Vision Models. - Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT. - DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
- yolo-v3, yolo-v8. - yolo-v3, yolo-v8.
- Segment-Anything Model (SAM). - Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files. - File formats: load models from safetensors, npz, ggml, or PyTorch files.

View File

@ -2,8 +2,7 @@ mod benchmarks;
use criterion::criterion_main; use criterion::criterion_main;
criterion_main!( criterion_main!(
benchmarks::affine::benches,
benchmarks::matmul::benches, benchmarks::matmul::benches,
benchmarks::random::benches, benchmarks::affine::benches,
benchmarks::where_cond::benches benchmarks::where_cond::benches
); );

View File

@ -1,6 +1,5 @@
pub(crate) mod affine; pub(crate) mod affine;
pub(crate) mod matmul; pub(crate) mod matmul;
pub(crate) mod random;
pub(crate) mod where_cond; pub(crate) mod where_cond;
use candle_core::{Device, Result}; use candle_core::{Device, Result};

View File

@ -1,63 +0,0 @@
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, Criterion, Throughput};
use std::time::Instant;
fn rand_uniform(a: &Tensor) {
a.rand_like(-1.0, 123.0).unwrap();
}
fn rand_normal(a: &Tensor) {
a.randn_like(100.0, 15.0).unwrap();
}
fn run_random_bench(c: &mut Criterion, device: &Device) {
let b = 1;
let rows = 2048;
let cols = 2048;
let dtype = DType::F32;
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let flops = b * rows * cols * dtype.size_in_bytes();
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_uniform(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
let mut group = c.benchmark_group(device.bench_name("random_normal"));
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |benches| {
benches.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
rand_normal(black_box(&tensor));
}
device.sync().unwrap();
start.elapsed()
})
});
group.finish();
}
fn criterion_benchmark(c: &mut Criterion) {
let handler = BenchDeviceHandler::new().unwrap();
for device in handler.devices {
run_random_bench(c, &device);
}
}
criterion_group!(benches, criterion_benchmark);

View File

@ -196,7 +196,7 @@ fn run_ls(
} }
} }
Format::Pth => { Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?; let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name)); tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() { for tensor_info in tensors.iter() {
println!( println!(

View File

@ -175,7 +175,7 @@ impl Tensor {
// the backprop graph of the backprop itself. This would be an issue for second order // the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment. // derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b); let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach() }; let grad = if do_not_detach { grad } else { grad.detach()? };
if let Some(op) = node.op() { if let Some(op) = node.op() {
match op { match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => { Op::Binary(lhs, rhs, BinaryOp::Add) => {

View File

@ -1149,55 +1149,6 @@ impl<'a> Map2 for Conv2D<'a> {
} }
} }
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
impl<'a> Map2 for ConvTranspose1D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
&self,
inp: &CudaSlice<T>,
inp_l: &Layout,
k: &CudaSlice<T>,
k_l: &Layout,
dev: &CudaDevice,
) -> Result<CudaSlice<T>> {
// Kernel shape: (c_in_k, c_out, l_k)
// Input shape: (b_size, c_in, l_in)
let p = &self.0;
let l_out = p.l_out();
let dst_el = p.c_out * l_out * p.b_size;
let inp = &inp.slice(inp_l.start_offset()..);
let k = &k.slice(k_l.start_offset()..);
let shape = inp_l.shape();
let dims = shape.dims();
let el = shape.elem_count();
// SAFETY: Set later by running the kernel.
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
let ds = if dims.len() == 3 {
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
} else {
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
};
let ds = dev.htod_copy(ds).w()?;
let params = (
el,
l_out,
p.stride,
p.padding,
p.output_padding,
p.dilation,
&ds,
inp,
k,
&out,
);
// SAFETY: ffi.
unsafe { func.launch(cfg, params) }.w()?;
Ok(out)
}
}
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D); struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
impl<'a> Map2 for ConvTranspose2D<'a> { impl<'a> Map2 for ConvTranspose2D<'a> {
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>( fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
@ -1859,15 +1810,12 @@ impl BackendStorage for CudaStorage {
fn conv_transpose1d( fn conv_transpose1d(
&self, &self,
l: &Layout, _: &Layout,
kernel: &Self, _: &Self,
kernel_l: &Layout, _: &Layout,
params: &crate::conv::ParamsConvTranspose1D, _: &crate::conv::ParamsConvTranspose1D,
) -> Result<Self> { ) -> Result<Self> {
let device = self.device().clone(); todo!()
let slice =
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
Ok(Self { slice, device })
} }
#[cfg(not(feature = "cudnn"))] #[cfg(not(feature = "cudnn"))]

View File

@ -7,9 +7,8 @@ use candle_metal_kernels::Kernels;
use metal; use metal;
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, Mutex, RwLock, TryLockError}; use std::sync::{Arc, RwLock, TryLockError};
/// Simple way to catch lock error without /// Simple way to catch lock error without
/// depending on T /// depending on T
@ -102,8 +101,6 @@ pub struct MetalDevice {
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1). /// (strong_count = 1).
buffers: AllocatedBuffers, buffers: AllocatedBuffers,
/// Seed for random number generation.
seed: Arc<Mutex<Buffer>>,
} }
impl std::fmt::Debug for MetalDevice { impl std::fmt::Debug for MetalDevice {
@ -1557,11 +1554,6 @@ impl BackendDevice for MetalDevice {
Ok(val) => val.parse()?, Ok(val) => val.parse()?,
_ => 10, _ => 10,
}; };
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
[299792458].as_ptr() as *const c_void,
4,
MTLResourceOptions::StorageModeManaged,
)));
Ok(Self { Ok(Self {
device, device,
command_queue, command_queue,
@ -1570,10 +1562,13 @@ impl BackendDevice for MetalDevice {
compute_per_buffer, compute_per_buffer,
buffers, buffers,
kernels, kernels,
seed,
}) })
} }
fn set_seed(&self, _seed: u64) -> Result<()> {
crate::bail!("Metal set_seed not implemented")
}
fn location(&self) -> crate::DeviceLocation { fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal { crate::DeviceLocation::Metal {
gpu_id: self.registry_id() as usize, gpu_id: self.registry_id() as usize,
@ -1613,31 +1608,12 @@ impl BackendDevice for MetalDevice {
&self, &self,
shape: &Shape, shape: &Shape,
dtype: DType, dtype: DType,
min: f64, mean: f64,
max: f64, stddev: f64,
) -> Result<Self::Storage> { ) -> Result<Self::Storage> {
let name = match dtype { // TODO is there a better way ?
DType::F32 => "rand_uniform_f32", let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
DType::F16 => "rand_uniform_f16", self.storage_from_cpu_storage(&cpu_storage)
DType::BF16 => "rand_uniform_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_uniform(
&self.device,
&command_buffer,
&self.kernels,
name,
min as f32,
max as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
} }
fn rand_normal( fn rand_normal(
@ -1647,43 +1623,9 @@ impl BackendDevice for MetalDevice {
mean: f64, mean: f64,
stddev: f64, stddev: f64,
) -> Result<Self::Storage> { ) -> Result<Self::Storage> {
let name = match dtype { // TODO is there a better way ?
DType::F32 => "rand_normal_f32", let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
DType::F16 => "rand_normal_f16", self.storage_from_cpu_storage(&cpu_storage)
DType::BF16 => "rand_normal_bf16",
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_random_normal(
&self.device,
&command_buffer,
&self.kernels,
name,
mean as f32,
stddev as f32,
shape.elem_count(),
&*self.seed.lock().unwrap(),
&buffer,
)
.map_err(MetalError::from)?;
Ok(Self::Storage::new(buffer, self.clone(), dtype))
}
fn set_seed(&self, seed: u64) -> Result<()> {
let seed: u32 = seed.try_into().map_err(|_| {
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
})?;
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
let contents = seed_buffer.contents();
unsafe {
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
}
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
Ok(())
} }
} }

View File

@ -217,13 +217,6 @@ impl Object {
let args = args.remove(1); let args = args.remove(1);
(callable, args) (callable, args)
} }
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
let mut args = args.tuple()?;
args.remove(0).reduce()?
}
_ => (callable, args), _ => (callable, args),
}; };
match callable { match callable {
@ -234,11 +227,13 @@ impl Object {
_ => return Ok(None), _ => return Ok(None),
}; };
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?; let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
let mut path = dir_name.to_path_buf();
path.push(file_path);
Ok(Some(TensorInfo { Ok(Some(TensorInfo {
name, name,
dtype, dtype,
layout, layout,
path: format!("{}/{}", dir_name.to_string_lossy(), file_path), path: path.to_string_lossy().into_owned(),
storage_size, storage_size,
})) }))
} }
@ -350,10 +345,8 @@ impl Stack {
module_name, module_name,
class_name, class_name,
} => { } => {
if module_name == "collections" if module_name == "collections" && class_name == "OrderedDict" {
&& (class_name == "OrderedDict" || class_name == "defaultdict") // TODO: have a separate ordered dict.
{
// TODO: have a separate ordered dict and a separate default dict.
Some(Object::Dict(vec![])) Some(Object::Dict(vec![]))
} else { } else {
None None
@ -634,16 +627,9 @@ pub struct TensorInfo {
pub storage_size: usize, pub storage_size: usize,
} }
/// Read the tensor info from a .pth file.
///
/// # Arguments
/// * `file` - The path to the .pth file.
/// * `verbose` - Whether to print debug information.
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>( pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P, file: P,
verbose: bool, verbose: bool,
key: Option<&str>,
) -> Result<Vec<TensorInfo>> { ) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?; let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file); let zip_reader = std::io::BufReader::new(file);
@ -665,9 +651,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
stack.read_loop(&mut reader)?; stack.read_loop(&mut reader)?;
let obj = stack.finalize()?; let obj = stack.finalize()?;
if VERBOSE || verbose { if VERBOSE || verbose {
println!("{obj:#?}"); println!("{obj:?}");
} }
let obj = match obj { let obj = match obj {
Object::Build { callable, args } => match *callable { Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable { Object::Reduce { callable, args: _ } => match *callable {
@ -681,24 +666,6 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
}, },
obj => obj, obj => obj,
}; };
// If key is provided, then we need to extract the state_dict from the object.
let obj = if let Some(key) = key {
if let Object::Dict(key_values) = obj {
key_values
.into_iter()
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
.map(|(_, v)| v)
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
} else {
obj
}
} else {
obj
};
// If the object is a dict, then we can extract the tensor info from it.
// NOTE: We are assuming that the `obj` is state_dict by this stage.
if let Object::Dict(key_values) = obj { if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() { for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) { match value.into_tensor_info(name, &dir_name) {
@ -721,8 +688,8 @@ pub struct PthTensors {
} }
impl PthTensors { impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> { pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?; let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
let tensor_infos = tensor_infos let tensor_infos = tensor_infos
.into_iter() .into_iter()
.map(|ti| (ti.name.to_string(), ti)) .map(|ti| (ti.name.to_string(), ti))
@ -745,12 +712,10 @@ impl PthTensors {
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?); let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?; let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?; let mut reader = zip.by_name(&tensor_info.path)?;
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
let rank = tensor_info.layout.shape().rank();
// Reading the data is a bit tricky as it can be strided, for now only support the basic // Reading the data is a bit tricky as it can be strided, for now only support the basic
// case and when the tensor is fortran contiguous. // case.
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous { if !tensor_info.layout.is_contiguous() {
crate::bail!( crate::bail!(
"cannot retrieve non-contiguous tensors {:?}", "cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout tensor_info.layout
@ -768,33 +733,13 @@ impl PthTensors {
tensor_info.dtype, tensor_info.dtype,
&mut reader, &mut reader,
)?; )?;
if rank > 1 && is_fortran_contiguous {
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
let tensor = tensor.reshape(shape_reversed)?;
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
let tensor = tensor.permute(dim_indeces_reversed)?;
Ok(Some(tensor))
} else {
Ok(Some(tensor)) Ok(Some(tensor))
} }
} }
}
/// Read all the tensors from a PyTorch pth file with a given key. /// Read all the tensors from a PyTorch pth file.
/// pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
/// # Arguments let pth = PthTensors::new(path)?;
/// * `path` - Path to the pth file.
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
/// contains multiple objects and the state_dict is the one we are interested in.
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
path: P,
key: Option<&str>,
) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path, key)?;
let tensor_names = pth.tensor_infos.keys(); let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len()); let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names { for name in tensor_names {
@ -804,11 +749,3 @@ pub fn read_all_with_key<P: AsRef<std::path::Path>>(
} }
Ok(tensors) Ok(tensors)
} }
/// Read all the tensors from a PyTorch pth file.
///
/// # Arguments
/// * `path` - Path to the pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
read_all_with_key(path, None)
}

View File

@ -1,43 +0,0 @@
#![allow(unused)]
use super::GgmlDType;
use crate::{Error, MetalDevice, MetalStorage, Result};
pub struct QMetalStorage {
dtype: GgmlDType,
device: MetalDevice,
}
impl QMetalStorage {
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn dtype(&self) -> GgmlDType {
self.dtype
}
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
pub fn storage_size_in_bytes(&self) -> usize {
0
}
pub fn fwd(
&self,
_self_shape: &crate::Shape,
_storage: &MetalStorage,
_layout: &crate::Layout,
) -> Result<(MetalStorage, crate::Shape)> {
Err(Error::NotCompiledWithMetalSupport)
}
}

View File

@ -233,7 +233,6 @@ pub struct Content {
pub hparams: HParams, pub hparams: HParams,
pub vocab: Vocab, pub vocab: Vocab,
pub tensors: HashMap<String, super::QTensor>, pub tensors: HashMap<String, super::QTensor>,
pub device: Device,
} }
impl Content { impl Content {
@ -253,13 +252,11 @@ impl Content {
let (name, tensor) = read_one_tensor(reader, magic, device)?; let (name, tensor) = read_one_tensor(reader, magic, device)?;
tensors.insert(name, tensor); tensors.insert(name, tensor);
} }
let device = device.clone();
Ok(Self { Ok(Self {
magic, magic,
hparams, hparams,
vocab, vocab,
tensors, tensors,
device,
}) })
} }

View File

@ -1,6 +1,5 @@
use super::{GgmlDType, QStorage}; use super::{GgmlDType, QStorage};
use crate::backend::BackendStorage; use crate::{DType, MetalDevice, MetalStorage, Result};
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
use metal::Buffer; use metal::Buffer;
use std::sync::Arc; use std::sync::Arc;
@ -11,28 +10,22 @@ pub struct QMetalStorage {
} }
impl QMetalStorage { impl QMetalStorage {
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
let size = elem_count * dtype.type_size() / dtype.block_size();
let buffer = device.allocate_zeros(size)?;
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
pub fn dtype(&self) -> GgmlDType { pub fn dtype(&self) -> GgmlDType {
self.dtype self.dtype
} }
pub fn device(&self) -> &MetalDevice {
&self.device
}
pub fn buffer(&self) -> &Buffer { pub fn buffer(&self) -> &Buffer {
&self.buffer &self.buffer
} }
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
Self {
device,
buffer,
dtype,
}
}
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> { pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
let buffer = self.device.new_buffer_managed(self.buffer.length())?; let buffer = self.device.new_buffer_managed(self.buffer.length())?;
let command_buffer = self.device.command_buffer()?; let command_buffer = self.device.command_buffer()?;
@ -137,59 +130,6 @@ impl QMetalStorage {
self.buffer = buffer; self.buffer = buffer;
Ok(()) Ok(())
} }
pub fn storage_size_in_bytes(&self) -> usize {
self.buffer.length() as usize
}
pub fn fwd(
&self,
self_shape: &Shape,
storage: &MetalStorage,
layout: &crate::Layout,
) -> Result<(MetalStorage, Shape)> {
use crate::MetalError;
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
self.dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&self.buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
Ok((dst_storage, dst_shape))
}
} }
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>( pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
@ -211,24 +151,3 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let slice = unsafe { std::slice::from_raw_parts(ptr, n) }; let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec() slice.to_vec()
} }
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
}
}

View File

@ -1,19 +1,16 @@
#[cfg(feature = "metal")]
use crate::{backend::BackendStorage, DType};
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor}; use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
use k_quants::*; use k_quants::*;
use std::borrow::Cow; use std::borrow::Cow;
#[cfg(target_feature = "avx")] #[cfg(target_feature = "avx")]
pub mod avx; pub mod avx;
mod dummy_metal;
pub mod ggml_file; pub mod ggml_file;
pub mod gguf_file; pub mod gguf_file;
pub mod k_quants; pub mod k_quants;
#[cfg(feature = "metal")] #[cfg(feature = "metal")]
pub mod metal; pub mod metal;
#[cfg(not(feature = "metal"))]
mod metal {
pub use super::dummy_metal::*;
}
#[cfg(target_feature = "neon")] #[cfg(target_feature = "neon")]
pub mod neon; pub mod neon;
#[cfg(target_feature = "simd128")] #[cfg(target_feature = "simd128")]
@ -35,9 +32,19 @@ impl Device {
let storage = dtype.cpu_zeros(elem_count); let storage = dtype.cpu_zeros(elem_count);
Ok(QStorage::Cpu(storage)) Ok(QStorage::Cpu(storage))
} }
#[cfg(feature = "metal")]
Device::Metal(metal) => { Device::Metal(metal) => {
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?; let size = elem_count * dtype.type_size() / dtype.block_size();
Ok(QStorage::Metal(storage)) let buffer = metal.allocate_zeros(size)?;
Ok(QStorage::Metal(metal::QMetalStorage::new(
buffer,
metal.clone(),
dtype,
)))
}
#[cfg(not(feature = "metal"))]
Device::Metal(_metal) => {
crate::bail!("Metal feature not activated");
} }
Device::Cuda(_cuda) => { Device::Cuda(_cuda) => {
crate::bail!("Cuda ggml quantization not supported"); crate::bail!("Cuda ggml quantization not supported");
@ -48,6 +55,7 @@ impl Device {
pub enum QStorage { pub enum QStorage {
Cpu(Box<dyn QuantizedType>), Cpu(Box<dyn QuantizedType>),
#[cfg(feature = "metal")]
Metal(metal::QMetalStorage), Metal(metal::QMetalStorage),
} }
@ -55,6 +63,7 @@ impl QStorage {
fn block_size(&self) -> usize { fn block_size(&self) -> usize {
match self { match self {
QStorage::Cpu(storage) => storage.block_size(), QStorage::Cpu(storage) => storage.block_size(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.dtype().block_size(), QStorage::Metal(storage) => storage.dtype().block_size(),
} }
} }
@ -62,21 +71,16 @@ impl QStorage {
fn dtype(&self) -> GgmlDType { fn dtype(&self) -> GgmlDType {
match self { match self {
QStorage::Cpu(storage) => storage.dtype(), QStorage::Cpu(storage) => storage.dtype(),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.dtype(), QStorage::Metal(storage) => storage.dtype(),
} }
} }
fn device(&self) -> Device {
match self {
QStorage::Cpu(_storage) => Device::Cpu,
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
fn size_in_bytes(&self) -> usize { fn size_in_bytes(&self) -> usize {
match self { match self {
QStorage::Cpu(storage) => storage.storage_size_in_bytes(), QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
QStorage::Metal(storage) => storage.storage_size_in_bytes(), #[cfg(feature = "metal")]
QStorage::Metal(storage) => storage.buffer().length() as usize,
} }
} }
@ -85,6 +89,7 @@ impl QStorage {
(QStorage::Cpu(storage), Storage::Cpu(src)) => { (QStorage::Cpu(storage), Storage::Cpu(src)) => {
storage.from_float(src.as_slice::<f32>()?)?; storage.from_float(src.as_slice::<f32>()?)?;
} }
#[cfg(feature = "metal")]
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?, (QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
_ => crate::bail!("Invalid dequantize storage locations do not match"), _ => crate::bail!("Invalid dequantize storage locations do not match"),
} }
@ -94,6 +99,7 @@ impl QStorage {
fn dequantize(&self, elem_count: usize) -> Result<Storage> { fn dequantize(&self, elem_count: usize) -> Result<Storage> {
match self { match self {
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)), QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
#[cfg(feature = "metal")]
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)), QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
} }
} }
@ -106,6 +112,7 @@ impl QStorage {
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) }; let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
Ok(Cow::from(data)) Ok(Cow::from(data))
} }
#[cfg(feature = "metal")]
QStorage::Metal(_storage) => { QStorage::Metal(_storage) => {
crate::bail!("not implemented"); crate::bail!("not implemented");
} }
@ -329,10 +336,6 @@ impl QTensor {
self.storage.dtype() self.storage.dtype()
} }
pub fn device(&self) -> Device {
self.storage.device()
}
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
self.shape.rank() self.shape.rank()
} }
@ -424,7 +427,8 @@ impl crate::CustomOp1 for QTensor {
#[allow(clippy::infallible_destructuring_match)] #[allow(clippy::infallible_destructuring_match)]
let self_storage = match &self.storage { let self_storage = match &self.storage {
QStorage::Cpu(storage) => storage, QStorage::Cpu(storage) => storage,
QStorage::Metal(_) => crate::bail!("Invalid storage"), #[cfg(feature = "metal")]
_ => crate::bail!("Invalid storage"),
}; };
let slice = storage.as_slice::<f32>()?; let slice = storage.as_slice::<f32>()?;
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()]; let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
@ -433,16 +437,79 @@ impl crate::CustomOp1 for QTensor {
Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
} }
#[cfg(feature = "metal")]
fn metal_fwd( fn metal_fwd(
&self, &self,
storage: &crate::MetalStorage, storage: &crate::MetalStorage,
layout: &crate::Layout, layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> { ) -> Result<(crate::MetalStorage, Shape)> {
let self_storage = match &self.storage { use crate::MetalError;
QStorage::Metal(metal) => metal,
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let (n, k) = self.shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
let device = storage.device().clone();
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
let (buffer, dtype) = match &self.storage {
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
_ => unreachable!("Cannot call metal matmul on non metal QTensor"), _ => unreachable!("Cannot call metal matmul on non metal QTensor"),
}; };
self_storage.fwd(&self.shape, storage, layout) let command_buffer = device.command_buffer()?;
candle_metal_kernels::call_quantized_matmul_t(
device.device(),
&command_buffer,
device.kernels(),
dtype.into(),
(b, m, n, k),
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
buffer,
&dst,
)
.map_err(MetalError::from)?;
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
Ok((dst_storage, dst_shape))
}
}
#[cfg(feature = "metal")]
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
fn from(value: GgmlDType) -> Self {
match value {
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
}
} }
} }

View File

@ -804,35 +804,6 @@ impl Tensor {
} }
} }
/// Roll the tensor input along the given dimension.
/// Elements that are shifted beyond the last position are re-introduced at the first position.
///
/// ```rust
/// # use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let tensor = tensor.roll(-1, 0)?;
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
where
D: Dim + Clone,
{
let dim = dim.to_index(self.shape(), "roll")?;
let dim_size = self.dim(dim)?;
let shift = shift.rem_euclid(dim_size as i32) as usize;
if shift == 0 {
Ok(self.clone())
} else {
let a = self.narrow(dim, 0, dim_size - shift)?;
let b = self.narrow(dim, dim_size - shift, shift)?;
Tensor::cat(&[&b, &a], dim)
}
}
/// Returns the sum of all elements in the input tensor. The sum is performed over all the /// Returns the sum of all elements in the input tensor. The sum is performed over all the
/// input dimensions. /// input dimensions.
/// ///
@ -1882,9 +1853,9 @@ impl Tensor {
/// this new node. The storage of this tensor is shared with the initial tensor. /// this new node. The storage of this tensor is shared with the initial tensor.
/// ///
/// If the tensor is already detached from the computation graph, the same tensor is returned. /// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Tensor { pub fn detach(&self) -> Result<Tensor> {
if self.op.is_none() && !self.is_variable { if self.op.is_none() && !self.is_variable {
self.clone() Ok(self.clone())
} else { } else {
let tensor_ = Tensor_ { let tensor_ = Tensor_ {
id: TensorId::new(), id: TensorId::new(),
@ -1895,7 +1866,7 @@ impl Tensor {
dtype: self.dtype, dtype: self.dtype,
device: self.device.clone(), device: self.device.clone(),
}; };
Tensor(Arc::new(tensor_)) Ok(Tensor(Arc::new(tensor_)))
} }
} }

View File

@ -107,10 +107,6 @@ impl Var {
Ok(Self(inner)) Ok(Self(inner))
} }
pub fn as_detached_tensor(&self) -> Tensor {
self.0.detach()
}
pub fn as_tensor(&self) -> &Tensor { pub fn as_tensor(&self) -> &Tensor {
&self.0 &self.0
} }

View File

@ -50,6 +50,7 @@ fn conv1d(dev: &Device) -> Result<()> {
test_utils::to_vec1_round(&res.flatten_all()?, 4)?, test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352] [2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
); );
if dev.is_cpu() {
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?; let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
assert_eq!(res.dims(), [1, 2, 7]); assert_eq!(res.dims(), [1, 2, 7]);
assert_eq!( assert_eq!(
@ -59,6 +60,7 @@ fn conv1d(dev: &Device) -> Result<()> {
4.7076, -5.9745, -0.8276, 1.621 4.7076, -5.9745, -0.8276, 1.621
], ],
); );
}
Ok(()) Ok(())
} }

View File

@ -1,37 +0,0 @@
import torch
from collections import OrderedDict
# Write a trivial tensor to a pt file
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
o = OrderedDict()
o["test"] = a
# Write a trivial tensor to a pt file
torch.save(o, "test.pt")
############################################################################################################
# Write a trivial tensor to a pt file with a key
torch.save({"model_state_dict": o}, "test_with_key.pt")
############################################################################################################
# Create a tensor with fortran contiguous memory layout
import numpy as np
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
# For example, creating a 2x3x4 array
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
# Verify the memory order
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
# Step 2: Convert the NumPy array to a PyTorch tensor
tensor_fortran = torch.from_numpy(array_fortran)
# Verify the tensor layout
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
# Step 3: Save the PyTorch tensor to a .pth file
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
print("3D Tensor saved with Fortran layout.")

View File

@ -1,31 +0,0 @@
/// Regression test for pth files not loading on Windows.
#[test]
fn test_pth() {
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_with_key() {
let tensors =
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
.unwrap();
tensors.get("test").unwrap().unwrap();
}
#[test]
fn test_pth_fortran_congiguous() {
let tensors =
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
assert_eq!(
tensor.to_vec3::<i64>().unwrap(),
[
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
]
);
}

Binary file not shown.

Binary file not shown.

View File

@ -30,9 +30,7 @@ rayon = { workspace = true }
safetensors = { workspace = true } safetensors = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"] }
tokenizers = { workspace = true, features = ["onig"] } tokenizers = { workspace = true, features = ["onig"] }
cpal= { version = "0.15.2", optional = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }
@ -45,6 +43,7 @@ rusttype = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
tracing-chrome = { workspace = true } tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 # Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1" tokio = "1.29.1"
@ -62,7 +61,6 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
nccl = ["cuda", "cudarc/nccl", "dep:half"] nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"] onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"] metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
[[example]] [[example]]
name = "llama_multiprocess" name = "llama_multiprocess"
@ -79,7 +77,3 @@ required-features = ["onnx"]
[[example]] [[example]]
name = "onnx_basics" name = "onnx_basics"
required-features = ["onnx"] required-features = ["onnx"]
[[example]]
name = "whisper-microphone"
required-features = ["microphone"]

View File

@ -1,237 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::chatglm::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/chatglm3-6b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("lmz/candle-chatglm".to_string())
.get("chatglm-tokenizer.json")?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::glm3_6b();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,22 +0,0 @@
# candle-convnext
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
This candle implementation uses a pre-trained ConvNeXt network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 84.09%
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
maillot : 0.74%
crash helmet : 0.54%
unicycle, monocycle : 0.44%
```

View File

@ -1,102 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::convnext;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Tiny,
Small,
Base,
Large,
XLarge,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::Tiny => "tiny",
Self::Small => "small",
Self::Base => "base",
Self::Large => "large",
Self::XLarge => "xlarge",
};
// The XLarge model only has an ImageNet-22K variant
let variant = match self {
Self::XLarge => "fb_in22k_ft_in1k",
_ => "fb_in1k",
};
format!("timm/convnext_{name}.{variant}")
}
fn config(&self) -> convnext::Config {
match self {
Self::Tiny => convnext::Config::tiny(),
Self::Small => convnext::Config::small(),
Self::Base => convnext::Config::base(),
Self::Large => convnext::Config::large(),
Self::XLarge => convnext::Config::xlarge(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::Tiny)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -2,9 +2,6 @@
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal). This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
Compared to the mamba example, this version can handle training but is much
slower.
## Running the example ## Running the example
```bash ```bash

View File

@ -1,17 +0,0 @@
# candle-mamba: Mamba implementation
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
the transformer architecture. It leverages State Space Models (SSMs) with the
goal of being computationally efficient on long sequences. The implementation is
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
Compared to the mamba-minimal example, this version is far more efficient but
would only work for inference.
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
```

View File

@ -1,299 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use candle_transformers::models::mamba::{Config, Model, State};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
config: Config,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
config: Config,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
config,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let mut state = State::new(1, &self.config, &self.device)?;
let mut next_logits = None;
for &t in tokens.iter() {
let input = Tensor::new(&[t], &self.device)?;
let logits = self.model.forward(&input, &mut state)?;
next_logits = Some(logits);
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let logits = match next_logits.as_ref() {
Some(logits) => logits,
None => anyhow::bail!("cannot work on an empty prompt"),
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let input = Tensor::new(&[next_token], &self.device)?;
next_logits = Some(self.model.forward(&input, &mut state)?)
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Mamba130m,
Mamba370m,
Mamba790m,
Mamba1_4b,
Mamba2_8b,
Mamba2_8bSlimPj,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Mamba130m => "state-spaces/mamba-130m",
Self::Mamba370m => "state-spaces/mamba-370m",
Self::Mamba790m => "state-spaces/mamba-790m",
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Mamba130m
| Self::Mamba370m
| Self::Mamba790m
| Self::Mamba1_4b
| Self::Mamba2_8bSlimPj => "refs/pr/1",
Self::Mamba2_8b => "refs/pr/4",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "mamba130m")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
config,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,22 +0,0 @@
# candle-mobileone
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
This candle implementation uses a pre-trained MobileOne network for inference. The
classification head has been trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
loaded image Tensor[dims 3, 224, 224; f32]
model built
mountain bike, all-terrain bike, off-roader: 79.33%
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
crash helmet : 2.58%
unicycle, monocycle : 1.70%
alp : 0.21%
```

View File

@ -1,96 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::mobileone;
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
S0,
S1,
S2,
S3,
S4,
}
impl Which {
fn model_filename(&self) -> String {
let name = match self {
Self::S0 => "s0",
Self::S1 => "s1",
Self::S2 => "s2",
Self::S3 => "s3",
Self::S4 => "s4",
};
format!("timm/mobileone_{}.apple_in1k", name)
}
fn config(&self) -> mobileone::Config {
match self {
Self::S0 => mobileone::Config::s0(),
Self::S1 => mobileone::Config::s1(),
Self::S2 => mobileone::Config::s2(),
Self::S3 => mobileone::Config::s3(),
Self::S4 => mobileone::Config::s4(),
}
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(value_enum, long, default_value_t=Which::S0)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let model_name = args.which.model_filename();
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name);
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,39 +1,10 @@
## Using ONNX models in Candle ## Using ONNX models in Candle
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle. This example demonstrates how to run ONNX based models in Candle, the model
being used here is a small sequeezenet variant.
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf). You can run the example with the following command:
You can run the examples with following commands:
```bash ```bash
cargo run --example onnx --features=onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
Use the `--which` flag to specify explicitly which network to use, i.e.
```bash
$ cargo run --example onnx --features=onnx --release -- --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
Finished release [optimized] target(s) in 0.21s
Running `target/release/examples/onnx --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
loaded image Tensor[dims 3, 224, 224; f32]
unicycle, monocycle : 83.23%
ballplayer, baseball player : 3.68%
bearskin, busby, shako : 1.54%
military uniform : 0.78%
cowboy hat, ten-gallon hat : 0.76%
```
```bash
$ cargo run --example onnx --features=onnx --release -- --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
Finished release [optimized] target(s) in 0.20s
Running `target/release/examples/onnx --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
loaded image Tensor[dims 224, 224, 3; f32]
bicycle-built-for-two, tandem bicycle, tandem : 99.16%
mountain bike, all-terrain bike, off-roader : 0.60%
unicycle, monocycle : 0.17%
crash helmet : 0.02%
alp : 0.02%
``` ```

View File

@ -1,281 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::qwen2::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the <|endoftext|> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
enum WhichModel {
#[value(name = "0.5b")]
W0_5b,
#[value(name = "1.8b")]
W1_8b,
#[value(name = "4b")]
W4b,
#[value(name = "7b")]
W7b,
#[value(name = "14b")]
W14b,
#[value(name = "72b")]
W72b,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 10000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
#[arg(long, default_value = "0.5b")]
model: WhichModel,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let size = match args.model {
WhichModel::W0_5b => "0.5B",
WhichModel::W1_8b => "1.8B",
WhichModel::W4b => "4B",
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
};
format!("Qwen/Qwen1.5-{size}")
}
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config_file = repo.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -411,7 +411,7 @@ impl DDPG<'_> {
pub fn actions(&mut self, state: &Tensor) -> Result<f32> { pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self let actions = self
.actor .actor
.forward(&state.detach().unsqueeze(0)?)? .forward(&state.detach()?.unsqueeze(0)?)?
.squeeze(0)?; .squeeze(0)?;
let actions = if self.train { let actions = if self.train {
(actions + self.ou_noise.sample()?)? (actions + self.ou_noise.sample()?)?

View File

@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
loop { loop {
let action = { let action = {
let action_probs: Vec<f32> = let action_probs: Vec<f32> =
softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)? softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
.squeeze(0)? .squeeze(0)?
.to_vec1()?; .to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64 weighted_sample(action_probs, &mut rng)? as i64
@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)? let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.detach(); .detach()?;
let actions_mask = { let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect(); let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
.unwrap() .unwrap()
}) })
.collect(); .collect();
Tensor::stack(&actions_mask, 0)?.detach() Tensor::stack(&actions_mask, 0)?.detach()?
}; };
let states = { let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect(); let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
Tensor::stack(&states, 0)?.detach() Tensor::stack(&states, 0)?.detach()?
}; };
let log_probs = actions_mask let log_probs = actions_mask

View File

@ -8,13 +8,6 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
Note that this model is gated so you will have to request access on the Hub in Note that this model is gated so you will have to request access on the Hub in
order to be able to use it. order to be able to use it.
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
Candle, so to run it you can download a somewhat compatible
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
and pass it via the --tokenizer-file argument.
## Running some example ## Running some example
```bash ```bash

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src; extern crate accelerate_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum}; use clap::Parser;
use candle_transformers::models::quantized_stable_lm::Model as QStableLM; use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
use candle_transformers::models::stable_lm::{Config, Model as StableLM}; use candle_transformers::models::stable_lm::{Config, Model as StableLM};
@ -122,16 +122,6 @@ impl TextGeneration {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
enum Which {
V1Orig,
V1,
V1Zephyr,
V2,
V2Zephyr,
Code,
}
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)] #[command(author, version, about, long_about = None)]
struct Args { struct Args {
@ -162,18 +152,15 @@ struct Args {
seed: u64, seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 1000)] #[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize, sample_len: usize,
#[arg(long)] #[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
model_id: Option<String>, model_id: String,
#[arg(long, default_value = "main")] #[arg(long, default_value = "main")]
revision: String, revision: String,
#[arg(long, default_value = "v2")]
which: Which,
#[arg(long)] #[arg(long)]
tokenizer_file: Option<String>, tokenizer_file: Option<String>,
@ -220,80 +207,33 @@ fn main() -> Result<()> {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let api = Api::new()?; let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => match args.which {
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
Which::Code => "stabilityai/stable-code-3b".to_string(),
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
},
};
let repo = api.repo(Repo::with_revision( let repo = api.repo(Repo::with_revision(
model_id, args.model_id,
RepoType::Model, RepoType::Model,
args.revision, args.revision,
)); ));
let tokenizer_filename = match args.tokenizer_file { let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file), Some(file) => std::path::PathBuf::from(file),
None => match args.which { None => repo.get("tokenizer.json")?,
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
repo.get("tokenizer.json")?
}
Which::V2 | Which::V2Zephyr => api
.model("lmz/candle-stablelm".to_string())
.get("tokenizer-gpt4.json")?,
},
}; };
let filenames = match args.weight_files { let filenames = match args.weight_files {
Some(files) => files Some(files) => files
.split(',') .split(',')
.map(std::path::PathBuf::from) .map(std::path::PathBuf::from)
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
None => match (args.which, args.quantized) { None => {
(Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?], if args.quantized {
(Which::V2, true) => { vec![repo.get("model-q4k.gguf")?]
let gguf = api } else {
.model("lmz/candle-stablelm".to_string())
.get("stablelm-2-1_6b-q4k.gguf")?;
vec![gguf]
}
(Which::V2Zephyr, true) => {
let gguf = api
.model("lmz/candle-stablelm".to_string())
.get("stablelm-2-zephyr-1_6b-q4k.gguf")?;
vec![gguf]
}
(Which::V1Zephyr | Which::Code, true) => {
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
}
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
vec![repo.get("model.safetensors")?] vec![repo.get("model.safetensors")?]
} }
(Which::Code, false) => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
} }
},
}; };
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let config = match args.which { let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
config.set_use_flash_attn(args.use_flash_attn);
config
}
};
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let (model, device) = if args.quantized { let (model, device) = if args.quantized {
let filename = &filenames[0]; let filename = &filenames[0];

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.5 KiB

View File

@ -10,36 +10,15 @@ use clap::{Parser, ValueEnum};
use candle::{DType, Tensor}; use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream; use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use candle_transformers::models::{trocr, vit}; use candle_transformers::models::trocr;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod image_processor; mod image_processor;
#[derive(Clone, Debug, Copy, ValueEnum)] #[derive(Clone, Debug, Copy, ValueEnum)]
enum Which { enum Which {
#[value(name = "base")] Base,
BaseHandwritten, Large,
#[value(name = "large")]
LargeHandwritten,
BasePrinted,
LargePrinted,
}
impl Which {
fn repo_and_branch_name(&self) -> (&str, &str) {
match self {
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
}
}
}
#[derive(Debug, Clone, serde::Deserialize)]
struct Config {
encoder: vit::Config,
decoder: trocr::TrOCRConfig,
} }
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -55,64 +34,63 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// The image file to be processed. /// Text to be translated
#[arg(long)] #[arg(long)]
image: String, image: String,
/// Tokenization config.
#[arg(long)]
tokenizer: Option<String>,
} }
pub fn main() -> anyhow::Result<()> { pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse(); let args = Args::parse();
let api = hf_hub::api::sync::Api::new()?;
let mut tokenizer_dec = { let tokenizer_dec = {
let tokenizer_file = match args.tokenizer { let tokenizer = Api::new()?
None => api
.model(String::from("ToluClassics/candle-trocr-tokenizer")) .model(String::from("ToluClassics/candle-trocr-tokenizer"))
.get("tokenizer.json")?, .get("tokenizer.json")?;
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
}; Tokenizer::from_file(&tokenizer).map_err(E::msg)?
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
TokenOutputStream::new(tokenizer)
}; };
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = { let vb = {
let model = match args.model { let model = match args.model {
Some(model) => std::path::PathBuf::from(model), Some(model) => std::path::PathBuf::from(model),
None => { None => match args.which {
let (repo, branch) = args.which.repo_and_branch_name(); Which::Base => Api::new()?
api.repo(hf_hub::Repo::with_revision( .repo(hf_hub::Repo::with_revision(
repo.to_string(), "microsoft/trocr-base-handwritten".to_string(),
hf_hub::RepoType::Model, hf_hub::RepoType::Model,
branch.to_string(), "refs/pr/3".to_string(),
)) ))
.get("model.safetensors")? .get("model.safetensors")?,
} Which::Large => Api::new()?
.repo(hf_hub::Repo::with_revision(
"microsoft/trocr-large-handwritten".to_string(),
hf_hub::RepoType::Model,
"refs/pr/6".to_string(),
))
.get("model.safetensors")?,
},
}; };
println!("model: {:?}", model); println!("model: {:?}", model);
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? } unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
}; };
let (encoder_config, decoder_config) = { let encoder_config = match args.which {
let (repo, branch) = args.which.repo_and_branch_name(); Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
let config_filename = api Which::Large => {
.repo(hf_hub::Repo::with_revision( candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
repo.to_string(), }
hf_hub::RepoType::Model,
branch.to_string(),
))
.get("config.json")?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
(config.encoder, config.decoder)
}; };
let decoder_config = trocr::TrOCRConfig::default();
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?; let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
let processor_config = image_processor::ProcessorConfig::default(); let config = image_processor::ProcessorConfig::default();
let processor = image_processor::ViTImageProcessor::new(&processor_config); let processor = image_processor::ViTImageProcessor::new(&config);
let image = vec![args.image.as_str()]; let image = vec![args.image.as_str()];
let image = processor.preprocess(image)?; let image = processor.preprocess(image)?;

View File

@ -5,27 +5,12 @@ transcribe image text. See the associated [model
card](https://huggingface.co/microsoft/trocr-base-printed) for details on card](https://huggingface.co/microsoft/trocr-base-printed) for details on
the model itself. the model itself.
Supported models include:
- `--which base`: small handwritten OCR model.
- `--which large`: large handwritten OCR model.
- `--which base-printed`: small printed OCR model.
- `--which large-printed`: large printed OCR model.
## Running an example ## Running an example
```bash ```bash
cargo run --example trocr --release -- --image candle-examples/examples/trocr/assets/trocr.png cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
cargo run --example trocr --release -- --which large --image candle-examples/examples/trocr/assets/trocr.png
cargo run --example trocr --release -- --which base-printed --image candle-examples/examples/trocr/assets/noto.png
cargo run --example trocr --release -- --which large-printed --image candle-examples/examples/trocr/assets/noto.png
``` ```
### Outputs
``` ```
industry , Mr. Brown commented icily . " Let us have a <s> industry , Mr. Brown commented icily . " Let us have a</s>
industry , " Mr. Brown commented icily . " Let us have a
THE QUICK BROWN FOR JUMPS OVER THE LAY DOG
THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG
``` ```

View File

@ -1,673 +0,0 @@
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use anyhow::{Error as E, Result};
use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng};
use std::iter;
use tokenizers::Tokenizer;
mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, Config};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use std::sync::{Arc, Mutex};
pub enum Model {
Normal(m::model::Whisper),
Quantized(m::quantized_model::Whisper),
}
// Maybe we should use some traits rather than doing the dispatch for all these.
impl Model {
pub fn config(&self) -> &Config {
match self {
Self::Normal(m) => &m.config,
Self::Quantized(m) => &m.config,
}
}
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.encoder.forward(x, flush),
Self::Quantized(m) => m.encoder.forward(x, flush),
}
}
pub fn decoder_forward(
&mut self,
x: &Tensor,
xa: &Tensor,
flush: bool,
) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.forward(x, xa, flush),
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
}
}
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Normal(m) => m.decoder.final_linear(x),
Self::Quantized(m) => m.decoder.final_linear(x),
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct DecodingResult {
tokens: Vec<u32>,
text: String,
avg_logprob: f64,
no_speech_prob: f64,
temperature: f64,
compression_ratio: f64,
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
struct Segment {
start: f64,
duration: f64,
dr: DecodingResult,
}
struct Decoder {
model: Model,
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
verbose: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
translate_token: u32,
eot_token: u32,
no_speech_token: u32,
no_timestamps_token: u32,
language_token: Option<u32>,
}
impl Decoder {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
device: &Device,
language_token: Option<u32>,
task: Option<Task>,
timestamps: bool,
verbose: bool,
) -> Result<Self> {
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
.map(|i| {
if model.config().suppress_tokens.contains(&i)
|| timestamps && i == no_timestamps_token
{
f32::NEG_INFINITY
} else {
0f32
}
})
.collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = m::NO_SPEECH_TOKENS
.iter()
.find_map(|token| token_id(&tokenizer, token).ok());
let no_speech_token = match no_speech_token {
None => anyhow::bail!("unable to find any non-speech token"),
Some(n) => n,
};
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
tokenizer,
task,
timestamps,
verbose,
suppress_tokens,
sot_token,
transcribe_token,
translate_token,
eot_token,
no_speech_token,
language_token,
no_timestamps_token,
})
}
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &mut self.model;
let audio_features = model.encoder_forward(mel, true)?;
if self.verbose {
println!("audio features: {:?}", audio_features.dims());
}
let sample_len = model.config().max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
if let Some(language_token) = self.language_token {
tokens.push(language_token);
}
match self.task {
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
}
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
// The model expects a batch dim but this inference loop does not handle
// it so we add it at this point.
let tokens_t = tokens_t.unsqueeze(0)?;
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
// Extract the no speech probability on the first iteration by looking at the first
// token logits and the probability for the according token.
if i == 0 {
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
no_speech_prob = softmax(&logits, 0)?
.i(self.no_speech_token as usize)?
.to_scalar::<f32>()? as f64;
}
let (_, seq_len, _) = ys.dims3()?;
let logits = model
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
.i(0)?
.i(0)?;
// TODO: Besides suppress tokens, we should apply the heuristics from
// ApplyTimestampRules, i.e.:
// - Timestamps come in pairs, except before EOT.
// - Timestamps should be non-decreasing.
// - If the sum of the probabilities of timestamps is higher than any other tokens,
// only consider timestamps when sampling.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
let logits = logits.broadcast_add(&self.suppress_tokens)?;
let next_token = if t > 0f64 {
let prs = softmax(&(&logits / t)?, 0)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
distr.sample(&mut self.rng) as u32
} else {
let logits_v: Vec<f32> = logits.to_vec1()?;
logits_v
.iter()
.enumerate()
.max_by(|(_, u), (_, v)| u.total_cmp(v))
.map(|(i, _)| i as u32)
.unwrap()
};
tokens.push(next_token);
let prob = softmax(&logits, candle::D::Minus1)?
.i(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
break;
}
sum_logprob += prob.ln();
}
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
let avg_logprob = sum_logprob / tokens.len() as f64;
Ok(DecodingResult {
tokens,
text,
avg_logprob,
no_speech_prob,
temperature: t,
compression_ratio: f64::NAN,
})
}
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult> = self.decode(segment, t);
if i == m::TEMPERATURES.len() - 1 {
return dr;
}
// On errors, we try again with a different temperature.
match dr {
Ok(dr) => {
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr);
}
}
Err(err) => {
println!("Error running at {t}: {err}")
}
}
}
unreachable!()
}
fn run(&mut self, mel: &Tensor, times: Option<(f64, f64)>) -> Result<Vec<Segment>> {
let (_, _, content_frames) = mel.dims3()?;
let mut seek = 0;
let mut segments = vec![];
while seek < content_frames {
let start = std::time::Instant::now();
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size;
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}");
continue;
}
let segment = Segment {
start: time_offset,
duration: segment_duration,
dr,
};
if self.timestamps {
println!(
"{:.1}s -- {:.1}s",
segment.start,
segment.start + segment.duration,
);
let mut tokens_to_decode = vec![];
let mut prev_timestamp_s = 0f32;
for &token in segment.dr.tokens.iter() {
if token == self.sot_token || token == self.eot_token {
continue;
}
// The no_timestamp_token is the last before the timestamp ones.
if token > self.no_timestamps_token {
let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.;
if !tokens_to_decode.is_empty() {
let text = self
.tokenizer
.decode(&tokens_to_decode, true)
.map_err(E::msg)?;
println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text);
tokens_to_decode.clear()
}
prev_timestamp_s = timestamp_s;
} else {
tokens_to_decode.push(token)
}
}
if !tokens_to_decode.is_empty() {
let text = self
.tokenizer
.decode(&tokens_to_decode, true)
.map_err(E::msg)?;
if !text.is_empty() {
println!(" {:.1}s-...: {}", prev_timestamp_s, text);
}
tokens_to_decode.clear()
}
} else {
match times {
Some((start, end)) => {
println!("{:.1}s -- {:.1}s: {}", start, end, segment.dr.text)
}
None => {
println!(
"{:.1}s -- {:.1}s: {}",
segment.start,
segment.start + segment.duration,
segment.dr.text,
)
}
}
}
if self.verbose {
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
}
segments.push(segment)
}
Ok(segments)
}
fn set_language_token(&mut self, language_token: Option<u32>) {
self.language_token = language_token;
}
#[allow(dead_code)]
fn reset_kv_cache(&mut self) {
match &mut self.model {
Model::Normal(m) => m.reset_kv_cache(),
Model::Quantized(m) => m.reset_kv_cache(),
}
}
fn model(&mut self) -> &mut Model {
&mut self.model
}
}
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
match tokenizer.token_to_id(token) {
None => candle::bail!("no token-id for {token}"),
Some(id) => Ok(id),
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Task {
Transcribe,
Translate,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
enum WhichModel {
Tiny,
#[value(name = "tiny.en")]
TinyEn,
Base,
#[value(name = "base.en")]
BaseEn,
Small,
#[value(name = "small.en")]
SmallEn,
Medium,
#[value(name = "medium.en")]
MediumEn,
Large,
LargeV2,
LargeV3,
#[value(name = "distil-medium.en")]
DistilMediumEn,
#[value(name = "distil-large-v2")]
DistilLargeV2,
}
impl WhichModel {
fn is_multilingual(&self) -> bool {
match self {
Self::Tiny
| Self::Base
| Self::Small
| Self::Medium
| Self::Large
| Self::LargeV2
| Self::LargeV3
| Self::DistilLargeV2 => true,
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
false
}
}
}
fn model_and_revision(&self) -> (&'static str, &'static str) {
match self {
Self::Tiny => ("openai/whisper-tiny", "main"),
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
Self::Base => ("openai/whisper-base", "refs/pr/22"),
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
Self::Small => ("openai/whisper-small", "main"),
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
Self::Medium => ("openai/whisper-medium", "main"),
Self::MediumEn => ("openai/whisper-medium.en", "main"),
Self::Large => ("openai/whisper-large", "refs/pr/36"),
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
#[arg(long)]
model_id: Option<String>,
/// The model to use, check out available models:
/// https://huggingface.co/models?search=whisper
#[arg(long)]
revision: Option<String>,
/// The model to be used, can be tiny, small, medium.
#[arg(long, default_value = "tiny.en")]
model: WhichModel,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
quantized: bool,
/// Language.
#[arg(long)]
language: Option<String>,
/// Task, when no task is specified, the input tokens contain only the sot token which can
/// improve things when in no-timestamp mode.
#[arg(long)]
task: Option<Task>,
/// Timestamps mode, this is not fully implemented yet.
#[arg(long)]
timestamps: bool,
/// Print the full DecodingResult structure rather than just the text.
#[arg(long)]
verbose: bool,
}
pub fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?;
let (default_model, default_revision) = if args.quantized {
("lmz/candle-whisper", "main")
} else {
args.model.model_and_revision()
};
let default_model = default_model.to_string();
let default_revision = default_revision.to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let (config, tokenizer, model) = if args.quantized {
let ext = match args.model {
WhichModel::TinyEn => "tiny-en",
WhichModel::Tiny => "tiny",
_ => unimplemented!("no quantized support for {:?}", args.model),
};
(
repo.get(&format!("config-{ext}.json"))?,
repo.get(&format!("tokenizer-{ext}.json"))?,
repo.get(&format!("model-{ext}-q80.gguf"))?,
)
} else {
let config = repo.get("config.json")?;
let tokenizer = repo.get("tokenizer.json")?;
let model = repo.get("model.safetensors")?;
(config, tokenizer, model)
};
(config, tokenizer, model)
};
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let model = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
&weights_filename,
&device,
)?;
Model::Quantized(m::quantized_model::Whisper::load(&vb, config.clone())?)
} else {
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
};
let language_token = None;
let mut dc = Decoder::new(
model,
tokenizer.clone(),
args.seed,
&device,
language_token,
args.task,
args.timestamps,
args.verbose,
)?;
let mel_bytes = match config.num_mel_bins {
80 => include_bytes!("../whisper/melfilters.bytes").as_slice(),
128 => include_bytes!("../whisper/melfilters128.bytes").as_slice(),
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
};
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
// Set up the input device and stream with the default input config.
let host = cpal::default_host();
let _device = "default";
let _device = if _device == "default" {
host.default_input_device()
} else {
host.input_devices()?
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
}
.expect("failed to find input device");
let _config = _device
.default_input_config()
.expect("Failed to get default input config");
let channel_count = _config.channels() as usize;
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
let audio_ring_buffer_2 = audio_ring_buffer.clone();
std::thread::spawn(move || loop {
let data = record_audio(&_device, &_config, 300).unwrap();
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
let max_len = data.len() * 16;
let data_len = data.len();
let len = audio_ring_buffer.lock().unwrap().len();
if len > max_len {
let mut data = audio_ring_buffer.lock().unwrap();
let new_data = data[data_len..].to_vec();
*data = new_data;
}
});
// loop to process the audio data forever (until the user stops the program)
println!("Transcribing audio...");
for (i, _) in iter::repeat(()).enumerate() {
std::thread::sleep(std::time::Duration::from_millis(1000));
let data = audio_ring_buffer_2.lock().unwrap().clone();
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel_len = mel.len();
let mel = Tensor::from_vec(
mel,
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
&device,
)?;
// on the first iteration, we detect the language and set the language token.
if i == 0 {
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
(false, None) => None,
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
Ok(token_id) => Some(token_id),
Err(_) => anyhow::bail!("language {language} is not supported"),
},
(false, Some(_)) => {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
println!("language_token: {:?}", language_token);
dc.set_language_token(language_token);
}
dc.run(
&mel,
Some((
i as f64,
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
)),
)?;
dc.reset_kv_cache();
}
Ok(())
}
fn record_audio(
device: &cpal::Device,
config: &cpal::SupportedStreamConfig,
milliseconds: u64,
) -> Result<Vec<i16>> {
let writer = Arc::new(Mutex::new(Vec::new()));
let writer_2 = writer.clone();
let stream = device.build_input_stream(
&config.config(),
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let processed = data
.iter()
.map(|v| (v * 32768.0) as i16)
.collect::<Vec<i16>>();
writer_2.lock().unwrap().extend_from_slice(&processed);
},
move |err| {
eprintln!("an error occurred on stream: {}", err);
},
None,
)?;
stream.play()?;
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
drop(stream);
let data = writer.lock().unwrap().clone();
let step = 3;
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
Ok(data)
}

View File

@ -1,137 +0,0 @@
use crate::{token_id, Model};
use candle::{IndexOp, Result, Tensor, D};
use candle_transformers::models::whisper::{self as m};
use tokenizers::Tokenizer;
const LANGUAGES: [(&str, &str); 99] = [
("en", "english"),
("zh", "chinese"),
("de", "german"),
("es", "spanish"),
("ru", "russian"),
("ko", "korean"),
("fr", "french"),
("ja", "japanese"),
("pt", "portuguese"),
("tr", "turkish"),
("pl", "polish"),
("ca", "catalan"),
("nl", "dutch"),
("ar", "arabic"),
("sv", "swedish"),
("it", "italian"),
("id", "indonesian"),
("hi", "hindi"),
("fi", "finnish"),
("vi", "vietnamese"),
("he", "hebrew"),
("uk", "ukrainian"),
("el", "greek"),
("ms", "malay"),
("cs", "czech"),
("ro", "romanian"),
("da", "danish"),
("hu", "hungarian"),
("ta", "tamil"),
("no", "norwegian"),
("th", "thai"),
("ur", "urdu"),
("hr", "croatian"),
("bg", "bulgarian"),
("lt", "lithuanian"),
("la", "latin"),
("mi", "maori"),
("ml", "malayalam"),
("cy", "welsh"),
("sk", "slovak"),
("te", "telugu"),
("fa", "persian"),
("lv", "latvian"),
("bn", "bengali"),
("sr", "serbian"),
("az", "azerbaijani"),
("sl", "slovenian"),
("kn", "kannada"),
("et", "estonian"),
("mk", "macedonian"),
("br", "breton"),
("eu", "basque"),
("is", "icelandic"),
("hy", "armenian"),
("ne", "nepali"),
("mn", "mongolian"),
("bs", "bosnian"),
("kk", "kazakh"),
("sq", "albanian"),
("sw", "swahili"),
("gl", "galician"),
("mr", "marathi"),
("pa", "punjabi"),
("si", "sinhala"),
("km", "khmer"),
("sn", "shona"),
("yo", "yoruba"),
("so", "somali"),
("af", "afrikaans"),
("oc", "occitan"),
("ka", "georgian"),
("be", "belarusian"),
("tg", "tajik"),
("sd", "sindhi"),
("gu", "gujarati"),
("am", "amharic"),
("yi", "yiddish"),
("lo", "lao"),
("uz", "uzbek"),
("fo", "faroese"),
("ht", "haitian creole"),
("ps", "pashto"),
("tk", "turkmen"),
("nn", "nynorsk"),
("mt", "maltese"),
("sa", "sanskrit"),
("lb", "luxembourgish"),
("my", "myanmar"),
("bo", "tibetan"),
("tl", "tagalog"),
("mg", "malagasy"),
("as", "assamese"),
("tt", "tatar"),
("haw", "hawaiian"),
("ln", "lingala"),
("ha", "hausa"),
("ba", "bashkir"),
("jw", "javanese"),
("su", "sundanese"),
];
/// Returns the token id for the selected language.
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
let (_bsize, _, seq_len) = mel.dims3()?;
let mel = mel.narrow(
2,
0,
usize::min(seq_len, model.config().max_source_positions),
)?;
let device = mel.device();
let language_token_ids = LANGUAGES
.iter()
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?;
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
let audio_features = model.encoder_forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
let logits = logits.index_select(&language_token_ids, 0)?;
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
let probs = probs.to_vec1::<f32>()?;
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for ((_, language), p) in probs.iter().take(5) {
println!("{language}: {p}")
}
let language = token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
Ok(language)
}

View File

@ -18,8 +18,6 @@ use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod multilingual; mod multilingual;
mod pcm_decode;
use candle_transformers::models::whisper::{self as m, audio, Config}; use candle_transformers::models::whisper::{self as m, audio, Config};
pub enum Model { pub enum Model {
@ -537,10 +535,17 @@ fn main() -> Result<()> {
let mut mel_filters = vec![0f32; mel_bytes.len() / 4]; let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters); <byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
let (pcm_data, sample_rate) = pcm_decode::pcm_decode(input)?; let mut input = std::fs::File::open(input)?;
if sample_rate != m::SAMPLE_RATE as u32 { let (header, data) = wav::read(&mut input)?;
anyhow::bail!("input file must have a {} sampling rate", m::SAMPLE_RATE) println!("loaded wav data: {header:?}");
if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
} }
let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
.iter()
.map(|v| *v as f32 / 32768.)
.collect();
println!("pcm data loaded {}", pcm_data.len()); println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters); let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
let mel_len = mel.len(); let mel_len = mel.len();

View File

@ -1,74 +0,0 @@
use symphonia::core::audio::{AudioBufferRef, Signal};
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
use symphonia::core::conv::FromSample;
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
where
T: symphonia::core::sample::Sample,
f32: symphonia::core::conv::FromSample<T>,
{
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
}
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
// Open the media source.
let src = std::fs::File::open(path)?;
// Create the media source stream.
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
// Create a probe hint using the file's extension. [Optional]
let hint = symphonia::core::probe::Hint::new();
// Use the default options for metadata and format readers.
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
// Probe the media source.
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
// Get the instantiated format reader.
let mut format = probed.format;
// Find the first audio track with a known (decodeable) codec.
let track = format
.tracks()
.iter()
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
.expect("no supported audio tracks");
// Use the default options for the decoder.
let dec_opts: DecoderOptions = Default::default();
// Create a decoder for the track.
let mut decoder = symphonia::default::get_codecs()
.make(&track.codec_params, &dec_opts)
.expect("unsupported codec");
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
let mut pcm_data = Vec::new();
// The decode loop.
while let Ok(packet) = format.next_packet() {
// Consume any new metadata that has been read since the last packet.
while !format.metadata().is_latest() {
format.metadata().pop();
}
// If the packet does not belong to the selected track, skip over it.
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet)? {
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
}
}
Ok((pcm_data, sample_rate))
}

View File

@ -104,7 +104,6 @@ impl TextGeneration {
break; break;
} }
if let Some(t) = self.tokenizer.next_token(next_token)? { if let Some(t) = self.tokenizer.next_token(next_token)? {
let t = t.replace("<|im_end|>", "\n");
print!("{t}"); print!("{t}");
std::io::stdout().flush()?; std::io::stdout().flush()?;
} }

View File

@ -216,7 +216,7 @@ fn detect(
xs: &Tensor, xs: &Tensor,
image_height: usize, image_height: usize,
classes: usize, classes: usize,
anchors: &[(usize, usize)], anchors: &Vec<(usize, usize)>,
) -> Result<Tensor> { ) -> Result<Tensor> {
let (bsize, _channels, height, _width) = xs.dims4()?; let (bsize, _channels, height, _width) = xs.dims4()?;
let stride = image_height / height; let stride = image_height / height;

View File

@ -40,7 +40,7 @@ impl TokenOutputStream {
}; };
self.tokens.push(token); self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?; let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() { if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
let text = text.split_at(prev_text.len()); let text = text.split_at(prev_text.len());
self.prev_index = self.current_index; self.prev_index = self.current_index;
self.current_index = self.tokens.len(); self.current_index = self.tokens.len();

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-flash-attn" name = "candle-flash-attn"
version = "0.4.0" version = "0.3.3"
edition = "2021" edition = "2021"
description = "Flash attention layer for the candle ML framework." description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md" readme = "README.md"
[dependencies] [dependencies]
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.0" } candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] } half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies] [build-dependencies]

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-kernels" name = "candle-kernels"
version = "0.4.0" version = "0.3.3"
edition = "2021" edition = "2021"
description = "CUDA kernels for Candle" description = "CUDA kernels for Candle"

View File

@ -71,6 +71,7 @@ __device__ void im2col1d(
} }
const size_t *src_dims = info; const size_t *src_dims = info;
const size_t *src_s = info + 3; const size_t *src_s = info + 3;
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1]; const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2]; const size_t l_in = src_dims[2];
@ -119,6 +120,7 @@ __device__ void im2col(
} }
const size_t *src_dims = info; const size_t *src_dims = info;
const size_t *src_s = info + 4; const size_t *src_s = info + 4;
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1]; const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2]; const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3]; const size_t w_in = src_dims[3];
@ -223,60 +225,6 @@ __device__ void conv2d(
dst[dst_i] = static_cast<T>(d); dst[dst_i] = static_cast<T>(d);
} }
// Naive implementation of conv_transpose1d.
template <typename T, typename A>
__device__ void conv_transpose1d(
const size_t src_numel,
const size_t l_out,
const size_t stride,
const size_t padding,
const size_t out_padding,
const size_t dilation,
const size_t *info,
const T *src,
const T *kernel,
T *dst
) {
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
// src: (b_size, c_in, l_in)
// k: (c_in, c_out, l_k)
const size_t *src_dims = info;
const size_t *src_s = info + 3;
const size_t *k_dims = info + 6;
const size_t *k_s = info + 9;
const size_t l_k = k_dims[2];
const size_t c_out = k_dims[1];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
if (dst_i >= src_dims[0] * c_out * l_out) {
return;
}
// TODO
const size_t b_idx = dst_i / (l_out * c_out);
const size_t dst_c_idx = (dst_i / l_out) % c_out;
// NCL layout.
const size_t out_x = dst_i % l_out;
const size_t src_idx0 = b_idx * src_s[0];
A d = 0;
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
if (inp_x_stride < 0 || inp_x_stride % stride) {
continue;
}
int inp_x = inp_x_stride / stride;
if (inp_x >= l_in) continue;
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2];
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2];
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
}
}
dst[dst_i] = static_cast<T>(d);
}
// Naive implementation of conv_transpose2d. // Naive implementation of conv_transpose2d.
template <typename T, typename A> template <typename T, typename A>
__device__ void conv_transpose2d( __device__ void conv_transpose2d(
@ -559,22 +507,6 @@ extern "C" __global__ void FN_NAME( \
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \ im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
} \ } \
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t src_numel, \
const size_t l_out, \
const size_t stride, \
const size_t padding, \
const size_t out_padding, \
const size_t dilation, \
const size_t *info, \
const TYPENAME *src, \
const TYPENAME *kernel, \
TYPENAME *dst \
) { \
conv_transpose1d<TYPENAME, TYPEACC>(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
} \
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \ #define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
extern "C" __global__ void FN_NAME( \ extern "C" __global__ void FN_NAME( \
const size_t src_numel, \ const size_t src_numel, \
@ -636,7 +568,6 @@ extern "C" __global__ void FN_NAME( \
#if __CUDA_ARCH__ >= 800 #if __CUDA_ARCH__ >= 800
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16) CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16) CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
CONVT1D_OP(__nv_bfloat16, float, conv_transpose1d_bf16)
CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16) CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16) AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16) MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
@ -648,7 +579,6 @@ IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
#if __CUDA_ARCH__ >= 530 #if __CUDA_ARCH__ >= 530
CONV1D_OP(__half, float, conv1d_f16) CONV1D_OP(__half, float, conv1d_f16)
CONV2D_OP(__half, float, conv2d_f16) CONV2D_OP(__half, float, conv2d_f16)
CONVT1D_OP(__half, float, conv_transpose1d_f16)
CONVT2D_OP(__half, float, conv_transpose2d_f16) CONVT2D_OP(__half, float, conv_transpose2d_f16)
AVG_POOL2D_OP(__half, float, avg_pool2d_f16) AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
MAX_POOL2D_OP(__half, max_pool2d_f16) MAX_POOL2D_OP(__half, max_pool2d_f16)
@ -667,11 +597,6 @@ CONV2D_OP(double, double, conv2d_f64)
CONV2D_OP(uint8_t, uint8_t, conv2d_u8) CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
CONV2D_OP(uint32_t, uint32_t, conv2d_u32) CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
CONVT1D_OP(float, float, conv_transpose1d_f32)
CONVT1D_OP(double, double, conv_transpose1d_f64)
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
CONVT2D_OP(float, float, conv_transpose2d_f32) CONVT2D_OP(float, float, conv_transpose2d_f32)
CONVT2D_OP(double, double, conv_transpose2d_f64) CONVT2D_OP(double, double, conv_transpose2d_f64)
CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8) CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.4.0" version = "0.3.3"
edition = "2021" edition = "2021"
description = "Metal kernels for Candle" description = "Metal kernels for Candle"

View File

@ -1,2 +0,0 @@
xcrun metal -c src/gemm/kernels/steel_gemm.metal -I src/
xcrun metallib steel_gemm.air -o src/gemm/steel_gemm.metallib

View File

@ -73,7 +73,7 @@ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
#define INT64_BINARY_OP_OUT(NAME, FN) \ #define INT64_BINARY_OP_OUT(NAME, FN) \
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided); BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
BINARY_OP(x + y, add) BINARY_OP(x + y, add)
BINARY_OP(x - y, sub) BINARY_OP(x - y, sub)

View File

@ -1,317 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
#if defined(__HAVE_BFLOAT__)
typedef bfloat bfloat16_t;
#else
/////////////////////////////////////////////////////////////////////////////
// Helpers
/////////////////////////////////////////////////////////////////////////////
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
// Check for nan
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
_fp_encoding_traits<float>::inf_mask) {
return uint16_t(as_type<uint32_t>(0x7FC0));
}
// Take bits
uint32_t float_bits = as_type<uint32_t>(x);
// Round to nearest even
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
// Take upper 16 bits
return float_bits >> 16;
}
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
// Upper 16 bits are the data and lower 16 bits are 0s
return as_type<float>((uint32_t)x << 16);
}
struct _MLX_BFloat16;
template <typename T>
static constexpr constant bool can_convert_to_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_bfloat =
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
/////////////////////////////////////////////////////////////////////////////
// Bfloat struct
/////////////////////////////////////////////////////////////////////////////
struct _MLX_BFloat16 {
/////////////////////////////////////////////////////////////////////////////
// Constructors
uint16_t bits_;
_MLX_BFloat16() thread = default;
_MLX_BFloat16() threadgroup = default;
_MLX_BFloat16() device = default;
_MLX_BFloat16() constant = default;
struct bits_to_bfloat_struct {};
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
return bits_to_bfloat_struct();
}
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
: bits_(bits) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions to bfloat
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) device
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
template <
typename T,
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
/////////////////////////////////////////////////////////////////////////////
// Conversions from bfloat
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const thread {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const threadgroup {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const device {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
template <
typename T,
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
constexpr METAL_FUNC operator T() const constant {
return static_cast<T>(bfloat_bits_to_float(bits_));
}
};
/////////////////////////////////////////////////////////////////////////////
// Bfloat operators
/////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////
// Unary ops
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
return -static_cast<float>(x);
}
/////////////////////////////////////////////////////////////////////////////
// Binary operators
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
} \
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
}
/////////////////////////////////////////////////////////////////////////////
// Arithmetic Operators
#define bfloat_binop(_op_, _operator_) \
bfloat_binop_base( \
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(_op_, _operator_, float, float, float); \
bfloat_binop_helper(_op_, _operator_, float, half, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
bfloat_binop(+, operator+);
bfloat_binop(-, operator-);
bfloat_binop(*, operator*);
bfloat_binop(/, operator/);
/////////////////////////////////////////////////////////////////////////////
// Comparison ops
#define bfloat_compop(__op__, __operator__) \
bfloat_binop_base( \
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
bfloat_compop(>, operator>);
bfloat_compop(<, operator<);
bfloat_compop(>=, operator>=);
bfloat_compop(<=, operator<=);
bfloat_compop(==, operator==);
bfloat_compop(!=, operator!=);
#undef bfloat_compop
#undef bfloat_binop_base
#undef bfloat_binop_helper
#undef bfloat_binop
/////////////////////////////////////////////////////////////////////////////
// Inplace Operators
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, itype rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
} \
constexpr METAL_FUNC addr_space itype& __operator__( \
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
#define bfloat_inplace_op(itype) \
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
bfloat_inplace_op(float);
bfloat_inplace_op(half);
bfloat_inplace_op(int16_t);
bfloat_inplace_op(int32_t);
bfloat_inplace_op(int64_t);
bfloat_inplace_op(uint16_t);
bfloat_inplace_op(uint32_t);
bfloat_inplace_op(uint64_t);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
#undef bfloat_inplace_op
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
return lhs; \
}
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
bfloat_inplace_op_helper(__op__, __operator__, device); \
bfloat_inplace_op_helper(__op__, __operator__, thread); \
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
bfloat_inplace_op_addr_space_helper(+, operator+=);
bfloat_inplace_op_addr_space_helper(-, operator-=);
bfloat_inplace_op_addr_space_helper(*, operator*=);
bfloat_inplace_op_addr_space_helper(/, operator/=);
#undef bfloat_inplace_op_helper
#undef bfloat_inplace_op_addr_space_helper
/////////////////////////////////////////////////////////////////////////////
// Bfloat typedef
/////////////////////////////////////////////////////////////////////////////
typedef struct _MLX_BFloat16 bfloat16_t;
/////////////////////////////////////////////////////////////////////////////
// Bfloat numeric limits
/////////////////////////////////////////////////////////////////////////////
#pragma METAL internals : enable
namespace metal {
template <>
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
static constexpr constant int digits = 8;
static constexpr constant int digits10 = 2;
static constexpr constant int max_digits10 = 4;
static constexpr constant int radix = 2;
static constexpr constant int min_exponent = -125;
static constexpr constant int min_exponent10 = -37;
static constexpr constant int max_exponent = 128;
static constexpr constant int max_exponent10 = 38;
static constexpr bfloat16_t min() {
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t lowest() {
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t max() {
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t epsilon() {
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t round_error() {
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t infinity() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t quiet_NaN() {
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t signaling_NaN() {
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
}
static constexpr bfloat16_t denorm_min() {
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
}
};
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
return x != x;
}
} // namespace metal
#pragma METAL internals : disable
#endif // defined(__HAVE_BFLOAT__)
#include "gemm/bf16_math.h"

View File

@ -1,394 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include "gemm/bf16.h"
///////////////////////////////////////////////////////////////////////////////
// Metal math for bfloat16
///////////////////////////////////////////////////////////////////////////////
/*
Following the Metal Shading Language Specification (Metal 3.1)
"bfloat is an extended itypeing point type that only allows implicit conversion
to a type of greater itypeing point rank. While bfloat can be implicitly
converted to itype, it cannot be implicitly converted to half, and neither
itype nor half can be implicitly converted to bfloat."
Further, as far as I can tell, the stdlib math/simd functions are not defined
for bfloat and calling with an argument of type bfloat will result in that
argument getting implicitly converted to itype which then returns an output
that is (likely) a itype which cannot be implicitly converted into a bfloat
This leads to situations where
bfloat a = 5.0bf;
bfloat b = metal::abs(a); // this will throw an error since abs return itype
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
For the moment, I will be adding overloaded instantiations of the math
functions to accordingly automatically handle the casting
*/
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
\
METAL_FUNC otype abs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acos(itype x) { \
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype acosh(itype x) { \
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asin(itype x) { \
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype asinh(itype x) { \
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atan(itype y_over_x) { \
return static_cast<otype>( \
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
} \
METAL_FUNC otype atan2(itype y, itype x) { \
return static_cast<otype>( \
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype atanh(itype x) { \
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype ceil(itype x) { \
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cos(itype x) { \
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cosh(itype x) { \
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype cospi(itype x) { \
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype divide(itype x, itype y) { \
return static_cast<otype>( \
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype exp(itype x) { \
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp10(itype x) { \
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype exp2(itype x) { \
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fabs(itype x) { \
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fdim(itype x, itype y) { \
ctype t = static_cast<ctype>(x - y); \
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
} \
METAL_FUNC otype floor(itype x) { \
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype fma(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fma( \
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
} \
METAL_FUNC otype fmax(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmin(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype fmod(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype fract(itype x) { \
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype frexp(itype x, thread int& exp) { \
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
} \
METAL_FUNC otype ldexp(itype x, int k) { \
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
} \
METAL_FUNC otype log(itype x) { \
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log10(itype x) { \
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype log2(itype x) { \
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype max(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype max3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmax3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype median3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmedian3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype min(itype x, itype y) { \
return static_cast<otype>( \
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype min3(itype x, itype y, itype z) { \
return static_cast<otype>(__metal_fmin3( \
static_cast<ctype>(x), \
static_cast<ctype>(y), \
static_cast<ctype>(z), \
mfast)); \
} \
METAL_FUNC otype nextafter(itype x, itype y) { \
return static_cast<otype>( \
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
} \
METAL_FUNC otype pow(itype x, itype y) { \
return static_cast<otype>( \
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype powr(itype x, itype y) { \
return static_cast<otype>( \
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
} \
METAL_FUNC otype rint(itype x) { \
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype round(itype x) { \
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype rsqrt(itype x) { \
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sin(itype x) { \
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinh(itype x) { \
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sinpi(itype x) { \
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype sqrt(itype x) { \
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tan(itype x) { \
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanh(itype x) { \
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype tanpi(itype x) { \
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
} \
METAL_FUNC otype trunc(itype x) { \
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
}
namespace metal {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_MAYBE_FAST_MATH__);
namespace fast {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_FAST_MATH__);
} // namespace fast
namespace precise {
instantiate_metal_math_funcs(
bfloat16_t,
bfloat16_t,
float,
__METAL_PRECISE_MATH__);
} // namespace precise
} // namespace metal
///////////////////////////////////////////////////////////////////////////////
// Metal simd for bfloat16
///////////////////////////////////////////////////////////////////////////////
#define instantiate_metal_simd_comm_funcs( \
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
\
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
return ctype_to_otype( \
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
return ctype_to_otype( \
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down( \
itype data, itype filling_data, ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_down( \
itype data, itype filling_data, ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
itype_to_ctype(data), \
itype_to_ctype(filling_data), \
delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up( \
itype data, itype filling_data, ushort delta, ushort modulo) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
} \
\
METAL_FUNC otype simd_shuffle_and_fill_up( \
itype data, itype filling_data, ushort delta) { \
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
itype_to_ctype(data), \
itype_to_ctype(filling_data), \
delta, \
__metal_get_simdgroup_size(ushort()))); \
} \
\
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
return ctype_to_otype( \
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
} \
\
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
return ctype_to_otype( \
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
}
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
\
METAL_FUNC otype simd_max(itype data) { \
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_min(itype data) { \
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
return static_cast<otype>( \
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_product(itype data) { \
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_sum(itype data) { \
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
} \
\
METAL_FUNC otype simd_xor(itype data) { \
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
}
#if defined(__HAVE_BFLOAT__)
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
#else
#define bfloat16_to_uint16(x) x.bits_
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
#endif
namespace metal {
instantiate_metal_simd_comm_funcs(
bfloat16_t,
bfloat16_t,
uint16_t,
bfloat16_to_uint16,
uint16_to_bfloat16);
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
} // namespace metal

View File

@ -1,131 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_stdlib>
using namespace metal;
struct complex64_t;
template <typename T>
static constexpr constant bool can_convert_to_complex64 =
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
template <typename T>
static constexpr constant bool can_convert_from_complex64 =
!is_same_v<T, complex64_t> &&
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
struct complex64_t {
float real;
float imag;
// Constructors
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
// Conversions to complex64_t
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) thread : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) device : real(x), imag(0) {}
template <
typename T,
typename = typename enable_if<can_convert_to_complex64<T>>::type>
constexpr complex64_t(T x) constant : real(x), imag(0) {}
// Conversions from complex64_t
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const thread {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const threadgroup {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const device {
return static_cast<T>(real);
}
template <
typename T,
typename = typename enable_if<can_convert_from_complex64<T>>::type>
constexpr operator T() const constant {
return static_cast<T>(real);
}
};
constexpr complex64_t operator-(complex64_t x) {
return {-x.real, -x.imag};
}
constexpr bool operator>=(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
}
constexpr bool operator>(complex64_t a, complex64_t b) {
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
}
constexpr bool operator<=(complex64_t a, complex64_t b) {
return operator>=(b, a);
}
constexpr bool operator<(complex64_t a, complex64_t b) {
return operator>(b, a);
}
constexpr bool operator==(complex64_t a, complex64_t b) {
return a.real == b.real && a.imag == b.imag;
}
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
return {a.real + b.real, a.imag + b.imag};
}
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
return {a.real - b.real, a.imag - b.imag};
}
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
auto denom = b.real * b.real + b.imag * b.imag;
auto x = a.real * b.real + a.imag * b.imag;
auto y = a.imag * b.real - a.real * b.imag;
return {x / denom, y / denom};
}
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
if (real != 0 && (real < 0 != b.real < 0)) {
real += b.real;
}
if (imag != 0 && (imag < 0 != b.imag < 0)) {
imag += b.imag;
}
return {real, imag};
}

View File

@ -1,292 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "gemm/loader.h"
#include "gemm/mma.h"
#include "gemm/transforms.h"
#include "utils.h"
using namespace metal;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel class
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <bool M_aligned, bool N_aligned, bool K_aligned>
struct LoopAlignment {};
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = typename AccumHelper<T>::accum_type,
typename Epilogue = TransformNone<U, AccumType>>
struct GEMMKernel {
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
STEEL_CONST short tgp_mem_size_a =
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
STEEL_CONST short tgp_mem_size_b =
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
STEEL_CONST short tgp_size = WM * WN * 32;
using loader_a_t = BlockLoader<
T,
transpose_a ? BK : BM,
transpose_a ? BM : BK,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
!transpose_a,
tgp_size>;
using loader_b_t = BlockLoader<
T,
transpose_b ? BN : BK,
transpose_b ? BK : BN,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
transpose_b,
tgp_size>;
using mma_t = BlockMMA<
T,
U,
BM,
BN,
BK,
WM,
WN,
transpose_a,
transpose_b,
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
AccumType,
Epilogue>;
/* Main kernel function */
template <bool M_aligned, bool N_aligned, bool K_aligned_>
static METAL_FUNC void gemm_loop(
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
const int gemm_k_iterations,
thread loader_a_t& loader_a,
thread loader_b_t& loader_b,
thread mma_t& mma_op,
thread const short& tgp_bm,
thread const short& tgp_bn,
thread const short& lbk,
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
// Appease the compiler
(void)l;
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
if (M_aligned) {
loader_a.load_unsafe();
} else {
loader_a.load_safe(tile_dims_A);
}
if (N_aligned) {
loader_b.load_unsafe();
} else {
loader_b.load_safe(tile_dims_B);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
if (!K_aligned_) {
threadgroup_barrier(mem_flags::mem_threadgroup);
short2 tile_dims_A_last =
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
short2 tile_dims_B_last =
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
loader_a.load_safe(tile_dims_A_last);
loader_b.load_safe(tile_dims_B_last);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
}
/* Main kernel function */
static METAL_FUNC void run(
const device T* A [[buffer(0)]],
const device T* B [[buffer(1)]],
device U* C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
threadgroup T* As [[threadgroup(0)]],
threadgroup T* Bs [[threadgroup(1)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(C, params->ldc);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_loop<true, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result(C, params->ldc);
return;
} else if (tgp_bn == BN) {
gemm_loop<false, true, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
} else if (tgp_bm == BM) {
gemm_loop<true, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
} else {
gemm_loop<false, false, K_aligned>(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk);
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
return;
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@ -1,5 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "params.h"

View File

@ -1,89 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "gemm/bf16.h"
#include "gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device T *C [[buffer(2)]],
const constant GEMMParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
gemm_kernel::run(
A, B, C,
params,
As, Bs,
simd_lane_id, simd_group_id, tid, lid
);
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device itype *C [[buffer(2)]], \
const constant GEMMParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@ -1,254 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned,
typename AccumType = float,
typename Epilogue = TransformAdd<T, AccumType>>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
const device T *C [[buffer(2)]],
device T *D [[buffer(3)]],
const constant GEMMAddMMParams* params [[buffer(4)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
// Pacifying compiler
(void)lid;
using gemm_kernel =
GEMMKernel<T, T, BM, BN, BK, WM, WN,
transpose_a, transpose_b,
MN_aligned, K_aligned,
AccumType, Epilogue>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
// Adjust for batch
A += params->batch_stride_a * tid.z;
B += params->batch_stride_b * tid.z;
C += params->batch_stride_c * tid.z;
D += params->batch_stride_d * tid.z;
const int tid_y = ((tid.y) << params->swizzle_log) +
((tid.x) & ((1 << params->swizzle_log) - 1));
const int tid_x = (tid.x) >> params->swizzle_log;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
threadgroup_barrier(mem_flags::mem_none);
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
A += transpose_a ? c_row : c_row * params->lda;
B += transpose_b ? c_col * params->ldb : c_col;
C += c_row * params->ldc + c_col * params->fdc;
D += c_row * params->ldd + c_col;
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
const Epilogue epilogue_op(params->alpha, params->beta);
///////////////////////////////////////////////////////////////////////////////
// MNK aligned loop
if (MN_aligned) {
for (int k = 0; k < gemm_k_iterations; k++) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// Load elements into threadgroup
loader_a.load_unsafe();
loader_b.load_unsafe();
threadgroup_barrier(mem_flags::mem_threadgroup);
// Multiply and accumulate threadgroup elements
mma_op.mma(As, Bs);
// Prepare for next iteration
loader_a.next();
loader_b.next();
}
threadgroup_barrier(mem_flags::mem_none);
// Loop tail
if (!K_aligned) {
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
loader_a.load_safe(tile_dims_A);
loader_b.load_safe(tile_dims_B);
threadgroup_barrier(mem_flags::mem_threadgroup);
mma_op.mma(As, Bs);
}
// Store results to device memory
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
return;
}
///////////////////////////////////////////////////////////////////////////////
// MN unaligned loop
else { // Loop over K - unaligned case
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
if (tgp_bm == BM && tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, K_aligned>{});
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
return;
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
return mma_op.store_result_safe(
D, params->ldd,
C, params->ldc, params->fdc,
short2(tgp_bn, tgp_bm),
epilogue_op);
}
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
const device itype *C [[buffer(2)]], \
device itype *D [[buffer(3)]], \
const constant GEMMAddMMParams* params [[buffer(4)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
instantiate_gemm_shapes_helper(float32, float, float32, float);

View File

@ -1,280 +0,0 @@
// Copyright © 2024 Apple Inc.
#include "mlx/backend/metal/kernels/bf16.h"
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
using namespace metal;
using namespace mlx::steel;
///////////////////////////////////////////////////////////////////////////////
// GEMM kernels
///////////////////////////////////////////////////////////////////////////////
template <typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
bool MN_aligned,
bool K_aligned>
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
const device T *A [[buffer(0)]],
const device T *B [[buffer(1)]],
device U *C [[buffer(2)]],
const constant GEMMSpiltKParams* params [[buffer(3)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]]) {
(void)lid;
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
using loader_a_t = typename gemm_kernel::loader_a_t;
using loader_b_t = typename gemm_kernel::loader_b_t;
using mma_t = typename gemm_kernel::mma_t;
threadgroup T As[gemm_kernel::tgp_mem_size_a];
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
const int tid_x = tid.x;
const int tid_y = tid.y;
const int tid_z = tid.z;
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
return;
}
// Find block in A, B, C
const int c_row = tid_y * BM;
const int c_col = tid_x * BN;
const int k_start = params->split_k_partition_size * tid_z;
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
// Prepare threadgroup loading operations
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
// Prepare threadgroup mma operation
thread mma_t mma_op(simd_group_id, simd_lane_id);
int gemm_k_iterations = params->gemm_k_iterations_aligned;
short tgp_bm = min(BM, params->M - c_row);
short tgp_bn = min(BN, params->N - c_col);
short leftover_bk = params->K % BK;
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, true, true>{});
} else if (tgp_bn == BN) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, true, true>{});
} else if (tgp_bm == BM) {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<true, false, true>{});
} else {
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iterations,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, true>{});
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if ((tid_z + 1) == (params->split_k_partitions)) {
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
if(!K_aligned || gemm_k_iter_remaining > 0)
gemm_kernel::gemm_loop(
As,
Bs,
gemm_k_iter_remaining,
loader_a,
loader_b,
mma_op,
tgp_bm,
tgp_bn,
leftover_bk,
LoopAlignment<false, false, K_aligned>{});
}
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
mma_op.store_result(C, params->ldc);
} else {
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
}
}
///////////////////////////////////////////////////////////////////////////////
// GEMM kernel initializations
///////////////////////////////////////////////////////////////////////////////
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
const device itype *A [[buffer(0)]], \
const device itype *B [[buffer(1)]], \
device otype *C [[buffer(2)]], \
const constant GEMMSpiltKParams* params [[buffer(3)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]], \
uint3 lid [[thread_position_in_threadgroup]]);
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
instantiate_gemm_shapes_helper(float16, half, float32, float);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
instantiate_gemm_shapes_helper(float32, float, float32, float);
///////////////////////////////////////////////////////////////////////////////
// Split k accumulation kernel
///////////////////////////////////////////////////////////////////////////////
template <typename AccT,
typename OutT,
typename Epilogue = TransformNone<OutT, AccT>>
[[kernel]] void gemm_splitk_accum(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
int offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
D[0] = Epilogue::apply(out);
}
template <typename AccT,
typename OutT,
typename Epilogue = TransformAxpby<OutT, AccT>>
[[kernel]] void gemm_splitk_accum_axpby(
const device AccT *C_split [[buffer(0)]],
device OutT *D [[buffer(1)]],
const constant int& k_partitions [[buffer(2)]],
const constant int& partition_stride [[buffer(3)]],
const constant int& ldd [[buffer(4)]],
const device OutT *C [[buffer(5)]],
const constant int& ldc [[buffer(6)]],
const constant int& fdc [[buffer(7)]],
const constant float& alpha [[buffer(8)]],
const constant float& beta [[buffer(9)]],
uint2 gid [[thread_position_in_grid]]) {
// Ajust D and C
C += gid.x * fdc + gid.y * ldc;
D += gid.x + gid.y * ldd;
C_split += gid.x + gid.y * ldd;
int offset = 0;
AccT out = 0;
for(int i = 0; i < k_partitions; i++) {
out += C_split[offset];
offset += partition_stride;
}
// Write output
Epilogue op(alpha, beta);
D[0] = op.apply(out, *C);
}
#define instantiate_accum(oname, otype, aname, atype) \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
[[kernel]] void gemm_splitk_accum<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
uint2 gid [[thread_position_in_grid]]); \
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
const device atype *C_split [[buffer(0)]], \
device otype *D [[buffer(1)]], \
const constant int& k_partitions [[buffer(2)]], \
const constant int& partition_stride [[buffer(3)]], \
const constant int& ldd [[buffer(4)]], \
const device otype *C [[buffer(5)]], \
const constant int& ldc [[buffer(6)]], \
const constant int& fdc [[buffer(7)]], \
const constant float& alpha [[buffer(8)]], \
const constant float& beta [[buffer(9)]], \
uint2 gid [[thread_position_in_grid]]);
instantiate_accum(bfloat16, bfloat16_t, float32, float);
instantiate_accum(float16, half, float32, float);
instantiate_accum(float32, float, float32, float);

View File

@ -1,125 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "utils2.h"
///////////////////////////////////////////////////////////////////////////////
// Loading helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
short BROWS,
short BCOLS,
short dst_ld,
short reduction_dim,
short tgp_size,
short alignment = 1,
short n_reads = (BCOLS * BROWS) / (tgp_size),
short TCOLS = BCOLS / n_reads,
short TROWS = tgp_size / TCOLS>
struct BlockLoader {
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
STEEL_CONST short vec_size = n_reads;
// Leading dimension for src
const int src_ld;
const int tile_stride;
// Thread location indices
const short thread_idx;
const short bi;
const short bj;
// threadgroup and device memory
threadgroup T* dst;
const device T* src;
struct alignas(alignment * sizeof(T)) ReadVector {
uint8_t v[sizeof(T) * vec_size];
};
/* Constructor */
METAL_FUNC BlockLoader(
const device T* src_,
const int src_ld_,
threadgroup T* dst_,
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: src_ld(src_ld_),
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
thread_idx(simd_group_id * 32 + simd_lane_id),
bi(thread_idx / TCOLS),
bj(vec_size * (thread_idx % TCOLS)),
dst(dst_ + bi * dst_ld + bj),
src(src_ + bi * src_ld + bj) {}
/* Load from device memory into threadgroup memory - without bound checking */
METAL_FUNC void load_unsafe() const {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
*((const device ReadVector*)(&src[i * src_ld]));
}
}
/* Load from device memory into threadgroup memory - with bound checking */
METAL_FUNC void load_safe(short2 src_tile_dim) const {
src_tile_dim = src_tile_dim - short2(bj, bi);
// Skip loading if thread has no valid reads
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = T(0);
}
}
return;
}
// Use fast thread memory for bound checks
bool tmp_idx[vec_size];
T tmp_val[vec_size];
STEEL_PRAGMA_UNROLL
for (short i = 0; i < BROWS; i += TROWS) {
// Make sure tmp_idx only contains valid indices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
}
// Read valid indices into tmp_val
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
}
// Zero out uneeded values
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
}
// Copy values to threadgroup memory
STEEL_PRAGMA_UNROLL
for (short j = 0; j < vec_size; j++) {
dst[i * dst_ld + j] = tmp_val[j];
}
}
}
/* Iteration helper */
METAL_FUNC void next() {
src += tile_stride;
}
};
} // namespace steel
} // namespace mlx

View File

@ -1,264 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "gemm/transforms.h"
#include "utils.h"
///////////////////////////////////////////////////////////////////////////////
// MMA helper
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <
typename T,
typename U,
int BM,
int BN,
int BK,
int WM,
int WN,
bool transpose_a,
bool transpose_b,
short lda_tgp,
short ldb_tgp,
typename AccumType = float,
typename Epilogue = TransformNone<U, AccumType>>
struct BlockMMA {
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TM_stride = 8 * WM;
// Warp tile simdgroup matrix strides along M
STEEL_CONST short TN_stride = 8 * WN;
// Warp tile size along M
STEEL_CONST short TM = BM / TM_stride;
// Warp tile size along N
STEEL_CONST short TN = BN / TN_stride;
// Strides of A, B along reduction axis
STEEL_CONST short simd_stride_a = {
transpose_a ? TM_stride : TM_stride * lda_tgp};
STEEL_CONST short simd_stride_b = {
transpose_b ? TN_stride * ldb_tgp : TN_stride};
// Jump between elements
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
// Simdgroup matrices
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
simdgroup_matrix<AccumType, 8, 8>(0)};
// Offsets within threadgroup
const short tm;
const short tn;
short sm;
short sn;
short As_offset;
short Bs_offset;
/* Constructor */
METAL_FUNC BlockMMA(
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]])
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
// Determine thread position in simdgroup matrix
short qid = simd_lane_id / 4;
sm = (qid & 4) + (simd_lane_id / 2) % 4;
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
// Determine thread and simdgroup offset
As_offset =
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
Bs_offset =
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
}
/* (BM, BK) X (BK, BN) multiply accumulate function */
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
// Adjust for simdgroup and thread location
As += As_offset;
Bs += Bs_offset;
// Iterate over BK in blocks of 8
STEEL_PRAGMA_UNROLL
for (short kk = 0; kk < BK; kk += 8) {
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup A as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
Asimd[i].thread_elements()[0] =
static_cast<AccumType>(As[i * simd_stride_a + 0]);
Asimd[i].thread_elements()[1] =
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
}
simdgroup_barrier(mem_flags::mem_none);
// Load elements from threadgroup B as simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
Bsimd[j].thread_elements()[0] =
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
Bsimd[j].thread_elements()[1] =
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
}
simdgroup_barrier(mem_flags::mem_none);
// Multiply and accumulate into result simdgroup matrices
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
short j_serp = (i % 2) ? (TN - 1 - j) : j;
simdgroup_multiply_accumulate(
results[i * TN + j_serp],
Asimd[i],
Bsimd[j_serp],
results[i * TN + j_serp]);
}
}
// Progress to next simdgroup tile
As += tile_stride_a;
Bs += tile_stride_b;
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(device U* C, const int ldc) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
// Write out C
C[offset] = outs[0];
C[offset + 1] = outs[1];
}
}
}
METAL_FUNC void
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn);
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset = (i * TM_stride) * ldc + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
C[offset] = Epilogue::apply(accum[0]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
C[offset + 1] = Epilogue::apply(accum[1]);
}
}
}
}
}
/* Store results from simdgroup_matrix results into device memory */
METAL_FUNC void store_result(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
// Loop over all simdgroup tiles
STEEL_PRAGMA_UNROLL
for (short i = 0; i < TM; i++) {
STEEL_PRAGMA_UNROLL
for (short j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue
U outs[2] = {
epilogue_op.apply(accum[0], C[offset_c]),
epilogue_op.apply(accum[1], C[offset_c + fdc])};
// Write out D
D[offset_d] = outs[0];
D[offset_d + 1] = outs[1];
}
}
}
METAL_FUNC void store_result_safe(
device U* D,
const int ldd,
const device U* C,
const int ldc,
const int fdc,
short2 dst_tile_dims,
thread const Epilogue& epilogue_op) const {
// Adjust for simdgroup and thread location
C += (sm + tm) * ldc + (tn + sn) * fdc;
D += (sm + tm) * ldd + tn + sn;
dst_tile_dims -= short2(tn + sn, sm + tm);
STEEL_PRAGMA_UNROLL
for (int i = 0; i < TM; i++) {
if (i * TM_stride < dst_tile_dims.y) {
STEEL_PRAGMA_UNROLL
for (int j = 0; j < TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum = results[i * TN + j].thread_elements();
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
// Apply epilogue and output C
if (j * TN_stride < dst_tile_dims.x) {
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
}
if (j * TN_stride + 1 < dst_tile_dims.x) {
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
}
}
}
}
}
};
} // namespace steel
} // namespace mlx

View File

@ -1,79 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
///////////////////////////////////////////////////////////////////////////////
// GEMM param classes
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
struct GEMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int swizzle_log;
const int gemm_k_iterations_aligned;
};
struct GEMMSpiltKParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int tiles_n;
const int tiles_m;
const int split_k_partitions;
const int split_k_partition_stride;
const int split_k_partition_size;
const int gemm_k_iterations_aligned;
};
struct GEMMAddMMParams {
const int M;
const int N;
const int K;
const int lda;
const int ldb;
const int ldc;
const int ldd;
const int tiles_n;
const int tiles_m;
const int batch_stride_a;
const int batch_stride_b;
const int batch_stride_c;
const int batch_stride_d;
const int swizzle_log;
const int gemm_k_iterations_aligned;
const float alpha;
const float beta;
const int fdc;
};
} // namespace steel
} // namespace mlx

View File

@ -1,63 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include "utils.h"
///////////////////////////////////////////////////////////////////////////////
// Transforms and Epilogues
///////////////////////////////////////////////////////////////////////////////
namespace mlx {
namespace steel {
template <typename OutT, typename InT>
struct TransformNone {
static METAL_FUNC OutT apply(InT x) {
return static_cast<OutT>(x);
}
static METAL_FUNC OutT apply(InT x, OutT) {
return static_cast<OutT>(x);
}
};
template <typename OutT, typename InT>
struct TransformAdd {
TransformAdd(const float, const float) {}
static METAL_FUNC OutT apply(InT x, OutT c) {
return static_cast<OutT>(x) + c;
}
};
template <typename OutT, typename InT>
struct TransformAxpby {
const float alpha;
const float beta;
TransformAxpby(const float alpha_, const float beta_)
: alpha(alpha_), beta(beta_) {}
METAL_FUNC OutT apply(InT x, OutT c) const {
return static_cast<OutT>(x * alpha + (beta * c));
}
};
template <typename T>
struct AccumHelper {
typedef float accum_type;
};
struct BlockSwizzle {
static METAL_FUNC int2
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
const int tid_x = (tid.x) >> swizzle_log;
const int tid_y =
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
return int2(tid_x, tid_y);
}
};
} // namespace steel
} // namespace mlx

View File

@ -1,276 +0,0 @@
// Copyright © 2023 Apple Inc.
#pragma once
#include <metal_math>
#include "gemm/bf16.h"
#include "gemm/complex.h"
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
template <typename U>
struct Limits {
static const constant U max = metal::numeric_limits<U>::max();
static const constant U min = metal::numeric_limits<U>::min();
static const constant U finite_max = metal::numeric_limits<U>::max();
static const constant U finite_min = metal::numeric_limits<U>::min();
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};
instantiate_float_limit(half);
instantiate_float_limit(float);
instantiate_float_limit(bfloat16_t);
template <>
struct Limits<bool> {
static constexpr constant bool max = true;
static constexpr constant bool min = false;
};
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
inline size_t elem_to_loc(
uint elem,
device const int* shape,
device const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
inline size_t elem_to_loc(
uint elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = 0;
for (int i = ndim - 1; i >= 0; --i) {
loc += (elem % shape[i]) * strides[i];
elem /= shape[i];
}
return loc;
}
template <int NDIM>
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t a_strides[NDIM],
constant const size_t b_strides[NDIM]) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
static_cast<uint>(
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
for (int d = NDIM - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline size_t elem_to_loc_nd(
uint3 elem,
constant const int shape[NDIM],
constant const size_t strides[NDIM]) {
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
for (int d = NDIM - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
return elem * stride;
}
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
return elem.x * strides[1] + elem.y * strides[0];
}
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
}
// Non templated version to handle arbitrary dims
inline size_t elem_to_loc(
uint3 elem,
constant const int* shape,
constant const size_t* strides,
int ndim) {
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * strides[d];
elem.z /= shape[d];
}
return loc;
}
inline uint2 elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
int ndim) {
uint2 loc = {
static_cast<uint>(
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
static_cast<uint>(
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * a_strides[d];
loc.y += l * b_strides[d];
elem.z /= shape[d];
}
return loc;
}
template <int NDIM>
inline uint elem_to_loc_nd(
uint elem,
device const int* shape,
device const size_t* strides);
template <>
inline uint elem_to_loc_nd<1>(
uint elem,
device const int* shape,
device const size_t* strides) {
return (elem % shape[0]) * strides[0];
}
template <>
inline uint elem_to_loc_nd<2>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<3>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
template <>
inline uint elem_to_loc_nd<4>(
uint elem,
device const int* shape,
device const size_t* strides) {
uint loc = (elem % shape[3]) * strides[3];
elem /= shape[3];
loc += (elem % shape[2]) * strides[2];
elem /= shape[2];
loc += (elem % shape[1]) * strides[1];
elem /= shape[1];
loc += (elem % shape[0]) * strides[0];
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
/** Compute ceil((float)N/(float)M) */
inline size_t ceildiv(size_t N, size_t M) {
return (N + M - 1) / M;
}
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {
float xp1 = 1.0f + x;
if (xp1 == Limits<float>::max) {
return Limits<float>::max;
}
if (xp1 == 1.0f) {
return x;
}
return x * (metal::log(xp1) / (xp1 - 1.0f));
}
inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x);
if (xp1 == Limits<float>::max) {
return Limits<bfloat16_t>::max;
}
if (xp1 == 1.0f) {
return x;
}
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}

View File

@ -1,9 +0,0 @@
// Copyright © 2024 Apple Inc.
#pragma once
#include <metal_stdlib>
#include "gemm/host.h"
#define STEEL_CONST static constant constexpr const
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")

View File

@ -1,7 +1,6 @@
use metal::{ use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
NSUInteger,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
@ -13,12 +12,9 @@ const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal"); const BINARY: &str = include_str!("binary.metal");
const TERNARY: &str = include_str!("ternary.metal"); const TERNARY: &str = include_str!("ternary.metal");
const CAST: &str = include_str!("cast.metal"); const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const REDUCE: &str = include_str!("reduce.metal"); const REDUCE: &str = include_str!("reduce.metal");
const RANDOM: &str = include_str!("random.metal"); const CONV: &str = include_str!("conv.metal");
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const GEMM: &[u8] = include_bytes!("gemm/steel_gemm.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
/// Most kernels apply similarly across the tensors /// Most kernels apply similarly across the tensors
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the /// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
@ -65,12 +61,10 @@ macro_rules! primitive {
} }
}; };
} }
primitive!(bool);
primitive!(usize); primitive!(usize);
primitive!(i32);
primitive!(i64); primitive!(i64);
primitive!(i32);
primitive!(u32); primitive!(u32);
primitive!(u64);
primitive!(f32); primitive!(f32);
impl<T> EncoderParam for &[T] { impl<T> EncoderParam for &[T] {
@ -124,9 +118,7 @@ pub enum Source {
Cast, Cast,
Reduce, Reduce,
Mfa, Mfa,
Gemm,
Conv, Conv,
Random,
Quantized, Quantized,
} }
@ -248,10 +240,7 @@ impl Kernels {
Source::Cast => CAST, Source::Cast => CAST,
Source::Reduce => REDUCE, Source::Reduce => REDUCE,
Source::Conv => CONV, Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Mfa => panic!("Invalid lib"), Source::Mfa => panic!("Invalid lib"),
Source::Gemm => panic!("Invalid lib"),
} }
} }
@ -275,14 +264,6 @@ impl Kernels {
)) ))
})? })?
} }
Source::Gemm => {
let source_data = GEMM;
device.new_library_with_data(source_data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load GEMM: {e}"
))
})?
}
source => { source => {
let source_content = self.get_library_source(source); let source_content = self.get_library_source(source);
device device
@ -1242,34 +1223,6 @@ impl ConstantValues {
} }
} }
fn string_to_static_str(s: String) -> &'static str {
Box::leak(s.into_boxed_str())
}
use core::ffi::c_int;
#[repr(C)]
#[derive(Debug)]
struct GEMMParams {
m: c_int,
n: c_int,
k: c_int,
lda: c_int,
ldb: c_int,
ldc: c_int,
tiles_n: c_int,
tiles_m: c_int,
batch_stride_a: c_int,
batch_stride_b: c_int,
batch_stride_c: c_int,
swizzle_log: c_int,
gemm_k_iterations_aligned: c_int,
}
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn call_gemm( pub fn call_gemm(
device: &Device, device: &Device,
@ -1291,10 +1244,10 @@ pub fn call_gemm(
let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let (a_trans, lda) = if lhs_m1 == 1 && lhs_m2 == k { let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
(false, k as c_int) false
} else if lhs_m1 == m && lhs_m2 == 1 { } else if lhs_m1 == m && lhs_m2 == 1 {
(true, n as c_int) true
} else { } else {
return Err(MetalKernelError::MatMulNonContiguous { return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(), lhs_stride: lhs_stride.to_vec(),
@ -1302,10 +1255,10 @@ pub fn call_gemm(
mnk: (m, n, k), mnk: (m, n, k),
})?; })?;
}; };
let (b_trans, ldb) = if rhs_m1 == 1 && rhs_m2 == n { let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
(false, n as c_int) false
} else if rhs_m1 == k && rhs_m2 == 1 { } else if rhs_m1 == k && rhs_m2 == 1 {
(true, k as c_int) true
} else { } else {
return Err(MetalKernelError::MatMulNonContiguous { return Err(MetalKernelError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(), lhs_stride: lhs_stride.to_vec(),
@ -1313,195 +1266,120 @@ pub fn call_gemm(
mnk: (m, n, k), mnk: (m, n, k),
})?; })?;
}; };
// let d_trans = false; let d_trans = false;
// let alpha = 1.0f32; let alpha = 1.0f32;
// let beta = 0.0f32; let beta = 0.0f32;
// let batched = b > 1; let batched = b > 1;
// let fused_activation = false; let fused_activation = false;
// let fused_bias = false; let fused_bias = false;
// let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
// let m_simd = 8; let m_simd = 8;
// let n_simd = 8; let n_simd = 8;
// let k_simd = 64; let k_simd = 64;
// let m_splits = 1; let m_splits = 1;
// let n_splits = 1; let n_splits = 1;
// (m_simd, n_simd, k_simd, m_splits, n_splits) (m_simd, n_simd, k_simd, m_splits, n_splits)
// } else { } else {
// let m_simd = 40; let m_simd = 40;
// let n_simd = 40; let n_simd = 40;
// let k_simd = 32; let k_simd = 32;
// let m_splits = 1; let m_splits = 1;
// let n_splits = 1; let n_splits = 1;
// (m_simd, n_simd, k_simd, m_splits, n_splits) (m_simd, n_simd, k_simd, m_splits, n_splits)
// }; };
// let constants = Some(ConstantValues::new(vec![ let constants = Some(ConstantValues::new(vec![
// (0, Value::USize(m)), (0, Value::USize(m)),
// (1, Value::USize(n)), (1, Value::USize(n)),
// (2, Value::USize(k)), (2, Value::USize(k)),
// (10, Value::Bool(a_trans)), (10, Value::Bool(a_trans)),
// (11, Value::Bool(b_trans)), (11, Value::Bool(b_trans)),
// (13, Value::Bool(d_trans)), (13, Value::Bool(d_trans)),
// (20, Value::F32(alpha)), (20, Value::F32(alpha)),
// (21, Value::F32(beta)), (21, Value::F32(beta)),
// (100, Value::Bool(batched)), (100, Value::Bool(batched)),
// (101, Value::Bool(fused_activation)), (101, Value::Bool(fused_activation)),
// // Garbage // Garbage
// (102, Value::Bool(false)), (102, Value::Bool(false)),
// (103, Value::Bool(false)), (103, Value::Bool(false)),
// (113, Value::Bool(false)), (113, Value::Bool(false)),
// (50_000, Value::Bool(false)), (50_000, Value::Bool(false)),
// // End garbage // End garbage
// (200, Value::U16(m_simd)), (200, Value::U16(m_simd)),
// (201, Value::U16(n_simd)), (201, Value::U16(n_simd)),
// (202, Value::U16(k_simd)), (202, Value::U16(k_simd)),
// (210, Value::U16(m_splits)), (210, Value::U16(m_splits)),
// (211, Value::U16(n_splits)), (211, Value::U16(n_splits)),
// (50_001, Value::Bool(fused_bias)), (50_001, Value::Bool(fused_bias)),
// ])); ]));
let a_trans_name = if a_trans { "t" } else { "n" }; let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
let b_trans_name = if b_trans { "t" } else { "n" }; let m_group = m_simd * m_splits;
let (iname, oname) = match name { let n_group = n_simd * n_splits;
"sgemm" => ("float32", "float32"),
"hgemm" => ("float16", "float16"), let a_block_length = m_group * k_simd;
"bgemm" => ("bfloat16", "bfloat16"), let b_block_length = k_simd * n_group;
let mut block_elements = a_block_length + b_block_length;
if (m % 8 != 0) && (n % 8 != 0) {
let c_block_length = m_group * n_group;
block_elements = std::cmp::max(c_block_length, block_elements)
}
if fused_bias {
if d_trans {
block_elements = std::cmp::max(block_elements, m_group);
} else {
block_elements = std::cmp::max(block_elements, n_group);
}
}
let bytes = match name {
"sgemm" => 4,
"hgemm" => 2,
other => { other => {
return Err(MetalKernelError::LoadLibraryError(format!( return Err(MetalKernelError::LoadLibraryError(format!(
"{other} is not a valid kernel for gemm" "{other} is not a valid kernel for gemm"
))) )));
} }
}; };
let mut bm = 32; let block_bytes = block_elements * bytes;
let mut bn = 32;
let mut bk = 16;
let wm = 2;
let wn = 2;
if b * m * n >= 1 << 20 {
if !a_trans && b_trans {
bm = 64;
bn = if oname == "float32" { 64 } else { 32 };
bk = if oname == "float32" { 16 } else { 32 };
} else {
bm = 64;
bn = 64;
}
}
let mnaligned = if m % bm == 0 && n % bn == 0 {
"taligned"
} else {
"naligned"
};
let kaligned = if k % bk == 0 { "taligned" } else { "naligned" };
// let bytes = match &name[..] {
// "sgemm" => 4,
// "hgemm" => 2,
// other => {
// return Err(MetalKernelError::LoadLibraryError(format!(
// "{other} is not a valid kernel for gemm"
// )));
// }
// };
let name = format!("steel_gemm_{a_trans_name}{b_trans_name}_{iname}_{oname}_bm{bm}_bn{bn}_bk{bk}_wm{wm}_wn{wn}_MN_{mnaligned}_K_{kaligned}");
let name = string_to_static_str(name);
let pipeline = kernels.load_pipeline(device, Source::Gemm, name)?;
// let m_group = m_simd * m_splits;
// let n_group = n_simd * n_splits;
// let a_block_length = m_group * k_simd;
// let b_block_length = k_simd * n_group;
// let mut block_elements = a_block_length + b_block_length;
// if (m % 8 != 0) && (n % 8 != 0) {
// let c_block_length = m_group * n_group;
// block_elements = std::cmp::max(c_block_length, block_elements)
// }
// if fused_bias {
// if d_trans {
// block_elements = std::cmp::max(block_elements, m_group);
// } else {
// block_elements = std::cmp::max(block_elements, n_group);
// }
// }
// let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder(); let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline); encoder.set_compute_pipeline_state(&pipeline);
// encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_threadgroup_memory_length(0, block_bytes.into());
let batch_stride_a: i32 = if lhs_stride.len() > 2 {
lhs_stride[lhs_stride.len() - 3] as i32
} else {
0
};
let batch_stride_b: i32 = if rhs_stride.len() > 2 {
rhs_stride[rhs_stride.len() - 3] as i32
} else {
0
};
let batch_stride_c = (m * n) as i32;
let swizzle_log = 0;
let tiles_n = ((n + bn - 1) / bn) as c_int;
let tiles_m = ((m + bm - 1) / bm) as c_int;
let params = GEMMParams {
m: m as c_int,
n: n as c_int,
k: k as c_int,
lda,
ldb,
ldc: n as c_int,
tiles_m,
tiles_n,
batch_stride_a,
batch_stride_b,
batch_stride_c,
swizzle_log,
gemm_k_iterations_aligned: (k / bk) as c_int,
};
let params_buffer = device.new_buffer_with_data(
&params as *const GEMMParams as *const c_void,
core::mem::size_of::<GEMMParams>() as u64,
MTLResourceOptions::StorageModeShared,
);
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(2, Some(output), 0); encoder.set_buffer(2, Some(output), 0);
encoder.set_buffer(3, Some(&params_buffer), 0);
// TODO Tensor D // TODO Tensor D
let grid_z = b; let grid_z = b;
// if batched { if batched {
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
// let byte_stride_c = m * n * bytes as usize; let byte_stride_c = m * n * bytes as usize;
// // TODO byte_stride_d // TODO byte_stride_d
// let byte_stride_d = 0; let byte_stride_d = 0;
// let buffer: Vec<u64> = vec![ let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
// byte_stride_a as _, for i in 0..b {
// byte_stride_b as _, buffer.push((i * byte_stride_a) as u64);
// byte_stride_c as _, buffer.push((i * byte_stride_b) as u64);
// byte_stride_d as _, buffer.push((i * byte_stride_c) as u64);
// ]; buffer.push((i * byte_stride_d) as u64);
// // encoder.set_bytes( }
// // 10, encoder.set_bytes(
// // (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, 10,
// // buffer.as_ptr() as *const NSUInteger as *const c_void, (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
// // ); buffer.as_ptr() as *const NSUInteger as *const c_void,
// } );
let tile = 1 << swizzle_log; }
let tm = (tiles_m + tile - 1) / tile;
let tn = tiles_n * tile;
let grid_size = MTLSize { let grid_size = MTLSize {
width: tn as u64, width: divide(n, n_group.into()),
height: tm as u64, height: divide(m, m_group.into()),
depth: grid_z as NSUInteger, depth: grid_z as NSUInteger,
}; };
let group_size = MTLSize { let group_size = MTLSize {
width: 32, width: 32 * (m_splits as u64) * (n_splits as u64),
height: wn, height: 1,
depth: wm, depth: 1,
}; };
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read); encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
@ -1647,73 +1525,6 @@ pub fn call_upsample_nearest_2d(
Ok(()) Ok(())
} }
#[allow(clippy::too_many_arguments)]
pub fn call_random_uniform(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
min: f32,
max: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
if min >= max {
return Err(MetalKernelError::LoadLibraryError(
"min must be less than max".to_string(),
));
}
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, min, max, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_random_normal(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: &'static str,
mean: f32,
stddev: f32,
length: usize,
seed: &Buffer,
buffer: &Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
let encoder = command_buffer.new_compute_command_encoder();
let odd = (length % 2 != 0) as usize;
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, mean, stddev, seed, buffer));
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub enum GgmlDType { pub enum GgmlDType {
Q4_0, Q4_0,
@ -1743,145 +1554,7 @@ pub fn call_quantized_matmul_t(
rhs: &Buffer, rhs: &Buffer,
output: &Buffer, output: &Buffer,
) -> Result<(), MetalKernelError> { ) -> Result<(), MetalKernelError> {
// Everything is in reverse todo!("Not implemented yet");
let ne00 = k as i64;
let ne01 = n as i64;
let ne02 = b as i64;
let ne03 = 1 as i64;
let nb00 = 0i64;
let nb01 = 0 as i64;
let nb02 = 0 as i64;
let ne10 = k as i64;
let ne11 = m as i64;
let ne12 = b as i64;
let ne13 = 1 as i64;
let nb10 = 0i64;
let nb11 = 0i64;
let nb12 = 0i64;
let ne0 = n as i64;
let ne1 = m as i64;
let r2: u32 = (ne12 / ne02) as u32;
let r3: u32 = (ne13 / ne03) as u32;
let (nth0, nth1, align) = match dtype {
GgmlDType::Q4_0
| GgmlDType::Q4_1
| GgmlDType::Q5_0
| GgmlDType::Q5_1
| GgmlDType::Q8_0
| GgmlDType::Q8_1 => {
let nth0 = 8;
let nth1 = 8;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::Q2K => {
// Fixing a bug in Metal for GGML
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q4K => {
let nth0 = 4;
let nth1 = 8;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q3K | GgmlDType::Q5K => {
let nth0 = 2;
let nth1 = 32;
let align = 4;
(nth0, nth1, align)
}
GgmlDType::Q6K => {
let nth0 = 2;
let nth1 = 32;
let align = 2;
(nth0, nth1, align)
}
GgmlDType::F16 | GgmlDType::Q8K => {
// Original implem uses rows
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
GgmlDType::F32 => {
let nth0 = 32;
let nth1 = 1;
let align = 8;
(nth0, nth1, align)
}
};
let thread_groups_count = MTLSize {
width: divide(ne01 as usize, align),
height: ne11 as u64,
depth: (ne12 * ne13) as u64,
};
let threads_per_threadgroup = MTLSize {
width: nth0,
height: nth1,
depth: 1,
};
let name = match dtype {
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
};
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
rhs,
(lhs, lhs_offset),
output,
ne00,
ne01,
ne02,
nb00,
nb01,
nb02,
ne10,
ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
ne1,
r2,
r3
)
);
encoder.set_threadgroup_memory_length(0, 8192);
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
encoder.end_encoding();
Ok(())
} }
fn divide(m: usize, b: usize) -> NSUInteger { fn divide(m: usize, b: usize) -> NSUInteger {

File diff suppressed because it is too large Load Diff

View File

@ -1,206 +0,0 @@
#include <metal_stdlib>
#include <metal_integer>
#include <metal_atomic>
using namespace metal;
// Constants
// 2^32 and 1/2^32. Useful for converting between float and uint.
static constexpr constant ulong UNIF01_NORM32 = 4294967296;
static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
// 2 * pi
static constexpr constant float TWO_PI = 2.0 * M_PI_F;
static constexpr constant int3 S1 = {13, 19, 12};
static constexpr constant int3 S2 = {2, 25, 4};
static constexpr constant int3 S3 = {3, 11, 17};
// Used to prevent bad seeds.
static constexpr constant uint64_t PHI[16] = {
0x9E3779B97F4A7C15,
0xF39CC0605CEDC834,
0x1082276BF3A27251,
0xF86C6A11D0C18E95,
0x2767F0B153D27B7F,
0x0347045B5BF1827F,
0x01886F0928403002,
0xC1D64BA40F335E36,
0xF06AD7AE9717877E,
0x85839D6EFFBD7DC6,
0x64D325D1C5371682,
0xCADD0CCCFDFFBBE1,
0x626E33B8D04B4331,
0xBBF73C790D94F79D,
0x471C4AB3ED3D82A5,
0xFEC507705E4AE6E5,
};
// Combined Tausworthe and LCG Random Number Generator.
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
struct HybridTaus {
float state;
HybridTaus() thread = default;
HybridTaus() threadgroup = default;
HybridTaus() device = default;
HybridTaus() constant = default;
// Generate seeds for each thread.
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
}
// Tausworthe generator.
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
uint b = (((z << s.x) ^ z) >> s.y);
return (((z & M) << s.z) ^ b);
}
// LCG generator.
METAL_FUNC static uint lcg(const uint z) {
return (1664525 * z + 1013904223UL);
}
// Initialize the RNG state.
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
uint4 seed = seed_per_thread(seeds);
// Seed #1
uint z1 = taus(seed.x, S1, 4294967294UL);
uint z2 = taus(seed.y, S2, 4294967288UL);
uint z3 = taus(seed.z, S3, 4294967280UL);
uint z4 = lcg(seed.x);
// Seed #2
uint r1 = (z1^z2^z3^z4^seed.y);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
// Seed #3
r1 = (z1^z2^z3^z4^seed.z);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
// Seed #4
r1 = (z1^z2^z3^z4^seed.w);
z1 = taus(r1, S1, 429496729UL);
z2 = taus(r1, S2, 4294967288UL);
z3 = taus(r1, S3, 429496280UL);
z4 = lcg(r1);
HybridTaus rng;
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
return rng;
}
METAL_FUNC float rand() {
uint seed = this->state * UNIF01_NORM32;
uint z1 = taus(seed, S1, 429496729UL);
uint z2 = taus(seed, S2, 4294967288UL);
uint z3 = taus(seed, S3, 429496280UL);
uint z4 = lcg(seed);
thread float result = this->state;
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
return result;
}
};
template<typename T> METAL_FUNC void rand_uniform(
constant size_t &size,
constant float &min,
constant float &max,
device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
float diff = abs(min - max);
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
out[tid] = static_cast<T>(rng.rand() * diff + min);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
// Return early if tid == 0, otherwise we will write to out[size].
return;
}
// Use symmetry to fill the other half of the array.
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
}
// Create Gaussian normal distribution using Box-Muller transform:
// https://en.wikipedia.org/wiki/BoxMuller_transform
template<typename T> METAL_FUNC void normal(
constant size_t &size,
constant float &mean,
constant float &stddev,
device atomic_uint *seed,
device T *out,
uint tid [[thread_position_in_grid]]
) {
if (tid >= size) {
return;
}
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
float u1 = rng.rand();
float u2 = rng.rand();
float cosval;
float sinval = sincos(TWO_PI * u2, cosval);
float mag = stddev * sqrt(-2.0 * log(u1));
float z0 = mag * cosval + mean;
float z1 = mag * sinval + mean;
out[tid] = static_cast<T>(z0);
if (tid == 0) {
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
// Return early if tid == 0, otherwise we will write to out[size].
return;
}
// Use symmetry to fill the other half of the array.
out[size - tid] = static_cast<T>(z1);
}
#define UNIFORM_OP(NAME, T) \
kernel void rand_uniform_##NAME( \
constant size_t &size, \
constant float &min, \
constant float &max, \
device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
rand_uniform<T>(size, min, max, seed, out, tid); \
} \
#define NORMAL_OP(NAME, T) \
kernel void rand_normal_##NAME( \
constant size_t &size, \
constant float &mean, \
constant float &stddev, \
device atomic_uint *seed, \
device T *out, \
uint tid [[thread_position_in_grid]] \
) { \
normal<T>(size, mean, stddev, seed, out, tid); \
} \
#define RANDOM_OPS(NAME, T) \
UNIFORM_OP(NAME, T) \
NORMAL_OP(NAME, T) \
RANDOM_OPS(f32, float)
RANDOM_OPS(f16, half)
#if __METAL_VERSION__ >= 310
RANDOM_OPS(bf16, bfloat)
#endif

View File

@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer { fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let ptr = data.as_ptr() as *const c_void; let ptr = data.as_ptr() as *const core::ffi::c_void;
let size = (data.len() * std::mem::size_of::<T>()) as u64; let size = (data.len() * std::mem::size_of::<T>()) as u64;
device.new_buffer_with_data(ptr, size, options) device.new_buffer_with_data(ptr, size, options)
} }
@ -713,6 +713,7 @@ fn softmax() {
} }
let results = run_softmax(&v, last_dim, "softmax_f32"); let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4); let results = approx(results, 4);
println!("{results:?}");
assert_eq!( assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(), results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n n
@ -926,124 +927,3 @@ fn gemm() {
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
); );
} }
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
let seed = device.new_buffer_with_data(
&seed as *const u32 as *const core::ffi::c_void,
std::mem::size_of::<u32>() as NSUInteger,
options,
);
if name.starts_with("rand_uniform") {
call_random_uniform(
&device,
command_buffer,
&kernels,
name,
a,
b,
length,
&seed,
&output,
)
.unwrap();
} else {
call_random_normal(
&device,
command_buffer,
&kernels,
name,
a,
b,
length,
&seed,
&output,
)
.unwrap();
}
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
#[test]
fn random() {
fn calc_mean(data: &[f32]) -> f32 {
let sum = data.iter().sum::<f32>() as f32;
let count = data.len();
assert!(count > 0);
sum / count as f32
}
fn calc_stddev(data: &[f32]) -> f32 {
let mean = calc_mean(data);
let count = data.len();
assert!(count > 0);
let variance = data
.iter()
.map(|value| {
let diff = mean - (*value as f32);
diff * diff
})
.sum::<f32>()
/ count as f32;
variance.sqrt()
}
let shape = vec![1024, 10];
let length = shape.iter().product::<usize>();
let seed = 299792458;
let min = -30.0;
let max = 30.0;
let mean = 100.0;
let stddev = 50.0;
macro_rules! validate_random {
($type:ty) => {
let results: Vec<f32> = run_random::<$type>(
concat!("rand_uniform_", stringify!($type)),
seed,
length,
min,
max,
)
.into_iter()
.map(f32::from)
.collect();
results.iter().for_each(|v| {
assert!(*v >= min && *v <= max);
});
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
let results: Vec<f32> = run_random::<$type>(
concat!("rand_normal_", stringify!($type)),
seed,
length,
mean,
stddev,
)
.into_iter()
.map(f32::from)
.collect();
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
};
}
validate_random!(f32);
validate_random!(f16);
validate_random!(bf16);
}

View File

@ -262,19 +262,9 @@ impl BatchNorm {
let target_shape = target_shape.as_slice(); let target_shape = target_shape.as_slice();
let x = x let x = x
.broadcast_sub( .broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
&self
.running_mean
.as_detached_tensor()
.reshape(target_shape)?,
)?
.broadcast_div( .broadcast_div(
&(self &(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
.running_var
.as_detached_tensor()
.reshape(target_shape)?
+ self.eps)?
.sqrt()?,
)?; )?;
match &self.weight_and_bias { match &self.weight_and_bias {

View File

@ -124,7 +124,7 @@ fn set_at_index<D: WithDType, I: Into<i64>>(
value: I, value: I,
offset: usize, offset: usize,
depth: usize, depth: usize,
v: &mut [D], v: &mut Vec<D>,
on_value: D, on_value: D,
) -> Result<()> { ) -> Result<()> {
let value = value.into(); let value = value.into();

View File

@ -412,16 +412,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
} }
impl<'a> VarBuilder<'a> { impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` using a custom backend. fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
///
/// It is preferred to use one of the more specific constructors. This
/// constructor is provided to allow downstream users to define their own
/// backends.
pub fn from_backend(
backend: Box<dyn SimpleBackend + 'a>,
dtype: DType,
device: Device,
) -> Self {
let data = TensorData { let data = TensorData {
backend, backend,
dtype, dtype,
@ -436,13 +427,13 @@ impl<'a> VarBuilder<'a> {
/// Initializes a `VarBuilder` that uses zeros for any tensor. /// Initializes a `VarBuilder` that uses zeros for any tensor.
pub fn zeros(dtype: DType, dev: &Device) -> Self { pub fn zeros(dtype: DType, dev: &Device) -> Self {
Self::from_backend(Box::new(Zeros), dtype, dev.clone()) Self::new(Box::new(Zeros), dtype, dev.clone())
} }
/// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is /// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
/// returned if no tensor is available under the requested path or on shape mismatches. /// returned if no tensor is available under the requested path or on shape mismatches.
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self { pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
Self::from_backend(Box::new(ts), dtype, dev.clone()) Self::new(Box::new(ts), dtype, dev.clone())
} }
/// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and /// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
@ -452,7 +443,7 @@ impl<'a> VarBuilder<'a> {
/// Note that it is possible to load the tensor values after model creation using the `load` /// Note that it is possible to load the tensor values after model creation using the `load`
/// method on `varmap`, this can be used to start model training from an existing checkpoint. /// method on `varmap`, this can be used to start model training from an existing checkpoint.
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self { pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone()) Self::new(Box::new(varmap.clone()), dtype, dev.clone())
} }
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors /// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
@ -467,25 +458,25 @@ impl<'a> VarBuilder<'a> {
dev: &Device, dev: &Device,
) -> Result<Self> { ) -> Result<Self> {
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?; let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
} }
/// Initializes a `VarBuilder` from a binary builder in the safetensor format. /// Initializes a `VarBuilder` from a binary builder in the safetensor format.
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> { pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
let tensors = candle::safetensors::BufferedSafetensors::new(data)?; let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone())) Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
} }
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file. /// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> { pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let npz = candle::npy::NpzTensors::new(p)?; let npz = candle::npy::NpzTensors::new(p)?;
Ok(Self::from_backend(Box::new(npz), dtype, dev.clone())) Ok(Self::new(Box::new(npz), dtype, dev.clone()))
} }
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file. /// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> { pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let pth = candle::pickle::PthTensors::new(p, None)?; let pth = candle::pickle::PthTensors::new(p)?;
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone())) Ok(Self::new(Box::new(pth), dtype, dev.clone()))
} }
} }

View File

@ -1,6 +1,6 @@
[package] [package]
name = "candle-onnx" name = "candle-onnx"
version = "0.4.0" version = "0.3.3"
edition = "2021" edition = "2021"
description = "ONNX support for Candle" description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
candle = { path = "../candle-core", package = "candle-core", version = "0.4.0" } candle = { path = "../candle-core", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.4.0" } candle-nn = { path = "../candle-nn" }
prost = "0.12.1" prost = "0.12.1"
[build-dependencies] [build-dependencies]

View File

@ -766,16 +766,6 @@ pub fn simple_eval(
let output = input.cumsum(axis as usize)?; let output = input.cumsum(axis as usize)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten
"Flatten" => {
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(1) as usize;
let input = get(&node.input[0])?;
let first_part: usize = input.shape().dims().iter().take(axis).product();
let end_index = input.shape().dims().iter().product::<usize>();
let new_shape = (first_part, end_index / first_part);
let output = input.reshape(new_shape)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"), op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
} }
} }

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src; extern crate accelerate_src;
use candle::{Device, Result, Tensor}; use candle::{Device, Result, Tensor};
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap; use std::collections::HashMap;
const INPUT_X: &str = "x"; const INPUT_X: &str = "x";
@ -677,134 +677,6 @@ fn test_dropout_operation() -> Result<()> {
Ok(()) Ok(())
} }
// "Flatten"
#[test]
fn test_flatten_operation() -> Result<()> {
let mut att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
i: 0,
doc_string: "axis".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Flatten".to_string(),
domain: "".to_string(),
attribute: vec![att_axis.clone()],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32,
],
&[2, 2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]);
att_axis.i = 1;
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Flatten".to_string(),
domain: "".to_string(),
attribute: vec![att_axis.clone()],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]]
);
Ok(())
}
// Below are ops that are implemented but not tested yet // Below are ops that are implemented but not tested yet
// "MaxPool" // "MaxPool"

View File

@ -88,27 +88,23 @@ class QTensor:
Dequantizes the tensor. Dequantizes the tensor.
""" """
pass pass
@property @property
def ggml_dtype(self) -> str: def ggml_dtype(self) -> str:
""" """
Gets the tensors quantized dtype. Gets the tensors quantized dtype.
""" """
pass pass
def matmul_t(self, lhs: Tensor) -> Tensor: def matmul_t(self, lhs: Tensor) -> Tensor:
""" """
Performs a quantized matrix multiplication, with the quantized tensor as the right hand side. Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
""" """
pass pass
@property @property
def rank(self) -> int: def rank(self) -> int:
""" """
Gets the rank of the tensor. Gets the rank of the tensor.
""" """
pass pass
@property @property
def shape(self) -> Tuple[int]: def shape(self) -> Tuple[int]:
""" """
@ -123,213 +119,178 @@ class Tensor:
def __init__(self, data: _ArrayLike): def __init__(self, data: _ArrayLike):
pass pass
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Add a scalar to a tensor or two tensors together. Add a scalar to a tensor or two tensors together.
""" """
pass pass
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor": def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
""" """
Return a slice of a tensor. Return a slice of a tensor.
""" """
pass pass
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Multiply a tensor by a scalar or one tensor by another. Multiply a tensor by a scalar or one tensor by another.
""" """
pass pass
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Add a scalar to a tensor or two tensors together. Add a scalar to a tensor or two tensors together.
""" """
pass pass
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor": def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
""" """
Compare a tensor with a scalar or one tensor with another. Compare a tensor with a scalar or one tensor with another.
""" """
pass pass
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Multiply a tensor by a scalar or one tensor by another. Multiply a tensor by a scalar or one tensor by another.
""" """
pass pass
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Subtract a scalar from a tensor or one tensor from another. Subtract a scalar from a tensor or one tensor from another.
""" """
pass pass
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor": def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
""" """
Divide a tensor by a scalar or one tensor by another. Divide a tensor by a scalar or one tensor by another.
""" """
pass pass
def abs(self) -> Tensor: def abs(self) -> Tensor:
""" """
Performs the `abs` operation on the tensor. Performs the `abs` operation on the tensor.
""" """
pass pass
def argmax_keepdim(self, dim: int) -> Tensor: def argmax_keepdim(self, dim: int) -> Tensor:
""" """
Returns the indices of the maximum value(s) across the selected dimension. Returns the indices of the maximum value(s) across the selected dimension.
""" """
pass pass
def argmin_keepdim(self, dim: int) -> Tensor: def argmin_keepdim(self, dim: int) -> Tensor:
""" """
Returns the indices of the minimum value(s) across the selected dimension. Returns the indices of the minimum value(s) across the selected dimension.
""" """
pass pass
def broadcast_add(self, rhs: Tensor) -> Tensor: def broadcast_add(self, rhs: Tensor) -> Tensor:
""" """
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def broadcast_as(self, *shape: Shape) -> Tensor: def broadcast_as(self, *shape: Shape) -> Tensor:
""" """
Broadcasts the tensor to the given shape. Broadcasts the tensor to the given shape.
""" """
pass pass
def broadcast_div(self, rhs: Tensor) -> Tensor: def broadcast_div(self, rhs: Tensor) -> Tensor:
""" """
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def broadcast_left(self, *shape: Shape) -> Tensor: def broadcast_left(self, *shape: Shape) -> Tensor:
""" """
Broadcasts the tensor to the given shape, adding new dimensions on the left. Broadcasts the tensor to the given shape, adding new dimensions on the left.
""" """
pass pass
def broadcast_mul(self, rhs: Tensor) -> Tensor: def broadcast_mul(self, rhs: Tensor) -> Tensor:
""" """
Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def broadcast_sub(self, rhs: Tensor) -> Tensor: def broadcast_sub(self, rhs: Tensor) -> Tensor:
""" """
Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor. Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
""" """
pass pass
def contiguous(self) -> Tensor: def contiguous(self) -> Tensor:
""" """
Makes the tensor contiguous in memory. Makes the tensor contiguous in memory.
""" """
pass pass
def copy(self) -> Tensor: def copy(self) -> Tensor:
""" """
Returns a copy of the tensor. Returns a copy of the tensor.
""" """
pass pass
def cos(self) -> Tensor: def cos(self) -> Tensor:
""" """
Performs the `cos` operation on the tensor. Performs the `cos` operation on the tensor.
""" """
pass pass
def detach(self) -> Tensor: def detach(self) -> Tensor:
""" """
Detach the tensor from the computation graph. Detach the tensor from the computation graph.
""" """
pass pass
@property @property
def device(self) -> Device: def device(self) -> Device:
""" """
Gets the tensor's device. Gets the tensor's device.
""" """
pass pass
@property @property
def dtype(self) -> DType: def dtype(self) -> DType:
""" """
Gets the tensor's dtype. Gets the tensor's dtype.
""" """
pass pass
def exp(self) -> Tensor: def exp(self) -> Tensor:
""" """
Performs the `exp` operation on the tensor. Performs the `exp` operation on the tensor.
""" """
pass pass
def flatten_all(self) -> Tensor: def flatten_all(self) -> Tensor:
""" """
Flattens the tensor into a 1D tensor. Flattens the tensor into a 1D tensor.
""" """
pass pass
def flatten_from(self, dim: int) -> Tensor: def flatten_from(self, dim: int) -> Tensor:
""" """
Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension. Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
""" """
pass pass
def flatten_to(self, dim: int) -> Tensor: def flatten_to(self, dim: int) -> Tensor:
""" """
Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive). Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
""" """
pass pass
def get(self, index: int) -> Tensor: def get(self, index: int) -> Tensor:
""" """
Gets the value at the specified index. Gets the value at the specified index.
""" """
pass pass
def index_select(self, rhs: Tensor, dim: int) -> Tensor: def index_select(self, rhs: Tensor, dim: int) -> Tensor:
""" """
Select values for the input tensor at the target indexes across the specified dimension. Select values for the input tensor at the target indexes across the specified dimension.
@ -341,192 +302,161 @@ class Tensor:
tensor. tensor.
""" """
pass pass
def is_contiguous(self) -> bool: def is_contiguous(self) -> bool:
""" """
Returns true if the tensor is contiguous in C order. Returns true if the tensor is contiguous in C order.
""" """
pass pass
def is_fortran_contiguous(self) -> bool: def is_fortran_contiguous(self) -> bool:
""" """
Returns true if the tensor is contiguous in Fortran order. Returns true if the tensor is contiguous in Fortran order.
""" """
pass pass
def log(self) -> Tensor: def log(self) -> Tensor:
""" """
Performs the `log` operation on the tensor. Performs the `log` operation on the tensor.
""" """
pass pass
def matmul(self, rhs: Tensor) -> Tensor: def matmul(self, rhs: Tensor) -> Tensor:
""" """
Performs a matrix multiplication between the two tensors. Performs a matrix multiplication between the two tensors.
""" """
pass pass
def max_keepdim(self, dim: int) -> Tensor: def max_keepdim(self, dim: int) -> Tensor:
""" """
Gathers the maximum value across the selected dimension. Gathers the maximum value across the selected dimension.
""" """
pass pass
def mean_all(self) -> Tensor: def mean_all(self) -> Tensor:
""" """
Returns the mean of the tensor. Returns the mean of the tensor.
""" """
pass pass
def min_keepdim(self, dim: int) -> Tensor: def min_keepdim(self, dim: int) -> Tensor:
""" """
Gathers the minimum value across the selected dimension. Gathers the minimum value across the selected dimension.
""" """
pass pass
def narrow(self, dim: int, start: int, len: int) -> Tensor: def narrow(self, dim: int, start: int, len: int) -> Tensor:
""" """
Returns a new tensor that is a narrowed version of the input, the dimension `dim` Returns a new tensor that is a narrowed version of the input, the dimension `dim`
ranges from `start` to `start + len`. ranges from `start` to `start + len`.
""" """
pass pass
@property @property
def nelement(self) -> int: def nelement(self) -> int:
""" """
Gets the tensor's element count. Gets the tensor's element count.
""" """
pass pass
def powf(self, p: float) -> Tensor: def powf(self, p: float) -> Tensor:
""" """
Performs the `pow` operation on the tensor with the given exponent. Performs the `pow` operation on the tensor with the given exponent.
""" """
pass pass
def quantize(self, quantized_dtype: str) -> QTensor: def quantize(self, quantized_dtype: str) -> QTensor:
""" """
Quantize the tensor. Quantize the tensor.
""" """
pass pass
@property @property
def rank(self) -> int: def rank(self) -> int:
""" """
Gets the tensor's rank. Gets the tensor's rank.
""" """
pass pass
def recip(self) -> Tensor: def recip(self) -> Tensor:
""" """
Get the `recip` of the tensor. Get the `recip` of the tensor.
""" """
pass pass
def reshape(self, *shape: Shape) -> Tensor: def reshape(self, *shape: Shape) -> Tensor:
""" """
Reshapes the tensor to the given shape. Reshapes the tensor to the given shape.
""" """
pass pass
@property @property
def shape(self) -> Tuple[int]: def shape(self) -> Tuple[int]:
""" """
Gets the tensor's shape. Gets the tensor's shape.
""" """
pass pass
def sin(self) -> Tensor: def sin(self) -> Tensor:
""" """
Performs the `sin` operation on the tensor. Performs the `sin` operation on the tensor.
""" """
pass pass
def sqr(self) -> Tensor: def sqr(self) -> Tensor:
""" """
Squares the tensor. Squares the tensor.
""" """
pass pass
def sqrt(self) -> Tensor: def sqrt(self) -> Tensor:
""" """
Calculates the square root of the tensor. Calculates the square root of the tensor.
""" """
pass pass
def squeeze(self, dim: int) -> Tensor: def squeeze(self, dim: int) -> Tensor:
""" """
Creates a new tensor with the specified dimension removed if its size was one. Creates a new tensor with the specified dimension removed if its size was one.
""" """
pass pass
@property @property
def stride(self) -> Tuple[int]: def stride(self) -> Tuple[int]:
""" """
Gets the tensor's strides. Gets the tensor's strides.
""" """
pass pass
def sum_all(self) -> Tensor: def sum_all(self) -> Tensor:
""" """
Returns the sum of the tensor. Returns the sum of the tensor.
""" """
pass pass
def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor: def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor:
""" """
Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions. Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
""" """
pass pass
def t(self) -> Tensor: def t(self) -> Tensor:
""" """
Transposes the tensor. Transposes the tensor.
""" """
pass pass
def to(self, *args, **kwargs) -> Tensor: def to(self, *args, **kwargs) -> Tensor:
""" """
Performs Tensor dtype and/or device conversion. Performs Tensor dtype and/or device conversion.
""" """
pass pass
def to_device(self, device: Union[str, Device]) -> Tensor: def to_device(self, device: Union[str, Device]) -> Tensor:
""" """
Move the tensor to a new device. Move the tensor to a new device.
""" """
pass pass
def to_dtype(self, dtype: Union[str, DType]) -> Tensor: def to_dtype(self, dtype: Union[str, DType]) -> Tensor:
""" """
Convert the tensor to a new dtype. Convert the tensor to a new dtype.
""" """
pass pass
def to_torch(self) -> torch.Tensor: def to_torch(self) -> torch.Tensor:
""" """
Converts candle's tensor to pytorch's tensor Converts candle's tensor to pytorch's tensor
""" """
pass pass
def transpose(self, dim1: int, dim2: int) -> Tensor: def transpose(self, dim1: int, dim2: int) -> Tensor:
""" """
Returns a tensor that is a transposed version of the input, the given dimensions are swapped. Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
""" """
pass pass
def unsqueeze(self, dim: int) -> Tensor: def unsqueeze(self, dim: int) -> Tensor:
""" """
Creates a new tensor with a dimension of size one inserted at the specified position. Creates a new tensor with a dimension of size one inserted at the specified position.
""" """
pass pass
def values(self) -> _ArrayLike: def values(self) -> _ArrayLike:
""" """
Gets the tensor's data as a Python scalar or array-like object. Gets the tensor's data as a Python scalar or array-like object.
""" """
pass pass
def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor: def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor:
""" """
Returns a tensor with the same shape as the input tensor, the values are taken from Returns a tensor with the same shape as the input tensor, the values are taken from

View File

@ -57,10 +57,12 @@ class Sequential(Module):
_modules: Dict[str, Module] # type: ignore[assignment] _modules: Dict[str, Module] # type: ignore[assignment]
@overload @overload
def __init__(self, *args: Module) -> None: ... def __init__(self, *args: Module) -> None:
...
@overload @overload
def __init__(self, arg: "OrderedDict[str, Module]") -> None: ... def __init__(self, arg: "OrderedDict[str, Module]") -> None:
...
def __init__(self, *args): def __init__(self, *args):
super().__init__() super().__init__()

View File

@ -204,10 +204,12 @@ class Module:
T_destination = TypeVar("T_destination", bound=Dict[str, Any]) T_destination = TypeVar("T_destination", bound=Dict[str, Any])
@overload @overload
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ... def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
...
@overload @overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ... def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
...
def state_dict(self, *args, destination=None, prefix="", keep_vars=False): def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module. r"""Returns a dictionary containing references to the whole state of the module.
@ -584,10 +586,12 @@ class Module:
self: T, self: T,
device: str = ..., device: str = ...,
dtype: Optional[Union[DType, str]] = ..., dtype: Optional[Union[DType, str]] = ...,
) -> T: ... ) -> T:
...
@overload @overload
def to(self: T, dtype: Union[DType, str]) -> T: ... def to(self: T, dtype: Union[DType, str]) -> T:
...
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers. r"""Moves and/or casts the parameters and buffers.

View File

@ -14,7 +14,6 @@ class LayerNorm(Module):
math:: math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
""" """
__constants__ = ["normalized_shape", "eps"] __constants__ = ["normalized_shape", "eps"]
normalized_shape: Tuple[int, ...] normalized_shape: Tuple[int, ...]
eps: float eps: float

View File

@ -11,69 +11,59 @@ class ONNXModel:
def __init__(self, path: str): def __init__(self, path: str):
pass pass
@property @property
def doc_string(self) -> str: def doc_string(self) -> str:
""" """
The doc string of the model. The doc string of the model.
""" """
pass pass
@property @property
def domain(self) -> str: def domain(self) -> str:
""" """
The domain of the operator set of the model. The domain of the operator set of the model.
""" """
pass pass
def initializers(self) -> Dict[str, Tensor]: def initializers(self) -> Dict[str, Tensor]:
""" """
Get the weights of the model. Get the weights of the model.
""" """
pass pass
@property @property
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
""" """
The inputs of the model. The inputs of the model.
""" """
pass pass
@property @property
def ir_version(self) -> int: def ir_version(self) -> int:
""" """
The version of the IR this model targets. The version of the IR this model targets.
""" """
pass pass
@property @property
def model_version(self) -> int: def model_version(self) -> int:
""" """
The version of the model. The version of the model.
""" """
pass pass
@property @property
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]: def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
""" """
The outputs of the model. The outputs of the model.
""" """
pass pass
@property @property
def producer_name(self) -> str: def producer_name(self) -> str:
""" """
The producer of the model. The producer of the model.
""" """
pass pass
@property @property
def producer_version(self) -> str: def producer_version(self) -> str:
""" """
The version of the producer of the model. The version of the producer of the model.
""" """
pass pass
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
""" """
Run the model on the given inputs. Run the model on the given inputs.
@ -91,7 +81,6 @@ class ONNXTensorDescription:
The data type of the tensor. The data type of the tensor.
""" """
pass pass
@property @property
def shape(self) -> Tuple[Union[int, str, Any]]: def shape(self) -> Tuple[Union[int, str, Any]]:
""" """

View File

@ -938,8 +938,8 @@ impl PyTensor {
/// Detach the tensor from the computation graph. /// Detach the tensor from the computation graph.
/// &RETURNS&: Tensor /// &RETURNS&: Tensor
fn detach(&self) -> Self { fn detach(&self) -> PyResult<Self> {
PyTensor(self.0.detach()) Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
} }
/// Returns a copy of the tensor. /// Returns a copy of the tensor.

View File

@ -189,6 +189,7 @@ def do_black(content, is_pyi):
line_length=119, line_length=119,
is_pyi=is_pyi, is_pyi=is_pyi,
string_normalization=True, string_normalization=True,
experimental_string_processing=False,
) )
try: try:
return black.format_file_contents(content, fast=True, mode=mode) return black.format_file_contents(content, fast=True, mode=mode)

View File

@ -23,6 +23,7 @@ serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
serde_plain = { workspace = true } serde_plain = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
wav = { workspace = true }
[features] [features]
default = [] default = []

View File

@ -1,593 +0,0 @@
use crate::models::with_tracing::Linear;
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug, Clone)]
pub struct Config {
pub num_layers: usize,
pub padded_vocab_size: usize,
pub hidden_size: usize,
pub ffn_hidden_size: usize,
pub kv_channels: usize,
pub num_attention_heads: usize,
pub seq_length: usize,
pub layernorm_epsilon: f64,
pub rmsnorm: bool,
pub apply_residual_connection_post_layernorm: bool,
pub post_layer_norm: bool,
pub add_bias_linear: bool,
pub add_qkv_bias: bool,
pub bias_dropout_fusion: bool,
pub multi_query_attention: bool,
pub multi_query_group_num: usize,
pub apply_query_key_layer_scaling: bool,
pub attention_softmax_in_fp32: bool,
pub fp32_residual_connection: bool,
}
impl Config {
pub fn glm3_6b() -> Self {
Self {
num_layers: 28,
padded_vocab_size: 65024,
hidden_size: 4096,
ffn_hidden_size: 13696,
kv_channels: 128,
num_attention_heads: 32,
seq_length: 8192,
layernorm_epsilon: 1e-5,
rmsnorm: true,
apply_residual_connection_post_layernorm: false,
post_layer_norm: true,
add_bias_linear: false,
add_qkv_bias: true,
bias_dropout_fusion: true,
multi_query_attention: true,
multi_query_group_num: 2,
apply_query_key_layer_scaling: true,
attention_softmax_in_fp32: true,
fp32_residual_connection: false,
}
}
}
fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
if bias {
crate::models::with_tracing::linear(in_dim, out_dim, vb)
} else {
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
cache: Tensor,
}
impl RotaryEmbedding {
fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
let rotary_dim = cfg.kv_channels;
let n_elem = rotary_dim / 2;
let inv_freq: Vec<_> = (0..n_elem)
.step_by(2)
.map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
.to_dtype(dtype)?
.reshape((cfg.seq_length, 1))?;
let freqs = t.matmul(&inv_freq)?;
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
Ok(Self { cache })
}
fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (seqlen, _b, np, _hn) = xs.dims4()?;
let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
let rot_dim = cache.dim(D::Minus2)? * 2;
let (xs, xs_pass) = (
xs.narrow(D::Minus1, 0, rot_dim)?,
xs.narrow(D::Minus1, rot_dim, rot_dim)?,
);
let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
let (xshaped0, xshaped1) = (
xshaped.i((.., .., .., .., 0))?,
xshaped.i((.., .., .., .., 1))?,
);
let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
let xs_out = Tensor::stack(
&[
(xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
(xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
],
D::Minus1,
)?;
let xs_out = xs_out.flatten_from(3)?;
Tensor::cat(&[xs_out, xs_pass], D::Minus1)
}
}
#[derive(Debug, Clone)]
struct CoreAttention {
coeff: Option<f64>,
norm_factor: f64,
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
impl CoreAttention {
fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
let norm_factor = (cfg.kv_channels as f64).sqrt();
let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
let coeff = f64::max(1.0, layer_number as f64);
(norm_factor * coeff, Some(coeff))
} else {
(norm_factor, None)
};
Ok(Self { coeff, norm_factor })
}
fn forward(
&self,
query_layer: &Tensor,
key_layer: &Tensor,
value_layer: &Tensor,
attention_mask: &Option<Tensor>,
) -> Result<Tensor> {
let output_size = (
query_layer.dim(1)?, // b
query_layer.dim(2)?, // np
query_layer.dim(0)?, // sq
key_layer.dim(0)?, // sk
);
let query_layer =
query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
let matmul_result = Tensor::matmul(
&query_layer.transpose(0, 1)?,
&key_layer.transpose(0, 1)?.transpose(1, 2)?,
)?;
let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
let matmul_result = match self.coeff {
None => matmul_result,
Some(coeff) => (matmul_result * coeff)?,
};
let attention_scores = match attention_mask {
Some(mask) => masked_fill(
&matmul_result,
&mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
f32::NEG_INFINITY,
)?,
None => matmul_result,
};
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
let output_size = (
value_layer.dim(1)?,
value_layer.dim(2)?,
query_layer.dim(0)?,
value_layer.dim(3)?,
);
let value_layer =
value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
let attention_probs =
attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
let context_layer = Tensor::matmul(&attention_probs, &value_layer.transpose(0, 1)?)?;
let context_layer = context_layer.reshape(output_size)?;
let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
context_layer.flatten_from(D::Minus2)
}
}
#[derive(Debug, Clone)]
struct SelfAttention {
query_key_value: Linear,
core_attention: CoreAttention,
dense: Linear,
multi_query_attention: bool,
num_attention_heads_per_partition: usize,
num_multi_query_groups_per_partition: usize,
hidden_size_per_attention_head: usize,
kv_cache: Option<(Tensor, Tensor)>,
}
impl SelfAttention {
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let projection_size = cfg.kv_channels * cfg.num_attention_heads;
let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
let qkv_hidden_size = if cfg.multi_query_attention {
projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
} else {
3 * projection_size
};
let query_key_value = linear(
cfg.hidden_size,
qkv_hidden_size,
cfg.add_bias_linear || cfg.add_qkv_bias,
vb.pp("query_key_value"),
)?;
let core_attention = CoreAttention::new(layer_number, cfg)?;
let dense = linear(
cfg.hidden_size,
cfg.hidden_size,
cfg.add_bias_linear,
vb.pp("dense"),
)?;
Ok(Self {
query_key_value,
core_attention,
dense,
multi_query_attention: cfg.multi_query_attention,
num_attention_heads_per_partition: cfg.num_attention_heads,
num_multi_query_groups_per_partition: cfg.multi_query_group_num,
hidden_size_per_attention_head: cfg.kv_channels,
kv_cache: None,
})
}
fn reset_kv_cache(&mut self) {
self.kv_cache = None
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: &Option<Tensor>,
rotary_emb: &RotaryEmbedding,
) -> Result<Tensor> {
let mixed_x_layer = xs.apply(&self.query_key_value)?;
if !self.multi_query_attention {
candle::bail!("only multi_query_attention=true is supported")
}
let hpa = self.hidden_size_per_attention_head;
let query_layer =
mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
let key_layer = mixed_x_layer.narrow(
D::Minus1,
self.num_attention_heads_per_partition * hpa,
self.num_multi_query_groups_per_partition * hpa,
)?;
let value_layer = mixed_x_layer.narrow(
D::Minus1,
self.num_attention_heads_per_partition * hpa
+ self.num_multi_query_groups_per_partition * hpa,
self.num_multi_query_groups_per_partition * hpa,
)?;
let query_layer = query_layer.reshape((
query_layer.dim(0)?,
query_layer.dim(1)?,
self.num_attention_heads_per_partition,
hpa,
))?;
let key_layer = key_layer.reshape((
key_layer.dim(0)?,
key_layer.dim(1)?,
self.num_multi_query_groups_per_partition,
hpa,
))?;
let value_layer = value_layer.reshape((
value_layer.dim(0)?,
value_layer.dim(1)?,
self.num_multi_query_groups_per_partition,
hpa,
))?;
// Rotary embeddings.
let seqlen_offset = match &self.kv_cache {
None => 0,
Some((prev_k, _)) => prev_k.dim(0)?,
};
let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
// KV cache.
let (key_layer, value_layer) = match &self.kv_cache {
None => (key_layer, value_layer),
Some((prev_k, prev_v)) => {
let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
(k, v)
}
};
self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
// Repeat KV.
let ratio =
self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
let key_layer = {
let (d0, d1, d2, d3) = key_layer.dims4()?;
key_layer
.unsqueeze(D::Minus2)?
.expand((d0, d1, d2, ratio, d3))?
.reshape((
d0,
d1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
))?
};
let value_layer = {
let (d0, d1, d2, d3) = value_layer.dims4()?;
value_layer
.unsqueeze(D::Minus2)?
.expand((d0, d1, d2, ratio, d3))?
.reshape((
d0,
d1,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
))?
};
let context_layer =
self.core_attention
.forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
let output = context_layer.apply(&self.dense)?;
Ok(output)
}
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Clone)]
struct MLP {
dense_h_to_4h: Linear,
dense_4h_to_h: Linear,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dense_h_to_4h = linear(
cfg.hidden_size,
cfg.ffn_hidden_size * 2,
cfg.add_bias_linear,
vb.pp("dense_h_to_4h"),
)?;
let dense_4h_to_h = linear(
cfg.ffn_hidden_size,
cfg.hidden_size,
cfg.add_bias_linear,
vb.pp("dense_4h_to_h"),
)?;
Ok(Self {
dense_4h_to_h,
dense_h_to_4h,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.dense_h_to_4h)?
.apply(&candle_nn::Activation::Swiglu)?
.apply(&self.dense_4h_to_h)
}
}
#[derive(Debug, Clone)]
struct Block {
input_layernorm: candle_nn::LayerNorm,
self_attention: SelfAttention,
post_attention_layernorm: candle_nn::LayerNorm,
mlp: MLP,
apply_residual_connection_post_layernorm: bool,
}
impl Block {
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let input_layernorm = if cfg.rmsnorm {
candle_nn::rms_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("input_layernorm"),
)?
.into_inner()
} else {
candle_nn::layer_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("input_layernorm"),
)?
};
let post_attention_layernorm = if cfg.rmsnorm {
candle_nn::rms_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("post_attention_layernorm"),
)?
.into_inner()
} else {
candle_nn::layer_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("post_attention_layernorm"),
)?
};
let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
Ok(Self {
input_layernorm,
self_attention,
post_attention_layernorm,
mlp,
apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
})
}
fn reset_kv_cache(&mut self) {
self.self_attention.reset_kv_cache()
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: &Option<Tensor>,
rotary_emb: &RotaryEmbedding,
) -> Result<Tensor> {
let layernorm_output = xs.apply(&self.input_layernorm)?;
let attention_output =
self.self_attention
.forward(&layernorm_output, attention_mask, rotary_emb)?;
let residual = if self.apply_residual_connection_post_layernorm {
&layernorm_output
} else {
xs
};
let layernorm_input = (residual + attention_output)?;
let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
let mlp_output = layernorm_output.apply(&self.mlp)?;
let residual = if self.apply_residual_connection_post_layernorm {
&layernorm_output
} else {
&layernorm_input
};
mlp_output + residual
}
}
#[derive(Debug, Clone)]
struct Transformer {
layers: Vec<Block>,
final_layernorm: Option<candle_nn::LayerNorm>,
rotary_emb: RotaryEmbedding,
}
impl Transformer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_l = vb.pp("layers");
let mut layers = Vec::with_capacity(cfg.num_layers);
for layer_index in 0..cfg.num_layers {
let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
layers.push(block)
}
let final_layernorm = if cfg.post_layer_norm {
let ln = if cfg.rmsnorm {
candle_nn::rms_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("final_layernorm"),
)?
.into_inner()
} else {
candle_nn::layer_norm(
cfg.hidden_size,
cfg.layernorm_epsilon,
vb.pp("final_layernorm"),
)?
};
Some(ln)
} else {
None
};
let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
Ok(Self {
layers,
final_layernorm,
rotary_emb,
})
}
fn reset_kv_cache(&mut self) {
for block in self.layers.iter_mut() {
block.reset_kv_cache()
}
}
fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for block in self.layers.iter_mut() {
xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
}
match self.final_layernorm.as_ref() {
None => Ok(xs),
Some(ln) => xs.apply(ln),
}
}
}
#[derive(Debug, Clone)]
struct Embedding {
word_embeddings: candle_nn::Embedding,
fp32_residual_connection: bool,
}
impl Embedding {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let word_embeddings = candle_nn::embedding(
cfg.padded_vocab_size,
cfg.hidden_size,
vb.pp("word_embeddings"),
)?;
Ok(Self {
word_embeddings,
fp32_residual_connection: cfg.fp32_residual_connection,
})
}
}
impl Module for Embedding {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
if self.fp32_residual_connection {
xs.to_dtype(candle::DType::F32)
} else {
xs.contiguous()
}
}
}
#[derive(Debug, Clone)]
pub struct Model {
embedding: Embedding,
encoder: Transformer,
output_layer: Linear,
}
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device)
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("transformer");
let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
let output_layer = linear(
cfg.hidden_size,
cfg.padded_vocab_size,
false,
vb.pp("output_layer"),
)?;
Ok(Self {
embedding,
encoder,
output_layer,
})
}
pub fn reset_kv_cache(&mut self) {
self.encoder.reset_kv_cache()
}
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let (_b_size, seq_len) = xs.dims2()?;
let input_embeds = xs.apply(&self.embedding)?;
let attention_mask = if seq_len <= 1 {
None
} else {
Some(get_mask(seq_len, xs.device())?)
};
let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
Ok(lm_logits)
}
}

View File

@ -1,201 +0,0 @@
//! ConvNeXt implementation.
//!
//! See "A ConvNet for the 2020s" Liu et al. 2022
//! <https://arxiv.org/abs/2201.03545>
//! Original code: https://github.com/facebookresearch/ConvNeXt/
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
use candle::{Result, D};
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
#[derive(Clone)]
pub struct Config {
blocks: [usize; 4],
channels: [usize; 4],
}
impl Config {
pub fn tiny() -> Self {
Self {
blocks: [3, 3, 9, 3],
channels: [96, 192, 384, 768],
}
}
pub fn small() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [96, 192, 384, 768],
}
}
pub fn base() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [128, 256, 512, 1024],
}
}
pub fn large() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [192, 384, 768, 1536],
}
}
pub fn xlarge() -> Self {
Self {
blocks: [3, 3, 27, 3],
channels: [256, 512, 1024, 2048],
}
}
}
// Initial downsampling via a patchify layer.
fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride: 4,
..Default::default()
};
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
Ok(Func::new(move |xs| {
// The layer norm works with channels-last format.
let xs = xs
.apply(&patchify)?
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?;
Ok(xs)
}))
}
// Downsampling applied after the stages.
fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride: 2,
..Default::default()
};
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
Ok(Func::new(move |xs| {
let xs = xs
.permute((0, 2, 3, 1))?
.apply(&norm)?
.permute((0, 3, 1, 2))?
.apply(&conv)?;
Ok(xs)
}))
}
// MLP equivalent of pointwise convolutions.
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
Ok(Func::new(move |xs| {
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
Ok(xs)
}))
}
// A block consisting of a depthwise convolution, a MLP and layer scaling.
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
groups: dim,
padding: 3,
..Default::default()
};
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
let gamma = vb.get(dim, "gamma")?;
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs
.apply(&conv_dw)?
.permute((0, 2, 3, 1))?
.apply(&norm)?
.apply(&mlp)?
.broadcast_mul(&gamma)?
.permute((0, 3, 1, 2))?;
xs + residual
}))
}
// Each stage contains blocks and a downsampling layer for the previous stage.
fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nblocks = cfg.blocks[stage_idx];
let mut blocks = Vec::with_capacity(nblocks);
let dim = cfg.channels[stage_idx];
if stage_idx > 0 {
blocks.push(convnext_downsample(dim, vb.pp("downsample"))?);
}
for block_idx in 0..nblocks {
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for block in blocks.iter() {
xs = xs.apply(block)?
}
Ok(xs)
}))
}
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
}
// Build a convnext model for a given configuration.
fn convnext_model(
config: &Config,
nclasses: Option<usize>,
vb: VarBuilder,
) -> Result<Func<'static>> {
let head = match nclasses {
None => None,
Some(nclasses) => {
let head = convnext_head(config.channels[3], nclasses, vb.pp("head"))?;
Some(head)
}
};
let stem = convnext_stem(config.channels[0], vb.pp("stem"))?;
let vb = vb.pp("stages");
let stage1 = convnext_stage(config, 0, vb.pp(0))?;
let stage2 = convnext_stage(config, 1, vb.pp(1))?;
let stage3 = convnext_stage(config, 2, vb.pp(2))?;
let stage4 = convnext_stage(config, 3, vb.pp(3))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.apply(&stage4)?
.mean(D::Minus2)?
.mean(D::Minus1)?;
match &head {
None => Ok(xs),
Some(head) => xs.apply(head),
}
}))
}
pub fn convnext(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
convnext_model(cfg, Some(nclasses), vb)
}
pub fn convnext_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
convnext_model(cfg, None, vb)
}

View File

@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex};
pub const MAX_SEQ_LEN: usize = 4096; pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Debug, Clone, Deserialize)] #[derive(Deserialize)]
pub struct LlamaConfig { pub struct LlamaConfig {
pub hidden_size: usize, pub hidden_size: usize,
pub intermediate_size: usize, pub intermediate_size: usize,
@ -40,7 +40,6 @@ impl LlamaConfig {
} }
} }
#[derive(Debug, Clone)]
pub struct Config { pub struct Config {
pub hidden_size: usize, pub hidden_size: usize,
pub intermediate_size: usize, pub intermediate_size: usize,
@ -83,7 +82,7 @@ impl Config {
} }
} }
#[derive(Debug, Clone)] #[derive(Clone)]
pub struct Cache { pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool, pub use_kv_cache: bool,
@ -137,7 +136,6 @@ impl Cache {
} }
} }
#[derive(Debug, Clone)]
struct RmsNorm { struct RmsNorm {
inner: candle_nn::RmsNorm, inner: candle_nn::RmsNorm,
span: tracing::Span, span: tracing::Span,
@ -156,7 +154,6 @@ impl RmsNorm {
} }
} }
#[derive(Debug, Clone)]
struct CausalSelfAttention { struct CausalSelfAttention {
q_proj: Linear, q_proj: Linear,
k_proj: Linear, k_proj: Linear,
@ -317,7 +314,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
Ok(m) Ok(m)
} }
#[derive(Debug, Clone)]
struct Mlp { struct Mlp {
c_fc1: Linear, c_fc1: Linear,
c_fc2: Linear, c_fc2: Linear,
@ -348,7 +344,6 @@ impl Mlp {
} }
} }
#[derive(Debug, Clone)]
struct Block { struct Block {
rms_1: RmsNorm, rms_1: RmsNorm,
attn: CausalSelfAttention, attn: CausalSelfAttention,
@ -388,7 +383,6 @@ impl Block {
} }
} }
#[derive(Debug, Clone)]
pub struct Llama { pub struct Llama {
wte: Embedding, wte: Embedding,
blocks: Vec<Block>, blocks: Vec<Block>,

View File

@ -1,211 +0,0 @@
#![allow(unused)]
/// A fast implementation of mamba for inference only.
/// This is based on: https://github.com/LaurentMazare/mamba.rs
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{RmsNorm, VarBuilder};
const D_CONV: usize = 4;
const D_STATE: usize = 16;
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
d_model: usize,
n_layer: usize,
vocab_size: usize,
pad_vocab_size_multiple: usize,
}
impl Config {
fn vocab_size(&self) -> usize {
let pad = self.pad_vocab_size_multiple;
(self.vocab_size + pad - 1) / pad * pad
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
}
fn d_inner(&self) -> usize {
self.d_model * 2
}
}
pub struct State {
hs: Vec<Tensor>,
prev_xs: Vec<[Tensor; D_CONV]>,
pos: usize,
}
impl State {
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
let mut hs = Vec::with_capacity(cfg.n_layer);
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
for _i in 0..cfg.n_layer {
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
hs.push(h);
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
}
Ok(Self {
hs,
prev_xs,
pos: 0,
})
}
}
#[derive(Clone, Debug)]
pub struct MambaBlock {
in_proj: Linear,
conv1d_bias: Tensor,
conv1d_weights: [Tensor; D_CONV],
x_proj: Linear,
dt_proj: Linear,
a_log: Tensor,
d: Tensor,
out_proj: Linear,
dt_rank: usize,
layer_index: usize,
d_inner: usize,
}
impl MambaBlock {
pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let d_inner = cfg.d_inner();
let dt_rank = cfg.dt_rank();
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
let x_proj = linear_no_bias(d_inner, dt_rank + D_STATE * 2, vb.pp("x_proj"))?;
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
let a_log = vb.get((d_inner, D_STATE), "A_log")?;
let d = vb.get(d_inner, "D")?;
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
let conv1d_bias = vb.get(d_inner, "conv1d.bias")?;
let conv1d_weight = vb.get((d_inner, 1, D_CONV), "conv1d.weight")?;
let conv1d_weights = [
conv1d_weight.i((.., 0, 0))?,
conv1d_weight.i((.., 0, 1))?,
conv1d_weight.i((.., 0, 2))?,
conv1d_weight.i((.., 0, 3))?,
];
Ok(Self {
in_proj,
conv1d_bias,
conv1d_weights,
x_proj,
dt_proj,
a_log,
d,
out_proj,
dt_rank,
layer_index,
d_inner,
})
}
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
let (b_sz, _dim) = xs.dims2()?;
let li = self.layer_index;
let mut xs = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
let proj_for_silu = xs.remove(1);
state.prev_xs[li][state.pos % D_CONV] = xs.remove(0);
let mut proj_for_conv = self.conv1d_bias.broadcast_as((b_sz, self.d_inner))?;
for d_c in 0..D_CONV {
proj_for_conv = (proj_for_conv
+ self.conv1d_weights[d_c]
.broadcast_mul(&state.prev_xs[li][(d_c + 1 + state.pos) % D_CONV])?)?;
}
let proj_for_conv = candle_nn::ops::silu(&proj_for_conv)?;
// SSM + Selection, we're doing inference here so only need the last step of
// the sequence.
// Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf
let x_proj = self.x_proj.forward(&proj_for_conv)?;
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?;
let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?;
let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?;
let delta = delta.apply(&self.dt_proj)?;
// softplus
let delta = (delta.exp()? + 1.)?.log()?;
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
let d = self.d.to_dtype(candle::DType::F32)?;
// Selective scan part
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
let delta = delta
.unsqueeze(D::Minus1)?
.broadcast_as((b_sz, self.d_inner, D_STATE))?;
let a = a.broadcast_as((b_sz, self.d_inner, D_STATE))?;
let b = b.broadcast_as((b_sz, self.d_inner, D_STATE))?;
let proj_for_conv_b =
proj_for_conv
.unsqueeze(D::Minus1)?
.broadcast_as((b_sz, self.d_inner, D_STATE))?;
state.hs[li] = ((&state.hs[li] * (&delta * &a)?.exp()?)? + &delta * &b * &proj_for_conv_b)?;
let ss = (state.hs[li]
.matmul(&c.unsqueeze(D::Minus1)?)?
.squeeze(D::Minus1)?
+ proj_for_conv.broadcast_mul(&d)?)?;
let ys = (ss * candle_nn::ops::silu(&proj_for_silu))?;
ys.apply(&self.out_proj)
}
}
#[derive(Clone, Debug)]
pub struct ResidualBlock {
mixer: MambaBlock,
norm: RmsNorm,
}
impl ResidualBlock {
pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
let mixer = MambaBlock::new(layer_index, cfg, vb.pp("mixer"))?;
Ok(Self { mixer, norm })
}
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
self.mixer.forward(&xs.apply(&self.norm)?, state)? + xs
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
#[derive(Clone, Debug)]
pub struct Model {
embedding: candle_nn::Embedding,
layers: Vec<ResidualBlock>,
norm_f: RmsNorm,
lm_head: Linear,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
let mut layers = Vec::with_capacity(cfg.n_layer);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.n_layer {
let layer = ResidualBlock::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
Ok(Self {
embedding,
layers,
norm_f,
lm_head,
})
}
pub fn forward(&self, input_ids: &Tensor, state: &mut State) -> Result<Tensor> {
let _b_size = input_ids.dims1()?;
let mut xs = self.embedding.forward(input_ids)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs, state)?
}
state.pos += 1;
xs.apply(&self.norm_f)?.apply(&self.lm_head)
}
}

View File

@ -8,7 +8,7 @@ use serde::Deserialize;
const MAX_SEQ_LEN: usize = 4096; const MAX_SEQ_LEN: usize = 4096;
// https://huggingface.co/microsoft/phi-1_5/blob/d38e6f954ec29b96fe2cf033937dad64e279b5d9/configuration_mixformer_sequential.py // https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
#[derive(Debug, Clone, PartialEq, Deserialize)] #[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config { pub struct Config {
pub(crate) vocab_size: usize, pub(crate) vocab_size: usize,

View File

@ -1,333 +0,0 @@
//! MobileOne inference implementation based on timm and candle-repvgg
//!
//! See "MobileOne: An Improved One millisecond Mobile Backbone"
//! https://arxiv.org/abs/2206.04040
use candle::{DType, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
Func, VarBuilder,
};
struct StageConfig {
blocks: usize,
channels: usize,
}
// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form
// by concatenating the 5th stage (starts with stride 1) to the previous one.
const STAGES: [StageConfig; 5] = [
StageConfig {
blocks: 1,
channels: 64,
},
StageConfig {
blocks: 2,
channels: 64,
},
StageConfig {
blocks: 8,
channels: 128,
},
StageConfig {
blocks: 10,
channels: 256,
},
StageConfig {
blocks: 1,
channels: 512,
},
];
#[derive(Clone)]
pub struct Config {
/// overparameterization factor
k: usize,
/// per-stage channel number multipliers
alphas: [f32; 5],
}
impl Config {
pub fn s0() -> Self {
Self {
k: 4,
alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
}
}
pub fn s1() -> Self {
Self {
k: 1,
alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
}
}
pub fn s2() -> Self {
Self {
k: 1,
alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
}
}
pub fn s3() -> Self {
Self {
k: 1,
alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
}
}
pub fn s4() -> Self {
Self {
k: 1,
alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
}
}
}
// SE blocks are used in the last stages of the s4 variant.
fn squeeze_and_excitation(
in_channels: usize,
squeeze_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
..Default::default()
};
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
Ok(Func::new(move |xs| {
let residual = xs;
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
residual.broadcast_mul(&xs)
}))
}
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
// based on the _fuse_bn_tensor method in timm
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
let (gamma, beta) = bn.weight_and_bias().unwrap();
let mu = bn.running_mean();
let sigma = (bn.running_var() + bn.eps())?.sqrt();
let gps = (gamma / sigma)?;
let bias = (beta - mu * &gps)?;
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
Ok((weights, bias))
}
// A mobileone block has a different training time and inference time architecture.
// The latter is a simple and efficient equivalent transformation of the former
// realized by a structural reparameterization technique, where convolutions
// along with identity branches and batchnorm layers are fused into a single convolution.
#[allow(clippy::too_many_arguments)]
fn mobileone_block(
has_identity: bool,
k: usize,
dim: usize,
stride: usize,
padding: usize,
groups: usize,
kernel: usize,
in_channels: usize,
out_channels: usize,
vb: VarBuilder,
) -> Result<Func<'static>> {
let conv2d_cfg = Conv2dConfig {
stride,
padding,
groups,
..Default::default()
};
let mut w = Tensor::zeros(
(out_channels, in_channels / groups, kernel, kernel),
DType::F32,
vb.device(),
)?;
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
// k is the training-time overparameterization factor, larger than 1 only in the s0 variant
for i in 0..k {
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
let conv_kxk = conv2d_no_bias(
in_channels,
out_channels,
kernel,
conv2d_cfg,
vb.pp(format!("conv_kxk.{i}.conv")),
)?;
let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
w = (w + wk)?;
b = (b + bk)?;
}
if kernel > 1 {
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
let conv_scale = conv2d_no_bias(
in_channels,
out_channels,
1,
conv2d_cfg,
vb.pp("conv_scale.conv"),
)?;
let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
// resize to 3x3
ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
w = (w + ws)?;
b = (b + bs)?;
}
// Use SE blocks if present (last layers of the s4 variant)
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
// read and reparameterize the identity bn into wi and bi
if has_identity {
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
let id = in_channels / groups;
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
for i in 0..in_channels {
if kernel > 1 {
weights[i * kernel * kernel + 4] = 1.0;
} else {
weights[i * (id + 1)] = 1.0;
}
}
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
w = (w + wi)?;
b = (b + bi)?;
}
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
Ok(Func::new(move |xs| {
let mut xs = xs.apply(&reparam_conv)?;
if let Ok(f) = &se {
xs = xs.apply(f)?;
}
xs = xs.relu()?;
Ok(xs)
}))
}
// Get the number of output channels per stage taking into account the multipliers
fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
let channels = STAGES[stage].channels as f32;
let alpha = cfg.alphas[stage];
match stage {
0 => std::cmp::min(64, (channels * alpha) as usize),
_ => (channels * alpha) as usize,
}
}
// Each stage is made of blocks. The first layer always downsamples with stride 2.
// All but the first block have a residual connection.
fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
let nblocks = STAGES[idx].blocks;
let mut blocks = Vec::with_capacity(nblocks);
let mut in_channels = output_channels_per_stage(cfg, idx - 1);
for block_idx in 0..nblocks {
let out_channels = output_channels_per_stage(cfg, idx);
let (has_identity, stride) = if block_idx == 0 {
(false, 2)
} else {
(true, 1)
};
// depthwise convolution layer
blocks.push(mobileone_block(
has_identity,
cfg.k,
in_channels,
stride,
1,
in_channels,
3,
in_channels,
in_channels,
vb.pp(block_idx * 2),
)?);
// pointwise convolution layer
blocks.push(mobileone_block(
has_identity,
cfg.k,
out_channels,
1, // stride
0, // padding
1, // groups
1, // kernel
in_channels,
out_channels,
vb.pp(block_idx * 2 + 1),
)?);
in_channels = out_channels;
}
Ok(Func::new(move |xs| {
let mut xs = xs.clone();
for block in blocks.iter() {
xs = xs.apply(block)?
}
Ok(xs)
}))
}
// Build a mobileone model for a given configuration.
fn mobileone_model(
config: &Config,
nclasses: Option<usize>,
vb: VarBuilder,
) -> Result<Func<'static>> {
let cls = match nclasses {
None => None,
Some(nclasses) => {
let outputs = output_channels_per_stage(config, 4);
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
Some(linear)
}
};
let stem_dim = output_channels_per_stage(config, 0);
let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
let vb = vb.pp("stages");
let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
Ok(Func::new(move |xs| {
let xs = xs
.apply(&stem)?
.apply(&stage1)?
.apply(&stage2)?
.apply(&stage3)?
.apply(&stage4)?
.mean(D::Minus2)?
.mean(D::Minus1)?;
match &cls {
None => Ok(xs),
Some(cls) => xs.apply(cls),
}
}))
}
pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
mobileone_model(cfg, Some(nclasses), vb)
}
pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
mobileone_model(cfg, None, vb)
}

View File

@ -2,9 +2,7 @@ pub mod bert;
pub mod bigcode; pub mod bigcode;
pub mod blip; pub mod blip;
pub mod blip_text; pub mod blip_text;
pub mod chatglm;
pub mod convmixer; pub mod convmixer;
pub mod convnext;
pub mod dinov2; pub mod dinov2;
pub mod distilbert; pub mod distilbert;
pub mod efficientnet; pub mod efficientnet;
@ -13,12 +11,10 @@ pub mod jina_bert;
pub mod llama; pub mod llama;
pub mod llama2_c; pub mod llama2_c;
pub mod llama2_c_weights; pub mod llama2_c_weights;
pub mod mamba;
pub mod marian; pub mod marian;
pub mod mistral; pub mod mistral;
pub mod mixformer; pub mod mixformer;
pub mod mixtral; pub mod mixtral;
pub mod mobileone;
pub mod mpt; pub mod mpt;
pub mod persimmon; pub mod persimmon;
pub mod phi; pub mod phi;
@ -31,7 +27,6 @@ pub mod quantized_mixformer;
pub mod quantized_mpt; pub mod quantized_mpt;
pub mod quantized_stable_lm; pub mod quantized_stable_lm;
pub mod quantized_t5; pub mod quantized_t5;
pub mod qwen2;
pub mod repvgg; pub mod repvgg;
pub mod resnet; pub mod resnet;
pub mod segment_anything; pub mod segment_anything;

View File

@ -16,7 +16,7 @@ struct RmsNorm {
impl RmsNorm { impl RmsNorm {
fn new(scale: QTensor, eps: f32) -> Result<Self> { fn new(scale: QTensor, eps: f32) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm"); let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = scale.dequantize(&scale.device())?; let scale = scale.dequantize(&Device::Cpu)?;
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64); let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
Ok(Self { inner, span }) Ok(Self { inner, span })
} }
@ -275,17 +275,13 @@ pub struct ModelWeights {
span_output: tracing::Span, span_output: tracing::Span,
} }
fn precomput_freqs_cis( fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
head_dim: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim) let theta: Vec<_> = (0..head_dim)
.step_by(2) .step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect(); .collect();
let theta = Tensor::new(theta.as_slice(), device)?; let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
.to_dtype(DType::F32)? .to_dtype(DType::F32)?
.reshape((MAX_SEQ_LEN, 1))? .reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?; .matmul(&theta.reshape((1, theta.elem_count()))?)?;
@ -296,10 +292,11 @@ fn precomput_freqs_cis(
impl ModelWeights { impl ModelWeights {
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> { pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
let cpu = &Device::Cpu;
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize; let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?; let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?; let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?; let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
@ -361,6 +358,7 @@ impl ModelWeights {
reader: &mut R, reader: &mut R,
device: &Device, device: &Device,
) -> Result<Self> { ) -> Result<Self> {
let cpu = &Device::Cpu;
let md_get = |s: &str| match ct.metadata.get(s) { let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"), None => candle::bail!("cannot find {s} in metadata"),
Some(v) => Ok(v), Some(v) => Ok(v),
@ -384,10 +382,10 @@ impl ModelWeights {
let rope_freq_base = md_get("llama.rope.freq_base") let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32()) .and_then(|m| m.to_f32())
.unwrap_or(10000f32); .unwrap_or(10000f32);
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?; let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?; let tok_embeddings = tok_embeddings.dequantize(cpu)?;
let norm = RmsNorm::new( let norm = RmsNorm::new(
ct.tensor(reader, "output_norm.weight", device)?, ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps, rms_norm_eps,
@ -474,14 +472,14 @@ impl ModelWeights {
}) })
} }
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> { fn mask(&mut self, t: usize) -> Result<Tensor> {
if let Some(mask) = self.masks.get(&t) { if let Some(mask) = self.masks.get(&t) {
Ok(mask.clone()) Ok(mask.clone())
} else { } else {
let mask: Vec<_> = (0..t) let mask: Vec<_> = (0..t)
.flat_map(|i| (0..t).map(move |j| u8::from(j > i))) .flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
.collect(); .collect();
let mask = Tensor::from_slice(&mask, (t, t), device)?; let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
self.masks.insert(t, mask.clone()); self.masks.insert(t, mask.clone());
Ok(mask) Ok(mask)
} }
@ -489,7 +487,7 @@ impl ModelWeights {
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> { pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?; let (_b_sz, seq_len) = x.dims2()?;
let mask = self.mask(seq_len, x.device())?; let mask = self.mask(seq_len)?;
let _enter = self.span.enter(); let _enter = self.span.enter();
let mut layer_in = self.tok_embeddings.forward(x)?; let mut layer_in = self.tok_embeddings.forward(x)?;
for layer in self.layers.iter_mut() { for layer in self.layers.iter_mut() {

View File

@ -1,4 +1,4 @@
use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear}; use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
pub use crate::quantized_var_builder::VarBuilder; pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, Device, Module, Result, Tensor, D}; use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm}; use candle_nn::{Activation, LayerNorm};
@ -67,14 +67,9 @@ impl Attention {
let head_dim = cfg.head_dim(); let head_dim = cfg.head_dim();
let num_heads = cfg.num_attention_heads; let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads; let num_kv_heads = cfg.num_key_value_heads;
let linear_layer = if cfg.use_qkv_bias { let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
linear let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
} else { let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
linear_no_bias
};
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self { Ok(Self {
q_proj, q_proj,

View File

@ -1,377 +0,0 @@
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub max_position_embeddings: usize,
pub sliding_window: usize,
pub max_window_layers: usize,
pub tie_word_embeddings: bool,
pub rope_theta: f64,
pub rms_norm_eps: f64,
pub use_sliding_window: bool,
pub hidden_act: Activation,
}
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
}
impl MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
Ok(Self {
gate_proj,
up_proj,
down_proj,
act_fn: cfg.hidden_act,
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
kv_cache: None,
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?.contiguous()?;
let value_states = self.repeat_kv(value_states)?.contiguous()?;
let attn_output = {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
sliding_window: usize,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
// Sliding window mask?
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + self.sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -1,11 +1,10 @@
use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use crate::models::with_tracing::{linear_no_bias, Linear};
use candle::{DType, Device, Module, Result, Tensor, D}; use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, LayerNorm, VarBuilder}; use candle_nn::{Activation, LayerNorm, VarBuilder};
use serde::Deserialize;
use std::sync::Arc; use std::sync::Arc;
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py // https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
#[derive(Debug, Clone, PartialEq, Deserialize)] #[derive(Debug, Clone, PartialEq)]
pub struct Config { pub struct Config {
pub(crate) vocab_size: usize, pub(crate) vocab_size: usize,
pub(crate) intermediate_size: usize, pub(crate) intermediate_size: usize,
@ -19,10 +18,7 @@ pub struct Config {
pub(crate) max_position_embeddings: usize, pub(crate) max_position_embeddings: usize,
pub(crate) norm_eps: f64, pub(crate) norm_eps: f64,
pub(crate) use_cache: bool, pub(crate) use_cache: bool,
#[serde(default)] pub(crate) use_flash_attn: bool,
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
#[serde(default)]
pub(crate) use_flash_attn: bool, // Not in config.json
} }
impl Config { impl Config {
@ -39,7 +35,6 @@ impl Config {
rope_theta: 10_000., rope_theta: 10_000.,
max_position_embeddings: 4096, max_position_embeddings: 4096,
norm_eps: 1e-5, norm_eps: 1e-5,
use_qkv_bias: false,
use_cache: true, use_cache: true,
use_flash_attn, use_flash_attn,
} }
@ -56,10 +51,6 @@ impl Config {
pub fn num_kv_groups(&self) -> usize { pub fn num_kv_groups(&self) -> usize {
self.num_attention_heads / self.num_key_value_heads self.num_attention_heads / self.num_key_value_heads
} }
pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
self.use_flash_attn = use_flash_attn
}
} }
#[derive(Debug)] #[derive(Debug)]
@ -188,15 +179,9 @@ impl Attention {
let head_dim = cfg.head_dim(); let head_dim = cfg.head_dim();
let num_heads = cfg.num_attention_heads; let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads; let num_kv_heads = cfg.num_key_value_heads;
let linear_layer = if cfg.use_qkv_bias { let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
linear let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
} else { let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
linear_no_bias
};
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self { Ok(Self {
q_proj, q_proj,

View File

@ -1,21 +1,15 @@
use crate::models::vit::{Config, Embeddings, Encoder}; use crate::models::vit::{Config, Embeddings, Encoder};
use candle::{DType, Result, Tensor}; use candle::{Result, Tensor};
use candle_nn::{ use candle_nn::{
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder, embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
}; };
use serde::Deserialize;
fn default_tie_word_embeddings() -> bool { #[derive(Debug, Clone, PartialEq, Deserialize)]
true
}
fn default_use_learned_position_embeddings() -> bool {
true
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct TrOCRConfig { pub struct TrOCRConfig {
pub vocab_size: usize, pub vocab_size: usize,
pub d_model: usize, pub d_model: usize,
pub cross_attention_hidden_size: usize, pub hidden_size: usize,
pub decoder_layers: usize, pub decoder_layers: usize,
pub decoder_attention_heads: usize, pub decoder_attention_heads: usize,
pub decoder_ffn_dim: usize, pub decoder_ffn_dim: usize,
@ -29,14 +23,13 @@ pub struct TrOCRConfig {
pub decoder_layerdrop: f64, pub decoder_layerdrop: f64,
pub use_cache: bool, pub use_cache: bool,
pub scale_embedding: bool, pub scale_embedding: bool,
pub use_learned_position_embeddings: bool,
pub layernorm_embedding: bool,
pub pad_token_id: usize, pub pad_token_id: usize,
pub bos_token_id: usize, pub bos_token_id: usize,
pub eos_token_id: u32, pub eos_token_id: u32,
pub num_attention_heads: usize,
pub decoder_vocab_size: Option<usize>, pub decoder_vocab_size: Option<usize>,
#[serde(default = "default_use_learned_position_embeddings")]
pub use_learned_position_embeddings: bool,
#[serde(default = "default_tie_word_embeddings")]
pub tie_word_embeddings: bool,
} }
impl Default for TrOCRConfig { impl Default for TrOCRConfig {
@ -44,7 +37,7 @@ impl Default for TrOCRConfig {
Self { Self {
vocab_size: 50265, vocab_size: 50265,
d_model: 1024, d_model: 1024,
cross_attention_hidden_size: 768, hidden_size: 768,
decoder_layers: 12, decoder_layers: 12,
decoder_attention_heads: 16, decoder_attention_heads: 16,
decoder_ffn_dim: 4096, decoder_ffn_dim: 4096,
@ -58,12 +51,13 @@ impl Default for TrOCRConfig {
decoder_layerdrop: 0.0, decoder_layerdrop: 0.0,
use_cache: true, use_cache: true,
scale_embedding: false, scale_embedding: false,
use_learned_position_embeddings: true,
layernorm_embedding: true,
pad_token_id: 1, pad_token_id: 1,
bos_token_id: 0, bos_token_id: 0,
eos_token_id: 2, eos_token_id: 2,
num_attention_heads: 12,
decoder_vocab_size: Some(50265), decoder_vocab_size: Some(50265),
use_learned_position_embeddings: true,
tie_word_embeddings: true,
} }
} }
} }
@ -84,49 +78,17 @@ impl TrOCRLearnedPositionalEmbedding {
Ok(Self { offset, weights }) Ok(Self { offset, weights })
} }
fn new_sinusoidal(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
// https://github.com/huggingface/transformers/blob/58e3d23e97078f361a533b9ec4a6a2de674ea52a/src/transformers/models/trocr/modeling_trocr.py#L81
let embedding_dim = cfg.d_model;
let half_dim = embedding_dim / 2;
let num_positions = cfg.max_position_embeddings + cfg.pad_token_id + 1;
let dev = vb.device();
let inv_freq: Vec<_> = (0..half_dim)
.map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
let t = Tensor::arange(0u32, num_positions as u32, dev)?
.to_dtype(DType::F32)?
.reshape((num_positions, 1))?;
let freqs = t.matmul(&inv_freq)?;
let emb = Tensor::cat(&[freqs.sin()?, freqs.cos()?], 1)?;
let emb = Tensor::cat(
&[
emb.narrow(0, 0, cfg.pad_token_id)?,
Tensor::zeros((1, embedding_dim), DType::F32, dev)?,
emb.narrow(0, cfg.pad_token_id + 1, cfg.max_position_embeddings)?,
],
0,
)?
.contiguous()?;
let emb = Embedding::new(emb, embedding_dim);
Ok(Self {
offset: cfg.pad_token_id + 1,
weights: emb,
})
}
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> { fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
let (b_sz, seq_len) = input_ids.dims2()?; let (b_sz, seq_len) = input_ids.dims2()?;
let positions = Tensor::arange( let mut positions = Tensor::arange(
past_key_values_length, past_key_values_length,
seq_len as u32 + past_key_values_length, seq_len as u32 + past_key_values_length,
input_ids.device(), input_ids.device(),
)? )?
.expand((b_sz, seq_len))?; .expand((b_sz, seq_len))?;
let positions = positions =
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?; positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
self.weights.forward(&positions) self.weights.forward(&positions)
} }
@ -259,17 +221,19 @@ impl TrOCRDecoderLayer {
let encoder_attn = TrOCRAttention::load( let encoder_attn = TrOCRAttention::load(
vb.pp("encoder_attn"), vb.pp("encoder_attn"),
cfg, cfg,
Some(cfg.cross_attention_hidden_size), Some(cfg.hidden_size),
Some(cfg.cross_attention_hidden_size), Some(cfg.hidden_size),
)?; )?;
let encoder_attn_layer_norm = let encoder_attn_layer_norm =
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?; layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?; let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?; let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?; let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
let activation_fn = candle_nn::Activation::Gelu;
Ok(Self { Ok(Self {
self_attn, self_attn,
activation_fn: cfg.activation_function, activation_fn,
self_attn_layer_norm, self_attn_layer_norm,
encoder_attn, encoder_attn,
encoder_attn_layer_norm, encoder_attn_layer_norm,
@ -330,11 +294,7 @@ impl TrOCRDecoder {
let vb = vb.pp("decoder.model.decoder"); let vb = vb.pp("decoder.model.decoder");
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?; let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
let embed_positions = if cfg.use_learned_position_embeddings { let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?
} else {
TrOCRLearnedPositionalEmbedding::new_sinusoidal(vb.pp("embed_positions"), cfg)?
};
let mut layers = Vec::with_capacity(cfg.decoder_layers); let mut layers = Vec::with_capacity(cfg.decoder_layers);
let vb_l = vb.pp("layers"); let vb_l = vb.pp("layers");
for idx in 0..cfg.decoder_layers { for idx in 0..cfg.decoder_layers {
@ -423,15 +383,8 @@ pub struct TrOCRForCausalLM {
impl TrOCRForCausalLM { impl TrOCRForCausalLM {
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> { pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?; let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
let output_projection = if decoder_cfg.tie_word_embeddings { let output_projection =
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None) candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
} else {
candle_nn::linear_no_bias(
decoder_cfg.d_model,
decoder_cfg.vocab_size,
vb.pp("decoder.output_projection"),
)?
};
Ok(Self { Ok(Self {
decoder, decoder,
output_projection, output_projection,

Some files were not shown because too many files have changed in this diff Show More