mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Compare commits
11 Commits
bf16_metal
...
metal5
Author | SHA1 | Date | |
---|---|---|---|
67d93b4f42 | |||
c35d7d50db | |||
9694671bbf | |||
3dbf65ef20 | |||
b2db5adf82 | |||
9ef040338d | |||
3aefc709c7 | |||
c8c603ce96 | |||
61ad8d91cc | |||
2cd1e59c9e | |||
9c4b4f0da0 |
74
.github/workflows/ci_cuda.yaml
vendored
74
.github/workflows/ci_cuda.yaml
vendored
@ -5,15 +5,49 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
start-runner:
|
||||||
|
name: Start self-hosted EC2 runner
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
# Don't run on forks, they won't have access to secrets anyway.
|
||||||
|
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
||||||
|
env:
|
||||||
|
AWS_REGION: us-east-1
|
||||||
|
EC2_AMI_ID: ami-03cfed9ea28f4b002
|
||||||
|
EC2_INSTANCE_TYPE: g5.xlarge
|
||||||
|
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
|
||||||
|
EC2_SECURITY_GROUP: sg-030175c435ac141d6
|
||||||
|
outputs:
|
||||||
|
label: ${{ steps.start-ec2-runner.outputs.label }}
|
||||||
|
ec2-instance-id: ${{ steps.start-ec2-runner.outputs.ec2-instance-id }}
|
||||||
|
steps:
|
||||||
|
- name: Configure AWS credentials
|
||||||
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
with:
|
||||||
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
aws-region: ${{ env.AWS_REGION }}
|
||||||
|
- name: Start EC2 runner
|
||||||
|
id: start-ec2-runner
|
||||||
|
uses: philschmid/philschmid-ec2-github-runner@main
|
||||||
|
with:
|
||||||
|
mode: start
|
||||||
|
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||||
|
ec2-image-id: ${{ env.EC2_AMI_ID }}
|
||||||
|
ec2-instance-type: ${{ env.EC2_INSTANCE_TYPE }}
|
||||||
|
subnet-id: ${{ env.EC2_SUBNET_ID }}
|
||||||
|
security-group-id: ${{ env.EC2_SECURITY_GROUP }}
|
||||||
|
aws-resource-tags: > # optional, requires additional permissions
|
||||||
|
[
|
||||||
|
{"Key": "Name", "Value": "ec2-tgi-github-runner"},
|
||||||
|
{"Key": "GitHubRepository", "Value": "${{ github.repository }}"}
|
||||||
|
]
|
||||||
|
|
||||||
test-cuda:
|
test-cuda:
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
runs-on: [single-gpu, nvidia-gpu, t4, ci]
|
needs: start-runner # required to start the main job when the runner is ready
|
||||||
container:
|
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
|
||||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
|
||||||
options: --gpus 0
|
|
||||||
if: ${{ github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }}
|
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
packages: write
|
packages: write
|
||||||
@ -24,10 +58,32 @@ jobs:
|
|||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
- name: Install dependencies
|
|
||||||
run: apt-get update && apt install curl build-essential libssl-dev protobuf-compiler pkg-config -y
|
|
||||||
- name: Install Rust Stable
|
- name: Install Rust Stable
|
||||||
uses: actions-rust-lang/setup-rust-toolchain@v1
|
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
|
||||||
- uses: Swatinem/rust-cache@v2
|
- uses: Swatinem/rust-cache@v2
|
||||||
|
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
|
||||||
- name: Test (cuda)
|
- name: Test (cuda)
|
||||||
run: cargo test --features cuda
|
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
|
||||||
|
stop-runner:
|
||||||
|
name: Stop self-hosted EC2 runner
|
||||||
|
needs:
|
||||||
|
- start-runner
|
||||||
|
- test-cuda
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
AWS_REGION: us-east-1
|
||||||
|
if: ${{ (success() || failure()) && github.event.pull_request.head.repo.full_name == github.event.pull_request.base.repo.full_name }} # required to stop the runner even if the error happened in the previous jobs
|
||||||
|
steps:
|
||||||
|
- name: Configure AWS credentials
|
||||||
|
uses: aws-actions/configure-aws-credentials@v1
|
||||||
|
with:
|
||||||
|
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||||
|
aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||||
|
aws-region: ${{ env.AWS_REGION }}
|
||||||
|
- name: Stop EC2 runner
|
||||||
|
uses: philschmid/philschmid-ec2-github-runner@main
|
||||||
|
with:
|
||||||
|
mode: stop
|
||||||
|
github-token: ${{ secrets.GH_PERSONAL_ACCESS_TOKEN }}
|
||||||
|
label: ${{ needs.start-runner.outputs.label }}
|
||||||
|
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||||
|
22
Cargo.toml
22
Cargo.toml
@ -19,7 +19,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.4.0"
|
version = "0.3.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -31,14 +31,14 @@ license = "MIT OR Apache-2.0"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
|
candle = { path = "./candle-core", package = "candle-core" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
|
candle-datasets = { path = "./candle-datasets" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
|
candle-flash-attn = { path = "./candle-flash-attn" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
|
candle-kernels = { path = "./candle-kernels" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
|
candle-metal-kernels = { path = "./candle-metal-kernels" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.4.0" }
|
candle-nn = { path = "./candle-nn" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
|
candle-onnx = { path = "./candle-onnx" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
|
candle-transformers = { path = "./candle-transformers" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||||
@ -53,12 +53,12 @@ log = "0.4"
|
|||||||
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
memmap2 = { version = "0.9.3", features = ["stable_deref_trait"] }
|
||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
num-traits = "0.2.15"
|
num-traits = "0.2.15"
|
||||||
parquet = { version = "50.0.0" }
|
parquet = { version = "45.0.0" }
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
rand_distr = "0.4.3"
|
rand_distr = "0.4.3"
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
rusttype = { version = "0.9", default-features = false }
|
rusttype = { version = "0.9", default-features = false }
|
||||||
safetensors = "0.4.1"
|
safetensors = "0.3.1"
|
||||||
serde = { version = "1.0.171", features = ["derive"] }
|
serde = { version = "1.0.171", features = ["derive"] }
|
||||||
serde_plain = "1.0.2"
|
serde_plain = "1.0.2"
|
||||||
serde_json = "1.0.99"
|
serde_json = "1.0.99"
|
||||||
|
15
README.md
15
README.md
@ -65,9 +65,8 @@ We also provide a some command line based examples using state of the art models
|
|||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
StableLM-2, a 1.6b LLM trained on 2T tokens, as well as the code variants.
|
- [Minimal Mamba](./candle-examples/examples/mamba-minimal/): a minimal
|
||||||
- [Mamba](./candle-examples/examples/mamba/): an inference only
|
|
||||||
implementation of the Mamba state space model.
|
implementation of the Mamba state space model.
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
better performance than all publicly available 13b models as of 2023-09-28.
|
better performance than all publicly available 13b models as of 2023-09-28.
|
||||||
@ -112,10 +111,9 @@ We also provide a some command line based examples using state of the art models
|
|||||||
evaluation, segmentation).
|
evaluation, segmentation).
|
||||||
- [VGG](./candle-examples/examples/vgg/),
|
- [VGG](./candle-examples/examples/vgg/),
|
||||||
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
[RepVGG](./candle-examples/examples/repvgg): computer vision models.
|
||||||
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
|
||||||
generate captions for an image.
|
generate captions for an image.
|
||||||
- [TrOCR](./candle-examples/examples/trocr/): a transformer OCR model, with
|
|
||||||
dedicated submodels for hand-writing and printed recognition.
|
|
||||||
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
|
||||||
model, generates the translated text from the input text.
|
model, generates the translated text from the input text.
|
||||||
|
|
||||||
@ -186,10 +184,10 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Falcon.
|
- Falcon.
|
||||||
- StarCoder.
|
- StarCoder.
|
||||||
- Phi 1, 1.5, and 2.
|
- Phi 1, 1.5, and 2.
|
||||||
- Mamba, Minimal Mamba
|
- Minimal Mamba
|
||||||
- Mistral 7b v0.1.
|
- Mistral 7b v0.1.
|
||||||
- Mixtral 8x7b v0.1.
|
- Mixtral 8x7b v0.1.
|
||||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
- StableLM-3B-4E1T.
|
||||||
- Replit-code-v1.5-3B.
|
- Replit-code-v1.5-3B.
|
||||||
- Bert.
|
- Bert.
|
||||||
- Yi-6B and Yi-34B.
|
- Yi-6B and Yi-34B.
|
||||||
@ -208,9 +206,8 @@ If you have an addition to this list, please submit a pull request.
|
|||||||
- Wurstchen v2.
|
- Wurstchen v2.
|
||||||
- Image to text.
|
- Image to text.
|
||||||
- BLIP.
|
- BLIP.
|
||||||
- TrOCR.
|
|
||||||
- Computer Vision Models.
|
- Computer Vision Models.
|
||||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT.
|
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG.
|
||||||
- yolo-v3, yolo-v8.
|
- yolo-v3, yolo-v8.
|
||||||
- Segment-Anything Model (SAM).
|
- Segment-Anything Model (SAM).
|
||||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||||
|
@ -2,8 +2,7 @@ mod benchmarks;
|
|||||||
|
|
||||||
use criterion::criterion_main;
|
use criterion::criterion_main;
|
||||||
criterion_main!(
|
criterion_main!(
|
||||||
benchmarks::affine::benches,
|
|
||||||
benchmarks::matmul::benches,
|
benchmarks::matmul::benches,
|
||||||
benchmarks::random::benches,
|
benchmarks::affine::benches,
|
||||||
benchmarks::where_cond::benches
|
benchmarks::where_cond::benches
|
||||||
);
|
);
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
pub(crate) mod affine;
|
pub(crate) mod affine;
|
||||||
pub(crate) mod matmul;
|
pub(crate) mod matmul;
|
||||||
pub(crate) mod random;
|
|
||||||
pub(crate) mod where_cond;
|
pub(crate) mod where_cond;
|
||||||
|
|
||||||
use candle_core::{Device, Result};
|
use candle_core::{Device, Result};
|
||||||
|
@ -1,63 +0,0 @@
|
|||||||
use crate::benchmarks::{BenchDevice, BenchDeviceHandler};
|
|
||||||
use candle_core::{DType, Device, Tensor};
|
|
||||||
use criterion::{black_box, criterion_group, Criterion, Throughput};
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
fn rand_uniform(a: &Tensor) {
|
|
||||||
a.rand_like(-1.0, 123.0).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rand_normal(a: &Tensor) {
|
|
||||||
a.randn_like(100.0, 15.0).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run_random_bench(c: &mut Criterion, device: &Device) {
|
|
||||||
let b = 1;
|
|
||||||
|
|
||||||
let rows = 2048;
|
|
||||||
let cols = 2048;
|
|
||||||
|
|
||||||
let dtype = DType::F32;
|
|
||||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
|
||||||
|
|
||||||
let flops = b * rows * cols * dtype.size_in_bytes();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("random_uniform"));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |benches| {
|
|
||||||
benches.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
rand_uniform(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
|
|
||||||
let tensor = Tensor::zeros((b, rows, cols), dtype, device).unwrap();
|
|
||||||
|
|
||||||
let mut group = c.benchmark_group(device.bench_name("random_normal"));
|
|
||||||
group.throughput(Throughput::Bytes(flops as u64));
|
|
||||||
group.bench_function("iter", move |benches| {
|
|
||||||
benches.iter_custom(|iters| {
|
|
||||||
let start = Instant::now();
|
|
||||||
for _i in 0..iters {
|
|
||||||
rand_normal(black_box(&tensor));
|
|
||||||
}
|
|
||||||
device.sync().unwrap();
|
|
||||||
start.elapsed()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
group.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
fn criterion_benchmark(c: &mut Criterion) {
|
|
||||||
let handler = BenchDeviceHandler::new().unwrap();
|
|
||||||
for device in handler.devices {
|
|
||||||
run_random_bench(c, &device);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
criterion_group!(benches, criterion_benchmark);
|
|
@ -196,7 +196,7 @@ fn run_ls(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Pth => {
|
Format::Pth => {
|
||||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
|
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
|
||||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
for tensor_info in tensors.iter() {
|
for tensor_info in tensors.iter() {
|
||||||
println!(
|
println!(
|
||||||
|
@ -175,7 +175,7 @@ impl Tensor {
|
|||||||
// the backprop graph of the backprop itself. This would be an issue for second order
|
// the backprop graph of the backprop itself. This would be an issue for second order
|
||||||
// derivatives but these are out of scope at the moment.
|
// derivatives but these are out of scope at the moment.
|
||||||
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
|
||||||
let grad = if do_not_detach { grad } else { grad.detach() };
|
let grad = if do_not_detach { grad } else { grad.detach()? };
|
||||||
if let Some(op) = node.op() {
|
if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
Op::Binary(lhs, rhs, BinaryOp::Add) => {
|
||||||
|
@ -1149,55 +1149,6 @@ impl<'a> Map2 for Conv2D<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
|
||||||
&self,
|
|
||||||
inp: &CudaSlice<T>,
|
|
||||||
inp_l: &Layout,
|
|
||||||
k: &CudaSlice<T>,
|
|
||||||
k_l: &Layout,
|
|
||||||
dev: &CudaDevice,
|
|
||||||
) -> Result<CudaSlice<T>> {
|
|
||||||
// Kernel shape: (c_in_k, c_out, l_k)
|
|
||||||
// Input shape: (b_size, c_in, l_in)
|
|
||||||
let p = &self.0;
|
|
||||||
let l_out = p.l_out();
|
|
||||||
let dst_el = p.c_out * l_out * p.b_size;
|
|
||||||
let inp = &inp.slice(inp_l.start_offset()..);
|
|
||||||
let k = &k.slice(k_l.start_offset()..);
|
|
||||||
let shape = inp_l.shape();
|
|
||||||
let dims = shape.dims();
|
|
||||||
let el = shape.elem_count();
|
|
||||||
|
|
||||||
// SAFETY: Set later by running the kernel.
|
|
||||||
let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
|
||||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>("conv_transpose1d"), kernels::CONV)?;
|
|
||||||
let ds = if dims.len() == 3 {
|
|
||||||
[dims, inp_l.stride(), k_l.dims(), k_l.stride()].concat()
|
|
||||||
} else {
|
|
||||||
crate::bail!("unexpected input shape for conv_transpose1d {dims:?}")
|
|
||||||
};
|
|
||||||
let ds = dev.htod_copy(ds).w()?;
|
|
||||||
let params = (
|
|
||||||
el,
|
|
||||||
l_out,
|
|
||||||
p.stride,
|
|
||||||
p.padding,
|
|
||||||
p.output_padding,
|
|
||||||
p.dilation,
|
|
||||||
&ds,
|
|
||||||
inp,
|
|
||||||
k,
|
|
||||||
&out,
|
|
||||||
);
|
|
||||||
// SAFETY: ffi.
|
|
||||||
unsafe { func.launch(cfg, params) }.w()?;
|
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
struct ConvTranspose2D<'a>(&'a crate::conv::ParamsConvTranspose2D);
|
||||||
impl<'a> Map2 for ConvTranspose2D<'a> {
|
impl<'a> Map2 for ConvTranspose2D<'a> {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
@ -1859,15 +1810,12 @@ impl BackendStorage for CudaStorage {
|
|||||||
|
|
||||||
fn conv_transpose1d(
|
fn conv_transpose1d(
|
||||||
&self,
|
&self,
|
||||||
l: &Layout,
|
_: &Layout,
|
||||||
kernel: &Self,
|
_: &Self,
|
||||||
kernel_l: &Layout,
|
_: &Layout,
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
_: &crate::conv::ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
todo!()
|
||||||
let slice =
|
|
||||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
|
||||||
Ok(Self { slice, device })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(not(feature = "cudnn"))]
|
#[cfg(not(feature = "cudnn"))]
|
||||||
|
@ -7,9 +7,8 @@ use candle_metal_kernels::Kernels;
|
|||||||
use metal;
|
use metal;
|
||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
use std::sync::{Arc, RwLock, TryLockError};
|
||||||
|
|
||||||
/// Simple way to catch lock error without
|
/// Simple way to catch lock error without
|
||||||
/// depending on T
|
/// depending on T
|
||||||
@ -102,8 +101,6 @@ pub struct MetalDevice {
|
|||||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||||
/// (strong_count = 1).
|
/// (strong_count = 1).
|
||||||
buffers: AllocatedBuffers,
|
buffers: AllocatedBuffers,
|
||||||
/// Seed for random number generation.
|
|
||||||
seed: Arc<Mutex<Buffer>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -228,7 +225,7 @@ impl MetalDevice {
|
|||||||
// The slice might not live long enough for metal
|
// The slice might not live long enough for metal
|
||||||
// To actually fill the GPU buffer.
|
// To actually fill the GPU buffer.
|
||||||
// Putting this wait forces the GPU buffer to be filled
|
// Putting this wait forces the GPU buffer to be filled
|
||||||
// with the actual data allowing the CPU storage to do
|
// with the actual data allowing the CPU storage todo
|
||||||
// deallocate properly.
|
// deallocate properly.
|
||||||
self.wait_until_completed()?;
|
self.wait_until_completed()?;
|
||||||
Ok(real)
|
Ok(real)
|
||||||
@ -1557,11 +1554,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
Ok(val) => val.parse()?,
|
Ok(val) => val.parse()?,
|
||||||
_ => 10,
|
_ => 10,
|
||||||
};
|
};
|
||||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
|
||||||
[299792458].as_ptr() as *const c_void,
|
|
||||||
4,
|
|
||||||
MTLResourceOptions::StorageModeManaged,
|
|
||||||
)));
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
device,
|
device,
|
||||||
command_queue,
|
command_queue,
|
||||||
@ -1570,10 +1562,13 @@ impl BackendDevice for MetalDevice {
|
|||||||
compute_per_buffer,
|
compute_per_buffer,
|
||||||
buffers,
|
buffers,
|
||||||
kernels,
|
kernels,
|
||||||
seed,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||||
|
crate::bail!("Metal set_seed not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
fn location(&self) -> crate::DeviceLocation {
|
fn location(&self) -> crate::DeviceLocation {
|
||||||
crate::DeviceLocation::Metal {
|
crate::DeviceLocation::Metal {
|
||||||
gpu_id: self.registry_id() as usize,
|
gpu_id: self.registry_id() as usize,
|
||||||
@ -1613,31 +1608,12 @@ impl BackendDevice for MetalDevice {
|
|||||||
&self,
|
&self,
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
min: f64,
|
mean: f64,
|
||||||
max: f64,
|
stddev: f64,
|
||||||
) -> Result<Self::Storage> {
|
) -> Result<Self::Storage> {
|
||||||
let name = match dtype {
|
// TODO is there a better way ?
|
||||||
DType::F32 => "rand_uniform_f32",
|
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||||
DType::F16 => "rand_uniform_f16",
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
DType::BF16 => "rand_uniform_bf16",
|
|
||||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
|
||||||
};
|
|
||||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_uniform")?;
|
|
||||||
let command_buffer = self.command_buffer()?;
|
|
||||||
candle_metal_kernels::call_random_uniform(
|
|
||||||
&self.device,
|
|
||||||
&command_buffer,
|
|
||||||
&self.kernels,
|
|
||||||
name,
|
|
||||||
min as f32,
|
|
||||||
max as f32,
|
|
||||||
shape.elem_count(),
|
|
||||||
&*self.seed.lock().unwrap(),
|
|
||||||
&buffer,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
|
|
||||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_normal(
|
fn rand_normal(
|
||||||
@ -1647,43 +1623,9 @@ impl BackendDevice for MetalDevice {
|
|||||||
mean: f64,
|
mean: f64,
|
||||||
stddev: f64,
|
stddev: f64,
|
||||||
) -> Result<Self::Storage> {
|
) -> Result<Self::Storage> {
|
||||||
let name = match dtype {
|
// TODO is there a better way ?
|
||||||
DType::F32 => "rand_normal_f32",
|
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||||
DType::F16 => "rand_normal_f16",
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
DType::BF16 => "rand_normal_bf16",
|
|
||||||
dtype => crate::bail!("rand_uniform not implemented for {dtype:?}"),
|
|
||||||
};
|
|
||||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "rand_normal")?;
|
|
||||||
let command_buffer = self.command_buffer()?;
|
|
||||||
candle_metal_kernels::call_random_normal(
|
|
||||||
&self.device,
|
|
||||||
&command_buffer,
|
|
||||||
&self.kernels,
|
|
||||||
name,
|
|
||||||
mean as f32,
|
|
||||||
stddev as f32,
|
|
||||||
shape.elem_count(),
|
|
||||||
&*self.seed.lock().unwrap(),
|
|
||||||
&buffer,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
|
|
||||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
|
||||||
let seed: u32 = seed.try_into().map_err(|_| {
|
|
||||||
MetalError::Message("Metal seed must be less than or equal to u32::MAX".to_string())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
|
|
||||||
let contents = seed_buffer.contents();
|
|
||||||
unsafe {
|
|
||||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
|
|
||||||
}
|
|
||||||
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -217,13 +217,6 @@ impl Object {
|
|||||||
let args = args.remove(1);
|
let args = args.remove(1);
|
||||||
(callable, args)
|
(callable, args)
|
||||||
}
|
}
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._utils" && class_name == "_rebuild_parameter" => {
|
|
||||||
let mut args = args.tuple()?;
|
|
||||||
args.remove(0).reduce()?
|
|
||||||
}
|
|
||||||
_ => (callable, args),
|
_ => (callable, args),
|
||||||
};
|
};
|
||||||
match callable {
|
match callable {
|
||||||
@ -234,11 +227,13 @@ impl Object {
|
|||||||
_ => return Ok(None),
|
_ => return Ok(None),
|
||||||
};
|
};
|
||||||
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||||
|
let mut path = dir_name.to_path_buf();
|
||||||
|
path.push(file_path);
|
||||||
Ok(Some(TensorInfo {
|
Ok(Some(TensorInfo {
|
||||||
name,
|
name,
|
||||||
dtype,
|
dtype,
|
||||||
layout,
|
layout,
|
||||||
path: format!("{}/{}", dir_name.to_string_lossy(), file_path),
|
path: path.to_string_lossy().into_owned(),
|
||||||
storage_size,
|
storage_size,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
@ -350,10 +345,8 @@ impl Stack {
|
|||||||
module_name,
|
module_name,
|
||||||
class_name,
|
class_name,
|
||||||
} => {
|
} => {
|
||||||
if module_name == "collections"
|
if module_name == "collections" && class_name == "OrderedDict" {
|
||||||
&& (class_name == "OrderedDict" || class_name == "defaultdict")
|
// TODO: have a separate ordered dict.
|
||||||
{
|
|
||||||
// TODO: have a separate ordered dict and a separate default dict.
|
|
||||||
Some(Object::Dict(vec![]))
|
Some(Object::Dict(vec![]))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@ -634,16 +627,9 @@ pub struct TensorInfo {
|
|||||||
pub storage_size: usize,
|
pub storage_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read the tensor info from a .pth file.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `file` - The path to the .pth file.
|
|
||||||
/// * `verbose` - Whether to print debug information.
|
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file.
|
|
||||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
||||||
file: P,
|
file: P,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
key: Option<&str>,
|
|
||||||
) -> Result<Vec<TensorInfo>> {
|
) -> Result<Vec<TensorInfo>> {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let zip_reader = std::io::BufReader::new(file);
|
let zip_reader = std::io::BufReader::new(file);
|
||||||
@ -665,9 +651,8 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
stack.read_loop(&mut reader)?;
|
stack.read_loop(&mut reader)?;
|
||||||
let obj = stack.finalize()?;
|
let obj = stack.finalize()?;
|
||||||
if VERBOSE || verbose {
|
if VERBOSE || verbose {
|
||||||
println!("{obj:#?}");
|
println!("{obj:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
let obj = match obj {
|
let obj = match obj {
|
||||||
Object::Build { callable, args } => match *callable {
|
Object::Build { callable, args } => match *callable {
|
||||||
Object::Reduce { callable, args: _ } => match *callable {
|
Object::Reduce { callable, args: _ } => match *callable {
|
||||||
@ -681,24 +666,6 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
},
|
},
|
||||||
obj => obj,
|
obj => obj,
|
||||||
};
|
};
|
||||||
|
|
||||||
// If key is provided, then we need to extract the state_dict from the object.
|
|
||||||
let obj = if let Some(key) = key {
|
|
||||||
if let Object::Dict(key_values) = obj {
|
|
||||||
key_values
|
|
||||||
.into_iter()
|
|
||||||
.find(|(k, _)| *k == Object::Unicode(key.to_owned()))
|
|
||||||
.map(|(_, v)| v)
|
|
||||||
.ok_or_else(|| E::Msg(format!("key {key} not found")))?
|
|
||||||
} else {
|
|
||||||
obj
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
obj
|
|
||||||
};
|
|
||||||
|
|
||||||
// If the object is a dict, then we can extract the tensor info from it.
|
|
||||||
// NOTE: We are assuming that the `obj` is state_dict by this stage.
|
|
||||||
if let Object::Dict(key_values) = obj {
|
if let Object::Dict(key_values) = obj {
|
||||||
for (name, value) in key_values.into_iter() {
|
for (name, value) in key_values.into_iter() {
|
||||||
match value.into_tensor_info(name, &dir_name) {
|
match value.into_tensor_info(name, &dir_name) {
|
||||||
@ -721,8 +688,8 @@ pub struct PthTensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PthTensors {
|
impl PthTensors {
|
||||||
pub fn new<P: AsRef<std::path::Path>>(path: P, key: Option<&str>) -> Result<Self> {
|
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||||
let tensor_infos = read_pth_tensor_info(path.as_ref(), false, key)?;
|
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
|
||||||
let tensor_infos = tensor_infos
|
let tensor_infos = tensor_infos
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|ti| (ti.name.to_string(), ti))
|
.map(|ti| (ti.name.to_string(), ti))
|
||||||
@ -745,12 +712,10 @@ impl PthTensors {
|
|||||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
|
||||||
let rank = tensor_info.layout.shape().rank();
|
|
||||||
|
|
||||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||||
// case and when the tensor is fortran contiguous.
|
// case.
|
||||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
if !tensor_info.layout.is_contiguous() {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"cannot retrieve non-contiguous tensors {:?}",
|
"cannot retrieve non-contiguous tensors {:?}",
|
||||||
tensor_info.layout
|
tensor_info.layout
|
||||||
@ -768,33 +733,13 @@ impl PthTensors {
|
|||||||
tensor_info.dtype,
|
tensor_info.dtype,
|
||||||
&mut reader,
|
&mut reader,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
if rank > 1 && is_fortran_contiguous {
|
|
||||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
|
||||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
|
||||||
let tensor = tensor.reshape(shape_reversed)?;
|
|
||||||
|
|
||||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
|
||||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
|
||||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
|
||||||
Ok(Some(tensor))
|
Ok(Some(tensor))
|
||||||
} else {
|
|
||||||
Ok(Some(tensor))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read all the tensors from a PyTorch pth file with a given key.
|
/// Read all the tensors from a PyTorch pth file.
|
||||||
///
|
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
||||||
/// # Arguments
|
let pth = PthTensors::new(path)?;
|
||||||
/// * `path` - Path to the pth file.
|
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
|
||||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
|
||||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
|
||||||
path: P,
|
|
||||||
key: Option<&str>,
|
|
||||||
) -> Result<Vec<(String, Tensor)>> {
|
|
||||||
let pth = PthTensors::new(path, key)?;
|
|
||||||
let tensor_names = pth.tensor_infos.keys();
|
let tensor_names = pth.tensor_infos.keys();
|
||||||
let mut tensors = Vec::with_capacity(tensor_names.len());
|
let mut tensors = Vec::with_capacity(tensor_names.len());
|
||||||
for name in tensor_names {
|
for name in tensor_names {
|
||||||
@ -804,11 +749,3 @@ pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
|||||||
}
|
}
|
||||||
Ok(tensors)
|
Ok(tensors)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read all the tensors from a PyTorch pth file.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * `path` - Path to the pth file.
|
|
||||||
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
|
|
||||||
read_all_with_key(path, None)
|
|
||||||
}
|
|
||||||
|
@ -1,43 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
use super::GgmlDType;
|
|
||||||
use crate::{Error, MetalDevice, MetalStorage, Result};
|
|
||||||
|
|
||||||
pub struct QMetalStorage {
|
|
||||||
dtype: GgmlDType,
|
|
||||||
device: MetalDevice,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl QMetalStorage {
|
|
||||||
pub fn zeros(_: &MetalDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
|
||||||
self.dtype
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn device(&self) -> &MetalDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dequantize(&self, _elem_count: usize) -> Result<MetalStorage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn quantize(&mut self, _src: &MetalStorage) -> Result<()> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
_self_shape: &crate::Shape,
|
|
||||||
_storage: &MetalStorage,
|
|
||||||
_layout: &crate::Layout,
|
|
||||||
) -> Result<(MetalStorage, crate::Shape)> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
}
|
|
@ -233,7 +233,6 @@ pub struct Content {
|
|||||||
pub hparams: HParams,
|
pub hparams: HParams,
|
||||||
pub vocab: Vocab,
|
pub vocab: Vocab,
|
||||||
pub tensors: HashMap<String, super::QTensor>,
|
pub tensors: HashMap<String, super::QTensor>,
|
||||||
pub device: Device,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Content {
|
impl Content {
|
||||||
@ -253,13 +252,11 @@ impl Content {
|
|||||||
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
let (name, tensor) = read_one_tensor(reader, magic, device)?;
|
||||||
tensors.insert(name, tensor);
|
tensors.insert(name, tensor);
|
||||||
}
|
}
|
||||||
let device = device.clone();
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
magic,
|
magic,
|
||||||
hparams,
|
hparams,
|
||||||
vocab,
|
vocab,
|
||||||
tensors,
|
tensors,
|
||||||
device,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
use super::{GgmlDType, QStorage};
|
use super::{GgmlDType, QStorage};
|
||||||
use crate::backend::BackendStorage;
|
use crate::{DType, MetalDevice, MetalStorage, Result};
|
||||||
use crate::{DType, MetalDevice, MetalStorage, Result, Shape};
|
|
||||||
use metal::Buffer;
|
use metal::Buffer;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@ -11,28 +10,22 @@ pub struct QMetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl QMetalStorage {
|
impl QMetalStorage {
|
||||||
pub fn zeros(device: &MetalDevice, elem_count: usize, dtype: GgmlDType) -> Result<Self> {
|
|
||||||
let size = elem_count * dtype.type_size() / dtype.block_size();
|
|
||||||
let buffer = device.allocate_zeros(size)?;
|
|
||||||
Ok(Self {
|
|
||||||
buffer,
|
|
||||||
device: device.clone(),
|
|
||||||
dtype,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dtype(&self) -> GgmlDType {
|
pub fn dtype(&self) -> GgmlDType {
|
||||||
self.dtype
|
self.dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> &MetalDevice {
|
|
||||||
&self.device
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn buffer(&self) -> &Buffer {
|
pub fn buffer(&self) -> &Buffer {
|
||||||
&self.buffer
|
&self.buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: GgmlDType) -> Self {
|
||||||
|
Self {
|
||||||
|
device,
|
||||||
|
buffer,
|
||||||
|
dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
@ -137,59 +130,6 @@ impl QMetalStorage {
|
|||||||
self.buffer = buffer;
|
self.buffer = buffer;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn storage_size_in_bytes(&self) -> usize {
|
|
||||||
self.buffer.length() as usize
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fwd(
|
|
||||||
&self,
|
|
||||||
self_shape: &Shape,
|
|
||||||
storage: &MetalStorage,
|
|
||||||
layout: &crate::Layout,
|
|
||||||
) -> Result<(MetalStorage, Shape)> {
|
|
||||||
use crate::MetalError;
|
|
||||||
|
|
||||||
if !layout.is_contiguous() {
|
|
||||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
|
||||||
}
|
|
||||||
let src_shape = layout.shape();
|
|
||||||
// self is transposed so n is first then k.
|
|
||||||
if src_shape.rank() < 2 {
|
|
||||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
|
||||||
}
|
|
||||||
let (n, k) = self_shape.dims2()?;
|
|
||||||
let mut dst_shape = src_shape.dims().to_vec();
|
|
||||||
|
|
||||||
let (b, m) = match dst_shape.len() {
|
|
||||||
3 => (dst_shape[0], dst_shape[1]),
|
|
||||||
2 => (1, dst_shape[0]),
|
|
||||||
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
|
||||||
};
|
|
||||||
let last_k = dst_shape.pop().unwrap();
|
|
||||||
if last_k != k {
|
|
||||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self_shape)
|
|
||||||
}
|
|
||||||
dst_shape.push(n);
|
|
||||||
let dst_shape = Shape::from(dst_shape);
|
|
||||||
let device = storage.device().clone();
|
|
||||||
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
|
||||||
let command_buffer = device.command_buffer()?;
|
|
||||||
candle_metal_kernels::call_quantized_matmul_t(
|
|
||||||
device.device(),
|
|
||||||
&command_buffer,
|
|
||||||
device.kernels(),
|
|
||||||
self.dtype.into(),
|
|
||||||
(b, m, n, k),
|
|
||||||
storage.buffer(),
|
|
||||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
|
||||||
&self.buffer,
|
|
||||||
&dst,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
|
||||||
Ok((dst_storage, dst_shape))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||||
@ -211,24 +151,3 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|||||||
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
|
||||||
slice.to_vec()
|
slice.to_vec()
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
|
||||||
fn from(value: GgmlDType) -> Self {
|
|
||||||
match value {
|
|
||||||
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
|
||||||
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
|
||||||
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
|
||||||
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
|
||||||
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
|
||||||
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
|
||||||
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
|
||||||
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
|
||||||
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
|
||||||
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
|
||||||
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
|
||||||
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
|
||||||
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
|
||||||
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,19 +1,16 @@
|
|||||||
|
#[cfg(feature = "metal")]
|
||||||
|
use crate::{backend::BackendStorage, DType};
|
||||||
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
use crate::{CpuStorage, Device, Result, Shape, Storage, Tensor};
|
||||||
use k_quants::*;
|
use k_quants::*;
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
|
|
||||||
#[cfg(target_feature = "avx")]
|
#[cfg(target_feature = "avx")]
|
||||||
pub mod avx;
|
pub mod avx;
|
||||||
mod dummy_metal;
|
|
||||||
pub mod ggml_file;
|
pub mod ggml_file;
|
||||||
pub mod gguf_file;
|
pub mod gguf_file;
|
||||||
pub mod k_quants;
|
pub mod k_quants;
|
||||||
#[cfg(feature = "metal")]
|
#[cfg(feature = "metal")]
|
||||||
pub mod metal;
|
pub mod metal;
|
||||||
#[cfg(not(feature = "metal"))]
|
|
||||||
mod metal {
|
|
||||||
pub use super::dummy_metal::*;
|
|
||||||
}
|
|
||||||
#[cfg(target_feature = "neon")]
|
#[cfg(target_feature = "neon")]
|
||||||
pub mod neon;
|
pub mod neon;
|
||||||
#[cfg(target_feature = "simd128")]
|
#[cfg(target_feature = "simd128")]
|
||||||
@ -35,9 +32,19 @@ impl Device {
|
|||||||
let storage = dtype.cpu_zeros(elem_count);
|
let storage = dtype.cpu_zeros(elem_count);
|
||||||
Ok(QStorage::Cpu(storage))
|
Ok(QStorage::Cpu(storage))
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
Device::Metal(metal) => {
|
Device::Metal(metal) => {
|
||||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
let size = elem_count * dtype.type_size() / dtype.block_size();
|
||||||
Ok(QStorage::Metal(storage))
|
let buffer = metal.allocate_zeros(size)?;
|
||||||
|
Ok(QStorage::Metal(metal::QMetalStorage::new(
|
||||||
|
buffer,
|
||||||
|
metal.clone(),
|
||||||
|
dtype,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "metal"))]
|
||||||
|
Device::Metal(_metal) => {
|
||||||
|
crate::bail!("Metal feature not activated");
|
||||||
}
|
}
|
||||||
Device::Cuda(_cuda) => {
|
Device::Cuda(_cuda) => {
|
||||||
crate::bail!("Cuda ggml quantization not supported");
|
crate::bail!("Cuda ggml quantization not supported");
|
||||||
@ -48,6 +55,7 @@ impl Device {
|
|||||||
|
|
||||||
pub enum QStorage {
|
pub enum QStorage {
|
||||||
Cpu(Box<dyn QuantizedType>),
|
Cpu(Box<dyn QuantizedType>),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
Metal(metal::QMetalStorage),
|
Metal(metal::QMetalStorage),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,6 +63,7 @@ impl QStorage {
|
|||||||
fn block_size(&self) -> usize {
|
fn block_size(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.block_size(),
|
QStorage::Cpu(storage) => storage.block_size(),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -62,21 +71,16 @@ impl QStorage {
|
|||||||
fn dtype(&self) -> GgmlDType {
|
fn dtype(&self) -> GgmlDType {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.dtype(),
|
QStorage::Cpu(storage) => storage.dtype(),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
QStorage::Metal(storage) => storage.dtype(),
|
QStorage::Metal(storage) => storage.dtype(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn device(&self) -> Device {
|
|
||||||
match self {
|
|
||||||
QStorage::Cpu(_storage) => Device::Cpu,
|
|
||||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn size_in_bytes(&self) -> usize {
|
fn size_in_bytes(&self) -> usize {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
#[cfg(feature = "metal")]
|
||||||
|
QStorage::Metal(storage) => storage.buffer().length() as usize,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,6 +89,7 @@ impl QStorage {
|
|||||||
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
(QStorage::Cpu(storage), Storage::Cpu(src)) => {
|
||||||
storage.from_float(src.as_slice::<f32>()?)?;
|
storage.from_float(src.as_slice::<f32>()?)?;
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||||
}
|
}
|
||||||
@ -94,6 +99,7 @@ impl QStorage {
|
|||||||
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
fn dequantize(&self, elem_count: usize) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,6 +112,7 @@ impl QStorage {
|
|||||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||||
Ok(Cow::from(data))
|
Ok(Cow::from(data))
|
||||||
}
|
}
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
QStorage::Metal(_storage) => {
|
QStorage::Metal(_storage) => {
|
||||||
crate::bail!("not implemented");
|
crate::bail!("not implemented");
|
||||||
}
|
}
|
||||||
@ -329,10 +336,6 @@ impl QTensor {
|
|||||||
self.storage.dtype()
|
self.storage.dtype()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn device(&self) -> Device {
|
|
||||||
self.storage.device()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rank(&self) -> usize {
|
pub fn rank(&self) -> usize {
|
||||||
self.shape.rank()
|
self.shape.rank()
|
||||||
}
|
}
|
||||||
@ -424,7 +427,8 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
#[allow(clippy::infallible_destructuring_match)]
|
#[allow(clippy::infallible_destructuring_match)]
|
||||||
let self_storage = match &self.storage {
|
let self_storage = match &self.storage {
|
||||||
QStorage::Cpu(storage) => storage,
|
QStorage::Cpu(storage) => storage,
|
||||||
QStorage::Metal(_) => crate::bail!("Invalid storage"),
|
#[cfg(feature = "metal")]
|
||||||
|
_ => crate::bail!("Invalid storage"),
|
||||||
};
|
};
|
||||||
let slice = storage.as_slice::<f32>()?;
|
let slice = storage.as_slice::<f32>()?;
|
||||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||||
@ -433,16 +437,79 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
fn metal_fwd(
|
fn metal_fwd(
|
||||||
&self,
|
&self,
|
||||||
storage: &crate::MetalStorage,
|
storage: &crate::MetalStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(crate::MetalStorage, Shape)> {
|
) -> Result<(crate::MetalStorage, Shape)> {
|
||||||
let self_storage = match &self.storage {
|
use crate::MetalError;
|
||||||
QStorage::Metal(metal) => metal,
|
|
||||||
|
if !layout.is_contiguous() {
|
||||||
|
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||||
|
}
|
||||||
|
let src_shape = layout.shape();
|
||||||
|
// self is transposed so n is first then k.
|
||||||
|
if src_shape.rank() < 2 {
|
||||||
|
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||||
|
}
|
||||||
|
let (n, k) = self.shape.dims2()?;
|
||||||
|
let mut dst_shape = src_shape.dims().to_vec();
|
||||||
|
|
||||||
|
let (b, m) = match dst_shape.len() {
|
||||||
|
3 => (dst_shape[0], dst_shape[1]),
|
||||||
|
2 => (1, dst_shape[0]),
|
||||||
|
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
|
||||||
|
};
|
||||||
|
let last_k = dst_shape.pop().unwrap();
|
||||||
|
if last_k != k {
|
||||||
|
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||||
|
}
|
||||||
|
dst_shape.push(n);
|
||||||
|
let dst_shape = Shape::from(dst_shape);
|
||||||
|
let device = storage.device().clone();
|
||||||
|
let dst = device.new_buffer(dst_shape.elem_count(), DType::F32, "qmatmul")?;
|
||||||
|
let (buffer, dtype) = match &self.storage {
|
||||||
|
QStorage::Metal(metal) => (metal.buffer(), metal.dtype()),
|
||||||
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
_ => unreachable!("Cannot call metal matmul on non metal QTensor"),
|
||||||
};
|
};
|
||||||
self_storage.fwd(&self.shape, storage, layout)
|
let command_buffer = device.command_buffer()?;
|
||||||
|
candle_metal_kernels::call_quantized_matmul_t(
|
||||||
|
device.device(),
|
||||||
|
&command_buffer,
|
||||||
|
device.kernels(),
|
||||||
|
dtype.into(),
|
||||||
|
(b, m, n, k),
|
||||||
|
storage.buffer(),
|
||||||
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||||
|
buffer,
|
||||||
|
&dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||||
|
Ok((dst_storage, dst_shape))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
impl From<GgmlDType> for candle_metal_kernels::GgmlDType {
|
||||||
|
fn from(value: GgmlDType) -> Self {
|
||||||
|
match value {
|
||||||
|
GgmlDType::Q4_0 => candle_metal_kernels::GgmlDType::Q4_0,
|
||||||
|
GgmlDType::Q4_1 => candle_metal_kernels::GgmlDType::Q4_1,
|
||||||
|
GgmlDType::Q5_0 => candle_metal_kernels::GgmlDType::Q5_0,
|
||||||
|
GgmlDType::Q5_1 => candle_metal_kernels::GgmlDType::Q5_1,
|
||||||
|
GgmlDType::Q8_0 => candle_metal_kernels::GgmlDType::Q8_0,
|
||||||
|
GgmlDType::Q8_1 => candle_metal_kernels::GgmlDType::Q8_1,
|
||||||
|
GgmlDType::Q2K => candle_metal_kernels::GgmlDType::Q2K,
|
||||||
|
GgmlDType::Q3K => candle_metal_kernels::GgmlDType::Q3K,
|
||||||
|
GgmlDType::Q4K => candle_metal_kernels::GgmlDType::Q4K,
|
||||||
|
GgmlDType::Q5K => candle_metal_kernels::GgmlDType::Q5K,
|
||||||
|
GgmlDType::Q6K => candle_metal_kernels::GgmlDType::Q6K,
|
||||||
|
GgmlDType::Q8K => candle_metal_kernels::GgmlDType::Q8K,
|
||||||
|
GgmlDType::F16 => candle_metal_kernels::GgmlDType::F16,
|
||||||
|
GgmlDType::F32 => candle_metal_kernels::GgmlDType::F32,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -804,35 +804,6 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Roll the tensor input along the given dimension.
|
|
||||||
/// Elements that are shifted beyond the last position are re-introduced at the first position.
|
|
||||||
///
|
|
||||||
/// ```rust
|
|
||||||
/// # use candle_core::{Tensor, Device};
|
|
||||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
|
||||||
/// let tensor = tensor.roll(1, 0)?;
|
|
||||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[4., 5.], [0., 1.], [2., 3.]]);
|
|
||||||
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
|
|
||||||
/// let tensor = tensor.roll(-1, 0)?;
|
|
||||||
/// assert_eq!(tensor.to_vec2::<f32>()?, &[[2., 3.], [4., 5.], [0., 1.]]);
|
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
|
||||||
/// ```
|
|
||||||
pub fn roll<D>(&self, shift: i32, dim: D) -> Result<Self>
|
|
||||||
where
|
|
||||||
D: Dim + Clone,
|
|
||||||
{
|
|
||||||
let dim = dim.to_index(self.shape(), "roll")?;
|
|
||||||
let dim_size = self.dim(dim)?;
|
|
||||||
let shift = shift.rem_euclid(dim_size as i32) as usize;
|
|
||||||
if shift == 0 {
|
|
||||||
Ok(self.clone())
|
|
||||||
} else {
|
|
||||||
let a = self.narrow(dim, 0, dim_size - shift)?;
|
|
||||||
let b = self.narrow(dim, dim_size - shift, shift)?;
|
|
||||||
Tensor::cat(&[&b, &a], dim)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
/// Returns the sum of all elements in the input tensor. The sum is performed over all the
|
||||||
/// input dimensions.
|
/// input dimensions.
|
||||||
///
|
///
|
||||||
@ -1882,9 +1853,9 @@ impl Tensor {
|
|||||||
/// this new node. The storage of this tensor is shared with the initial tensor.
|
/// this new node. The storage of this tensor is shared with the initial tensor.
|
||||||
///
|
///
|
||||||
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
/// If the tensor is already detached from the computation graph, the same tensor is returned.
|
||||||
pub fn detach(&self) -> Tensor {
|
pub fn detach(&self) -> Result<Tensor> {
|
||||||
if self.op.is_none() && !self.is_variable {
|
if self.op.is_none() && !self.is_variable {
|
||||||
self.clone()
|
Ok(self.clone())
|
||||||
} else {
|
} else {
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
@ -1895,7 +1866,7 @@ impl Tensor {
|
|||||||
dtype: self.dtype,
|
dtype: self.dtype,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
};
|
};
|
||||||
Tensor(Arc::new(tensor_))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,10 +107,6 @@ impl Var {
|
|||||||
Ok(Self(inner))
|
Ok(Self(inner))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn as_detached_tensor(&self) -> Tensor {
|
|
||||||
self.0.detach()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn as_tensor(&self) -> &Tensor {
|
pub fn as_tensor(&self) -> &Tensor {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,7 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||||
);
|
);
|
||||||
|
if dev.is_cpu() {
|
||||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||||
assert_eq!(res.dims(), [1, 2, 7]);
|
assert_eq!(res.dims(), [1, 2, 7]);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -59,6 +60,7 @@ fn conv1d(dev: &Device) -> Result<()> {
|
|||||||
4.7076, -5.9745, -0.8276, 1.621
|
4.7076, -5.9745, -0.8276, 1.621
|
||||||
],
|
],
|
||||||
);
|
);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Binary file not shown.
@ -1,37 +0,0 @@
|
|||||||
import torch
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
# Write a trivial tensor to a pt file
|
|
||||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
|
||||||
o = OrderedDict()
|
|
||||||
o["test"] = a
|
|
||||||
|
|
||||||
# Write a trivial tensor to a pt file
|
|
||||||
torch.save(o, "test.pt")
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
# Write a trivial tensor to a pt file with a key
|
|
||||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
|
||||||
|
|
||||||
############################################################################################################
|
|
||||||
# Create a tensor with fortran contiguous memory layout
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
|
||||||
# For example, creating a 2x3x4 array
|
|
||||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
|
||||||
|
|
||||||
# Verify the memory order
|
|
||||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
|
||||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
|
||||||
|
|
||||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
|
||||||
tensor_fortran = torch.from_numpy(array_fortran)
|
|
||||||
|
|
||||||
# Verify the tensor layout
|
|
||||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
|
||||||
|
|
||||||
# Step 3: Save the PyTorch tensor to a .pth file
|
|
||||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
|
||||||
|
|
||||||
print("3D Tensor saved with Fortran layout.")
|
|
@ -1,31 +0,0 @@
|
|||||||
/// Regression test for pth files not loading on Windows.
|
|
||||||
#[test]
|
|
||||||
fn test_pth() {
|
|
||||||
let tensors = candle_core::pickle::PthTensors::new("tests/test.pt", None).unwrap();
|
|
||||||
tensors.get("test").unwrap().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_pth_with_key() {
|
|
||||||
let tensors =
|
|
||||||
candle_core::pickle::PthTensors::new("tests/test_with_key.pt", Some("model_state_dict"))
|
|
||||||
.unwrap();
|
|
||||||
tensors.get("test").unwrap().unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_pth_fortran_congiguous() {
|
|
||||||
let tensors =
|
|
||||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
|
||||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
|
||||||
|
|
||||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
tensor.to_vec3::<i64>().unwrap(),
|
|
||||||
[
|
|
||||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
|
||||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
|
||||||
]
|
|
||||||
);
|
|
||||||
}
|
|
Binary file not shown.
Binary file not shown.
@ -21,7 +21,7 @@ candle-onnx = { workspace = true, optional = true }
|
|||||||
csv = "1.3.0"
|
csv = "1.3.0"
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
hf-hub = { workspace = true, features = ["tokio"] }
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
@ -30,9 +30,7 @@ rayon = { workspace = true }
|
|||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
symphonia = { version = "0.5.3", features = ["all"] }
|
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
cpal= { version = "0.15.2", optional = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -45,6 +43,7 @@ rusttype = { workspace = true }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-chrome = { workspace = true }
|
tracing-chrome = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
|
wav = { workspace = true }
|
||||||
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
|
||||||
tokio = "1.29.1"
|
tokio = "1.29.1"
|
||||||
|
|
||||||
@ -62,7 +61,6 @@ mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/
|
|||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
metal = ["candle/metal", "candle-nn/metal"]
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
microphone = ["cpal"]
|
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
@ -79,7 +77,3 @@ required-features = ["onnx"]
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "onnx_basics"
|
name = "onnx_basics"
|
||||||
required-features = ["onnx"]
|
required-features = ["onnx"]
|
||||||
|
|
||||||
[[example]]
|
|
||||||
name = "whisper-microphone"
|
|
||||||
required-features = ["microphone"]
|
|
||||||
|
@ -1,237 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle_transformers::models::chatglm::{Config, Model};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
verbose_prompt: bool,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
verbose_prompt,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
println!("starting the inference loop");
|
|
||||||
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
|
|
||||||
if tokens.is_empty() {
|
|
||||||
anyhow::bail!("Empty prompts are not supported in the chatglm model.")
|
|
||||||
}
|
|
||||||
if self.verbose_prompt {
|
|
||||||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
|
|
||||||
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
|
|
||||||
println!("{id:7} -> '{token}'");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mut tokens = tokens.get_ids().to_vec();
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_vocab(true).get("</s>") {
|
|
||||||
Some(token) => *token,
|
|
||||||
None => anyhow::bail!("cannot find the endoftext token"),
|
|
||||||
};
|
|
||||||
print!("{prompt}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input)?;
|
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
|
|
||||||
print!("{token}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
/// Display the token for the specified prompt.
|
|
||||||
#[arg(long)]
|
|
||||||
verbose_prompt: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id.to_string(),
|
|
||||||
None => "THUDM/chatglm3-6b".to_string(),
|
|
||||||
};
|
|
||||||
let revision = match args.revision {
|
|
||||||
Some(rev) => rev.to_string(),
|
|
||||||
None => "main".to_string(),
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
|
||||||
let tokenizer_filename = match args.tokenizer {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => api
|
|
||||||
.model("lmz/candle-chatglm".to_string())
|
|
||||||
.get("chatglm-tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_file {
|
|
||||||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
|
|
||||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config = Config::glm3_6b();
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
args.verbose_prompt,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,22 +0,0 @@
|
|||||||
# candle-convnext
|
|
||||||
|
|
||||||
[A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545).
|
|
||||||
|
|
||||||
This candle implementation uses a pre-trained ConvNeXt network for inference. The
|
|
||||||
classification head has been trained on the ImageNet dataset and returns the
|
|
||||||
probabilities for the top-5 classes.
|
|
||||||
|
|
||||||
## Running an example
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --example convnext --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which tiny
|
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
|
||||||
model built
|
|
||||||
mountain bike, all-terrain bike, off-roader: 84.09%
|
|
||||||
bicycle-built-for-two, tandem bicycle, tandem: 4.15%
|
|
||||||
maillot : 0.74%
|
|
||||||
crash helmet : 0.54%
|
|
||||||
unicycle, monocycle : 0.44%
|
|
||||||
|
|
||||||
```
|
|
@ -1,102 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle::{DType, IndexOp, D};
|
|
||||||
use candle_nn::{Module, VarBuilder};
|
|
||||||
use candle_transformers::models::convnext;
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
Tiny,
|
|
||||||
Small,
|
|
||||||
Base,
|
|
||||||
Large,
|
|
||||||
XLarge,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_filename(&self) -> String {
|
|
||||||
let name = match self {
|
|
||||||
Self::Tiny => "tiny",
|
|
||||||
Self::Small => "small",
|
|
||||||
Self::Base => "base",
|
|
||||||
Self::Large => "large",
|
|
||||||
Self::XLarge => "xlarge",
|
|
||||||
};
|
|
||||||
// The XLarge model only has an ImageNet-22K variant
|
|
||||||
let variant = match self {
|
|
||||||
Self::XLarge => "fb_in22k_ft_in1k",
|
|
||||||
_ => "fb_in1k",
|
|
||||||
};
|
|
||||||
|
|
||||||
format!("timm/convnext_{name}.{variant}")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config(&self) -> convnext::Config {
|
|
||||||
match self {
|
|
||||||
Self::Tiny => convnext::Config::tiny(),
|
|
||||||
Self::Small => convnext::Config::small(),
|
|
||||||
Self::Base => convnext::Config::base(),
|
|
||||||
Self::Large => convnext::Config::large(),
|
|
||||||
Self::XLarge => convnext::Config::xlarge(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(value_enum, long, default_value_t=Which::Tiny)]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
|
||||||
None => {
|
|
||||||
let model_name = args.which.model_filename();
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.model(model_name);
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
}
|
|
||||||
Some(model) => model.into(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
|
||||||
let model = convnext::convnext(&args.which.config(), 1000, vb)?;
|
|
||||||
println!("model built");
|
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
|
||||||
.i(0)?
|
|
||||||
.to_vec1::<f32>()?;
|
|
||||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
|
||||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
|
||||||
for &(category_idx, pr) in prs.iter().take(5) {
|
|
||||||
println!(
|
|
||||||
"{:24}: {:.2}%",
|
|
||||||
candle_examples::imagenet::CLASSES[category_idx],
|
|
||||||
100. * pr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -2,9 +2,6 @@
|
|||||||
|
|
||||||
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
|
||||||
|
|
||||||
Compared to the mamba example, this version can handle training but is much
|
|
||||||
slower.
|
|
||||||
|
|
||||||
## Running the example
|
## Running the example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
# candle-mamba: Mamba implementation
|
|
||||||
|
|
||||||
Candle implementation of *Mamba* [1] inference only. Mamba is an alternative to
|
|
||||||
the transformer architecture. It leverages State Space Models (SSMs) with the
|
|
||||||
goal of being computationally efficient on long sequences. The implementation is
|
|
||||||
based on [mamba.rs](https://github.com/LaurentMazare/mamba.rs).
|
|
||||||
|
|
||||||
- [1]. [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752).
|
|
||||||
|
|
||||||
Compared to the mamba-minimal example, this version is far more efficient but
|
|
||||||
would only work for inference.
|
|
||||||
## Running the example
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
|
|
||||||
```
|
|
||||||
|
|
@ -1,299 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle_transformers::models::mamba::{Config, Model, State};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
config: Config,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
config: Config,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
config,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the </s> token"),
|
|
||||||
};
|
|
||||||
let mut state = State::new(1, &self.config, &self.device)?;
|
|
||||||
let mut next_logits = None;
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
let input = Tensor::new(&[t], &self.device)?;
|
|
||||||
let logits = self.model.forward(&input, &mut state)?;
|
|
||||||
next_logits = Some(logits);
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for _ in 0..sample_len {
|
|
||||||
let logits = match next_logits.as_ref() {
|
|
||||||
Some(logits) => logits,
|
|
||||||
None => anyhow::bail!("cannot work on an empty prompt"),
|
|
||||||
};
|
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let input = Tensor::new(&[next_token], &self.device)?;
|
|
||||||
next_logits = Some(self.model.forward(&input, &mut state)?)
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
|
|
||||||
enum Which {
|
|
||||||
Mamba130m,
|
|
||||||
Mamba370m,
|
|
||||||
Mamba790m,
|
|
||||||
Mamba1_4b,
|
|
||||||
Mamba2_8b,
|
|
||||||
Mamba2_8bSlimPj,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Display for Which {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
write!(f, "{:?}", self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_id(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Mamba130m => "state-spaces/mamba-130m",
|
|
||||||
Self::Mamba370m => "state-spaces/mamba-370m",
|
|
||||||
Self::Mamba790m => "state-spaces/mamba-790m",
|
|
||||||
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
|
|
||||||
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
|
|
||||||
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn revision(&self) -> &'static str {
|
|
||||||
match self {
|
|
||||||
Self::Mamba130m
|
|
||||||
| Self::Mamba370m
|
|
||||||
| Self::Mamba790m
|
|
||||||
| Self::Mamba1_4b
|
|
||||||
| Self::Mamba2_8bSlimPj => "refs/pr/1",
|
|
||||||
Self::Mamba2_8b => "refs/pr/4",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 5000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "mamba130m")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
config_file: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
args.model_id
|
|
||||||
.unwrap_or_else(|| args.which.model_id().to_string()),
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision
|
|
||||||
.unwrap_or_else(|| args.which.revision().to_string()),
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => api
|
|
||||||
.model("EleutherAI/gpt-neox-20b".to_string())
|
|
||||||
.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let config_filename = match args.config_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("config.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => {
|
|
||||||
vec![repo.get("model.safetensors")?]
|
|
||||||
}
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
|
||||||
let model = Model::new(&config, vb.pp("backbone"))?;
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
config,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,22 +0,0 @@
|
|||||||
# candle-mobileone
|
|
||||||
|
|
||||||
[MobileOne: An Improved One millisecond Mobile Backbone](https://arxiv.org/abs/2206.04040).
|
|
||||||
|
|
||||||
This candle implementation uses a pre-trained MobileOne network for inference. The
|
|
||||||
classification head has been trained on the ImageNet dataset and returns the
|
|
||||||
probabilities for the top-5 classes.
|
|
||||||
|
|
||||||
## Running an example
|
|
||||||
|
|
||||||
```
|
|
||||||
$ cargo run --example mobileone --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which s2
|
|
||||||
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
|
||||||
model built
|
|
||||||
mountain bike, all-terrain bike, off-roader: 79.33%
|
|
||||||
bicycle-built-for-two, tandem bicycle, tandem: 15.32%
|
|
||||||
crash helmet : 2.58%
|
|
||||||
unicycle, monocycle : 1.70%
|
|
||||||
alp : 0.21%
|
|
||||||
|
|
||||||
```
|
|
@ -1,96 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
|
|
||||||
use candle::{DType, IndexOp, D};
|
|
||||||
use candle_nn::{Module, VarBuilder};
|
|
||||||
use candle_transformers::models::mobileone;
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
|
||||||
enum Which {
|
|
||||||
S0,
|
|
||||||
S1,
|
|
||||||
S2,
|
|
||||||
S3,
|
|
||||||
S4,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn model_filename(&self) -> String {
|
|
||||||
let name = match self {
|
|
||||||
Self::S0 => "s0",
|
|
||||||
Self::S1 => "s1",
|
|
||||||
Self::S2 => "s2",
|
|
||||||
Self::S3 => "s3",
|
|
||||||
Self::S4 => "s4",
|
|
||||||
};
|
|
||||||
format!("timm/mobileone_{}.apple_in1k", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn config(&self) -> mobileone::Config {
|
|
||||||
match self {
|
|
||||||
Self::S0 => mobileone::Config::s0(),
|
|
||||||
Self::S1 => mobileone::Config::s1(),
|
|
||||||
Self::S2 => mobileone::Config::s2(),
|
|
||||||
Self::S3 => mobileone::Config::s3(),
|
|
||||||
Self::S4 => mobileone::Config::s4(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser)]
|
|
||||||
struct Args {
|
|
||||||
#[arg(long)]
|
|
||||||
model: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
image: String,
|
|
||||||
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(value_enum, long, default_value_t=Which::S0)]
|
|
||||||
which: Which,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
|
||||||
let args = Args::parse();
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
|
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
|
||||||
println!("loaded image {image:?}");
|
|
||||||
|
|
||||||
let model_file = match args.model {
|
|
||||||
None => {
|
|
||||||
let model_name = args.which.model_filename();
|
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
let api = api.model(model_name);
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
}
|
|
||||||
Some(model) => model.into(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
|
||||||
let model = mobileone::mobileone(&args.which.config(), 1000, vb)?;
|
|
||||||
println!("model built");
|
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
|
||||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
|
||||||
.i(0)?
|
|
||||||
.to_vec1::<f32>()?;
|
|
||||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
|
||||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
|
||||||
for &(category_idx, pr) in prs.iter().take(5) {
|
|
||||||
println!(
|
|
||||||
"{:24}: {:.2}%",
|
|
||||||
candle_examples::imagenet::CLASSES[category_idx],
|
|
||||||
100. * pr
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -1,39 +1,10 @@
|
|||||||
## Using ONNX models in Candle
|
## Using ONNX models in Candle
|
||||||
|
|
||||||
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based models in Candle.
|
This example demonstrates how to run ONNX based models in Candle, the model
|
||||||
|
being used here is a small sequeezenet variant.
|
||||||
|
|
||||||
It contains small variants of two models, [SqueezeNet](https://arxiv.org/pdf/1602.07360.pdf) (default) and [EfficientNet](https://arxiv.org/pdf/1905.11946.pdf).
|
You can run the example with the following command:
|
||||||
|
|
||||||
You can run the examples with following commands:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example onnx --features=onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
cargo run --example squeezenet-onnx --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
||||||
```
|
|
||||||
|
|
||||||
Use the `--which` flag to specify explicitly which network to use, i.e.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example onnx --features=onnx --release -- --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
Finished release [optimized] target(s) in 0.21s
|
|
||||||
Running `target/release/examples/onnx --which squeeze-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
|
|
||||||
loaded image Tensor[dims 3, 224, 224; f32]
|
|
||||||
unicycle, monocycle : 83.23%
|
|
||||||
ballplayer, baseball player : 3.68%
|
|
||||||
bearskin, busby, shako : 1.54%
|
|
||||||
military uniform : 0.78%
|
|
||||||
cowboy hat, ten-gallon hat : 0.76%
|
|
||||||
```
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ cargo run --example onnx --features=onnx --release -- --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg
|
|
||||||
|
|
||||||
Finished release [optimized] target(s) in 0.20s
|
|
||||||
Running `target/release/examples/onnx --which efficient-net --image candle-examples/examples/yolo-v8/assets/bike.jpg`
|
|
||||||
loaded image Tensor[dims 224, 224, 3; f32]
|
|
||||||
bicycle-built-for-two, tandem bicycle, tandem : 99.16%
|
|
||||||
mountain bike, all-terrain bike, off-roader : 0.60%
|
|
||||||
unicycle, monocycle : 0.17%
|
|
||||||
crash helmet : 0.02%
|
|
||||||
alp : 0.02%
|
|
||||||
```
|
```
|
||||||
|
@ -1,281 +0,0 @@
|
|||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
use candle_transformers::models::qwen2::{Config, Model};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
|
||||||
model: Model,
|
|
||||||
device: Device,
|
|
||||||
tokenizer: TokenOutputStream,
|
|
||||||
logits_processor: LogitsProcessor,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TextGeneration {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
temp: Option<f64>,
|
|
||||||
top_p: Option<f64>,
|
|
||||||
repeat_penalty: f32,
|
|
||||||
repeat_last_n: usize,
|
|
||||||
device: &Device,
|
|
||||||
) -> Self {
|
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
|
||||||
Self {
|
|
||||||
model,
|
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
|
||||||
logits_processor,
|
|
||||||
repeat_penalty,
|
|
||||||
repeat_last_n,
|
|
||||||
device: device.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
|
||||||
use std::io::Write;
|
|
||||||
self.tokenizer.clear();
|
|
||||||
let mut tokens = self
|
|
||||||
.tokenizer
|
|
||||||
.tokenizer()
|
|
||||||
.encode(prompt, true)
|
|
||||||
.map_err(E::msg)?
|
|
||||||
.get_ids()
|
|
||||||
.to_vec();
|
|
||||||
for &t in tokens.iter() {
|
|
||||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
|
||||||
print!("{t}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
|
|
||||||
let mut generated_tokens = 0usize;
|
|
||||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
|
||||||
Some(token) => token,
|
|
||||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
|
||||||
};
|
|
||||||
let start_gen = std::time::Instant::now();
|
|
||||||
for index in 0..sample_len {
|
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
|
||||||
let start_pos = tokens.len().saturating_sub(context_size);
|
|
||||||
let ctxt = &tokens[start_pos..];
|
|
||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
|
||||||
let logits = self.model.forward(&input, start_pos)?;
|
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
|
||||||
let logits = if self.repeat_penalty == 1. {
|
|
||||||
logits
|
|
||||||
} else {
|
|
||||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
|
||||||
candle_transformers::utils::apply_repeat_penalty(
|
|
||||||
&logits,
|
|
||||||
self.repeat_penalty,
|
|
||||||
&tokens[start_at..],
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let next_token = self.logits_processor.sample(&logits)?;
|
|
||||||
tokens.push(next_token);
|
|
||||||
generated_tokens += 1;
|
|
||||||
if next_token == eos_token {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
|
||||||
print!("{t}");
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let dt = start_gen.elapsed();
|
|
||||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
|
||||||
print!("{rest}");
|
|
||||||
}
|
|
||||||
std::io::stdout().flush()?;
|
|
||||||
println!(
|
|
||||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
|
||||||
generated_tokens as f64 / dt.as_secs_f64(),
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)]
|
|
||||||
enum WhichModel {
|
|
||||||
#[value(name = "0.5b")]
|
|
||||||
W0_5b,
|
|
||||||
#[value(name = "1.8b")]
|
|
||||||
W1_8b,
|
|
||||||
#[value(name = "4b")]
|
|
||||||
W4b,
|
|
||||||
#[value(name = "7b")]
|
|
||||||
W7b,
|
|
||||||
#[value(name = "14b")]
|
|
||||||
W14b,
|
|
||||||
#[value(name = "72b")]
|
|
||||||
W72b,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
use_flash_attn: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
prompt: String,
|
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
|
||||||
#[arg(long)]
|
|
||||||
temperature: Option<f64>,
|
|
||||||
|
|
||||||
/// Nucleus sampling probability cutoff.
|
|
||||||
#[arg(long)]
|
|
||||||
top_p: Option<f64>,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
|
||||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
|
||||||
sample_len: usize,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
|
||||||
revision: String,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_file: Option<String>,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
weight_files: Option<String>,
|
|
||||||
|
|
||||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
|
||||||
#[arg(long, default_value_t = 1.1)]
|
|
||||||
repeat_penalty: f32,
|
|
||||||
|
|
||||||
/// The context size to consider for the repeat penalty.
|
|
||||||
#[arg(long, default_value_t = 64)]
|
|
||||||
repeat_last_n: usize,
|
|
||||||
|
|
||||||
#[arg(long, default_value = "0.5b")]
|
|
||||||
model: WhichModel,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
println!(
|
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
|
||||||
candle::utils::with_avx(),
|
|
||||||
candle::utils::with_neon(),
|
|
||||||
candle::utils::with_simd128(),
|
|
||||||
candle::utils::with_f16c()
|
|
||||||
);
|
|
||||||
println!(
|
|
||||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
|
||||||
args.temperature.unwrap_or(0.),
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n
|
|
||||||
);
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let api = Api::new()?;
|
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id,
|
|
||||||
None => {
|
|
||||||
let size = match args.model {
|
|
||||||
WhichModel::W0_5b => "0.5B",
|
|
||||||
WhichModel::W1_8b => "1.8B",
|
|
||||||
WhichModel::W4b => "4B",
|
|
||||||
WhichModel::W7b => "7B",
|
|
||||||
WhichModel::W14b => "14B",
|
|
||||||
WhichModel::W72b => "72B",
|
|
||||||
};
|
|
||||||
format!("Qwen/Qwen1.5-{size}")
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
|
||||||
model_id,
|
|
||||||
RepoType::Model,
|
|
||||||
args.revision,
|
|
||||||
));
|
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
|
||||||
None => repo.get("tokenizer.json")?,
|
|
||||||
};
|
|
||||||
let filenames = match args.weight_files {
|
|
||||||
Some(files) => files
|
|
||||||
.split(',')
|
|
||||||
.map(std::path::PathBuf::from)
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
None => match args.model {
|
|
||||||
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
|
|
||||||
WhichModel::W4b | WhichModel::W7b | WhichModel::W14b | WhichModel::W72b => {
|
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
|
||||||
}
|
|
||||||
},
|
|
||||||
};
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let config_file = repo.get("config.json")?;
|
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let dtype = if device.is_cuda() {
|
|
||||||
DType::BF16
|
|
||||||
} else {
|
|
||||||
DType::F32
|
|
||||||
};
|
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
|
||||||
let model = Model::new(&config, vb)?;
|
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
args.seed,
|
|
||||||
args.temperature,
|
|
||||||
args.top_p,
|
|
||||||
args.repeat_penalty,
|
|
||||||
args.repeat_last_n,
|
|
||||||
&device,
|
|
||||||
);
|
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
@ -411,7 +411,7 @@ impl DDPG<'_> {
|
|||||||
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
|
||||||
let actions = self
|
let actions = self
|
||||||
.actor
|
.actor
|
||||||
.forward(&state.detach().unsqueeze(0)?)?
|
.forward(&state.detach()?.unsqueeze(0)?)?
|
||||||
.squeeze(0)?;
|
.squeeze(0)?;
|
||||||
let actions = if self.train {
|
let actions = if self.train {
|
||||||
(actions + self.ou_noise.sample()?)?
|
(actions + self.ou_noise.sample()?)?
|
||||||
|
@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
|
|||||||
loop {
|
loop {
|
||||||
let action = {
|
let action = {
|
||||||
let action_probs: Vec<f32> =
|
let action_probs: Vec<f32> =
|
||||||
softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
|
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
|
||||||
.squeeze(0)?
|
.squeeze(0)?
|
||||||
.to_vec1()?;
|
.to_vec1()?;
|
||||||
weighted_sample(action_probs, &mut rng)? as i64
|
weighted_sample(action_probs, &mut rng)? as i64
|
||||||
@ -109,7 +109,7 @@ pub fn run() -> Result<()> {
|
|||||||
|
|
||||||
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
|
let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.detach();
|
.detach()?;
|
||||||
|
|
||||||
let actions_mask = {
|
let actions_mask = {
|
||||||
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
|
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
|
||||||
@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
|
|||||||
.unwrap()
|
.unwrap()
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
Tensor::stack(&actions_mask, 0)?.detach()
|
Tensor::stack(&actions_mask, 0)?.detach()?
|
||||||
};
|
};
|
||||||
|
|
||||||
let states = {
|
let states = {
|
||||||
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
|
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
|
||||||
Tensor::stack(&states, 0)?.detach()
|
Tensor::stack(&states, 0)?.detach()?
|
||||||
};
|
};
|
||||||
|
|
||||||
let log_probs = actions_mask
|
let log_probs = actions_mask
|
||||||
|
@ -8,13 +8,6 @@ Card](https://huggingface.co/stabilityai/stablelm-3b-4e1t).
|
|||||||
Note that this model is gated so you will have to request access on the Hub in
|
Note that this model is gated so you will have to request access on the Hub in
|
||||||
order to be able to use it.
|
order to be able to use it.
|
||||||
|
|
||||||
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
|
||||||
|
|
||||||
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
|
|
||||||
Candle, so to run it you can download a somewhat compatible
|
|
||||||
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
|
||||||
and pass it via the --tokenizer-file argument.
|
|
||||||
|
|
||||||
## Running some example
|
## Running some example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::Parser;
|
||||||
|
|
||||||
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
use candle_transformers::models::quantized_stable_lm::Model as QStableLM;
|
||||||
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
use candle_transformers::models::stable_lm::{Config, Model as StableLM};
|
||||||
@ -122,16 +122,6 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)]
|
|
||||||
enum Which {
|
|
||||||
V1Orig,
|
|
||||||
V1,
|
|
||||||
V1Zephyr,
|
|
||||||
V2,
|
|
||||||
V2Zephyr,
|
|
||||||
Code,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -162,18 +152,15 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 1000)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long, default_value = "lmz/candle-stablelm-3b-4e1t")]
|
||||||
model_id: Option<String>,
|
model_id: String,
|
||||||
|
|
||||||
#[arg(long, default_value = "main")]
|
#[arg(long, default_value = "main")]
|
||||||
revision: String,
|
revision: String,
|
||||||
|
|
||||||
#[arg(long, default_value = "v2")]
|
|
||||||
which: Which,
|
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer_file: Option<String>,
|
tokenizer_file: Option<String>,
|
||||||
|
|
||||||
@ -220,80 +207,33 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let model_id = match args.model_id {
|
|
||||||
Some(model_id) => model_id,
|
|
||||||
None => match args.which {
|
|
||||||
Which::V1Orig => "lmz/candle-stablelm-3b-4e1t".to_string(),
|
|
||||||
Which::V1 => "stabilityai/stablelm-3b-4e1t".to_string(),
|
|
||||||
Which::V1Zephyr => "stabilityai/stablelm-zephyr-3b".to_string(),
|
|
||||||
Which::Code => "stabilityai/stable-code-3b".to_string(),
|
|
||||||
Which::V2 => "stabilityai/stablelm-2-1_6b".to_string(),
|
|
||||||
Which::V2Zephyr => "stabilityai/stablelm-2-zephyr-1_6b".to_string(),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
model_id,
|
args.model_id,
|
||||||
RepoType::Model,
|
RepoType::Model,
|
||||||
args.revision,
|
args.revision,
|
||||||
));
|
));
|
||||||
let tokenizer_filename = match args.tokenizer_file {
|
let tokenizer_filename = match args.tokenizer_file {
|
||||||
Some(file) => std::path::PathBuf::from(file),
|
Some(file) => std::path::PathBuf::from(file),
|
||||||
None => match args.which {
|
None => repo.get("tokenizer.json")?,
|
||||||
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
|
|
||||||
repo.get("tokenizer.json")?
|
|
||||||
}
|
|
||||||
Which::V2 | Which::V2Zephyr => api
|
|
||||||
.model("lmz/candle-stablelm".to_string())
|
|
||||||
.get("tokenizer-gpt4.json")?,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
let filenames = match args.weight_files {
|
let filenames = match args.weight_files {
|
||||||
Some(files) => files
|
Some(files) => files
|
||||||
.split(',')
|
.split(',')
|
||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match (args.which, args.quantized) {
|
None => {
|
||||||
(Which::V1Orig | Which::V1, true) => vec![repo.get("model-q4k.gguf")?],
|
if args.quantized {
|
||||||
(Which::V2, true) => {
|
vec![repo.get("model-q4k.gguf")?]
|
||||||
let gguf = api
|
} else {
|
||||||
.model("lmz/candle-stablelm".to_string())
|
|
||||||
.get("stablelm-2-1_6b-q4k.gguf")?;
|
|
||||||
vec![gguf]
|
|
||||||
}
|
|
||||||
(Which::V2Zephyr, true) => {
|
|
||||||
let gguf = api
|
|
||||||
.model("lmz/candle-stablelm".to_string())
|
|
||||||
.get("stablelm-2-zephyr-1_6b-q4k.gguf")?;
|
|
||||||
vec![gguf]
|
|
||||||
}
|
|
||||||
(Which::V1Zephyr | Which::Code, true) => {
|
|
||||||
anyhow::bail!("Quantized {:?} variant not supported.", args.which)
|
|
||||||
}
|
|
||||||
(Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr, false) => {
|
|
||||||
vec![repo.get("model.safetensors")?]
|
vec![repo.get("model.safetensors")?]
|
||||||
}
|
}
|
||||||
(Which::Code, false) => {
|
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
|
||||||
}
|
}
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = match args.which {
|
let config = Config::stablelm_3b_4e1t(args.use_flash_attn);
|
||||||
Which::V1Orig => Config::stablelm_3b_4e1t(args.use_flash_attn),
|
|
||||||
Which::V1 | Which::V1Zephyr | Which::V2 | Which::V2Zephyr | Which::Code => {
|
|
||||||
let config_filename = repo.get("config.json")?;
|
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
|
||||||
let mut config: Config = serde_json::from_str(&config)?;
|
|
||||||
config.set_use_flash_attn(args.use_flash_attn);
|
|
||||||
config
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (model, device) = if args.quantized {
|
let (model, device) = if args.quantized {
|
||||||
let filename = &filenames[0];
|
let filename = &filenames[0];
|
||||||
|
Binary file not shown.
Before Width: | Height: | Size: 7.5 KiB |
@ -10,36 +10,15 @@ use clap::{Parser, ValueEnum};
|
|||||||
use candle::{DType, Tensor};
|
use candle::{DType, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::{trocr, vit};
|
use candle_transformers::models::trocr;
|
||||||
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
mod image_processor;
|
mod image_processor;
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, ValueEnum)]
|
#[derive(Clone, Debug, Copy, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
#[value(name = "base")]
|
Base,
|
||||||
BaseHandwritten,
|
Large,
|
||||||
#[value(name = "large")]
|
|
||||||
LargeHandwritten,
|
|
||||||
BasePrinted,
|
|
||||||
LargePrinted,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Which {
|
|
||||||
fn repo_and_branch_name(&self) -> (&str, &str) {
|
|
||||||
match self {
|
|
||||||
Self::BaseHandwritten => ("microsoft/trocr-base-handwritten", "refs/pr/3"),
|
|
||||||
Self::LargeHandwritten => ("microsoft/trocr-large-handwritten", "refs/pr/6"),
|
|
||||||
Self::BasePrinted => ("microsoft/trocr-base-printed", "refs/pr/7"),
|
|
||||||
Self::LargePrinted => ("microsoft/trocr-large-printed", "main"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
struct Config {
|
|
||||||
encoder: vit::Config,
|
|
||||||
decoder: trocr::TrOCRConfig,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -55,64 +34,63 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// The image file to be processed.
|
/// Text to be translated
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
image: String,
|
image: String,
|
||||||
|
|
||||||
/// Tokenization config.
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let api = hf_hub::api::sync::Api::new()?;
|
|
||||||
|
|
||||||
let mut tokenizer_dec = {
|
let tokenizer_dec = {
|
||||||
let tokenizer_file = match args.tokenizer {
|
let tokenizer = Api::new()?
|
||||||
None => api
|
|
||||||
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
.model(String::from("ToluClassics/candle-trocr-tokenizer"))
|
||||||
.get("tokenizer.json")?,
|
.get("tokenizer.json")?;
|
||||||
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
|
|
||||||
};
|
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
|
||||||
let tokenizer = Tokenizer::from_file(&tokenizer_file).map_err(E::msg)?;
|
|
||||||
TokenOutputStream::new(tokenizer)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
let vb = {
|
let vb = {
|
||||||
let model = match args.model {
|
let model = match args.model {
|
||||||
Some(model) => std::path::PathBuf::from(model),
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
None => {
|
None => match args.which {
|
||||||
let (repo, branch) = args.which.repo_and_branch_name();
|
Which::Base => Api::new()?
|
||||||
api.repo(hf_hub::Repo::with_revision(
|
.repo(hf_hub::Repo::with_revision(
|
||||||
repo.to_string(),
|
"microsoft/trocr-base-handwritten".to_string(),
|
||||||
hf_hub::RepoType::Model,
|
hf_hub::RepoType::Model,
|
||||||
branch.to_string(),
|
"refs/pr/3".to_string(),
|
||||||
))
|
))
|
||||||
.get("model.safetensors")?
|
.get("model.safetensors")?,
|
||||||
}
|
Which::Large => Api::new()?
|
||||||
|
.repo(hf_hub::Repo::with_revision(
|
||||||
|
"microsoft/trocr-large-handwritten".to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"refs/pr/6".to_string(),
|
||||||
|
))
|
||||||
|
.get("model.safetensors")?,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
println!("model: {:?}", model);
|
println!("model: {:?}", model);
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? }
|
||||||
};
|
};
|
||||||
|
|
||||||
let (encoder_config, decoder_config) = {
|
let encoder_config = match args.which {
|
||||||
let (repo, branch) = args.which.repo_and_branch_name();
|
Which::Base => candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten(),
|
||||||
let config_filename = api
|
Which::Large => {
|
||||||
.repo(hf_hub::Repo::with_revision(
|
candle_transformers::models::vit::Config::microsoft_trocr_base_handwritten()
|
||||||
repo.to_string(),
|
}
|
||||||
hf_hub::RepoType::Model,
|
|
||||||
branch.to_string(),
|
|
||||||
))
|
|
||||||
.get("config.json")?;
|
|
||||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
|
||||||
(config.encoder, config.decoder)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let decoder_config = trocr::TrOCRConfig::default();
|
||||||
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
let mut model = trocr::TrOCRModel::new(&encoder_config, &decoder_config, vb)?;
|
||||||
|
|
||||||
let processor_config = image_processor::ProcessorConfig::default();
|
let config = image_processor::ProcessorConfig::default();
|
||||||
let processor = image_processor::ViTImageProcessor::new(&processor_config);
|
let processor = image_processor::ViTImageProcessor::new(&config);
|
||||||
|
|
||||||
let image = vec![args.image.as_str()];
|
let image = vec![args.image.as_str()];
|
||||||
let image = processor.preprocess(image)?;
|
let image = processor.preprocess(image)?;
|
||||||
|
@ -5,27 +5,12 @@ transcribe image text. See the associated [model
|
|||||||
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
card](https://huggingface.co/microsoft/trocr-base-printed) for details on
|
||||||
the model itself.
|
the model itself.
|
||||||
|
|
||||||
Supported models include:
|
|
||||||
|
|
||||||
- `--which base`: small handwritten OCR model.
|
|
||||||
- `--which large`: large handwritten OCR model.
|
|
||||||
- `--which base-printed`: small printed OCR model.
|
|
||||||
- `--which large-printed`: large printed OCR model.
|
|
||||||
|
|
||||||
## Running an example
|
## Running an example
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example trocr --release -- --image candle-examples/examples/trocr/assets/trocr.png
|
cargo run --example trocr --release -- --which base --cpu --image candle-examples/examples/trocr/assets/trocr.png
|
||||||
cargo run --example trocr --release -- --which large --image candle-examples/examples/trocr/assets/trocr.png
|
|
||||||
cargo run --example trocr --release -- --which base-printed --image candle-examples/examples/trocr/assets/noto.png
|
|
||||||
cargo run --example trocr --release -- --which large-printed --image candle-examples/examples/trocr/assets/noto.png
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Outputs
|
|
||||||
|
|
||||||
```
|
```
|
||||||
industry , Mr. Brown commented icily . " Let us have a
|
<s> industry , Mr. Brown commented icily . " Let us have a</s>
|
||||||
industry , " Mr. Brown commented icily . " Let us have a
|
|
||||||
THE QUICK BROWN FOR JUMPS OVER THE LAY DOG
|
|
||||||
THE QUICK BROWN FOX JUMPS OVER THE LAZY DOG
|
|
||||||
```
|
```
|
||||||
|
@ -1,673 +0,0 @@
|
|||||||
#[cfg(feature = "accelerate")]
|
|
||||||
extern crate accelerate_src;
|
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
|
||||||
extern crate intel_mkl_src;
|
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
|
||||||
use candle::{Device, IndexOp, Tensor};
|
|
||||||
use candle_nn::{ops::softmax, VarBuilder};
|
|
||||||
use clap::{Parser, ValueEnum};
|
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
|
||||||
use std::iter;
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
mod multilingual;
|
|
||||||
|
|
||||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
|
||||||
|
|
||||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
|
||||||
use std::sync::{Arc, Mutex};
|
|
||||||
|
|
||||||
pub enum Model {
|
|
||||||
Normal(m::model::Whisper),
|
|
||||||
Quantized(m::quantized_model::Whisper),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Maybe we should use some traits rather than doing the dispatch for all these.
|
|
||||||
impl Model {
|
|
||||||
pub fn config(&self) -> &Config {
|
|
||||||
match self {
|
|
||||||
Self::Normal(m) => &m.config,
|
|
||||||
Self::Quantized(m) => &m.config,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn encoder_forward(&mut self, x: &Tensor, flush: bool) -> candle::Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::Normal(m) => m.encoder.forward(x, flush),
|
|
||||||
Self::Quantized(m) => m.encoder.forward(x, flush),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decoder_forward(
|
|
||||||
&mut self,
|
|
||||||
x: &Tensor,
|
|
||||||
xa: &Tensor,
|
|
||||||
flush: bool,
|
|
||||||
) -> candle::Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::Normal(m) => m.decoder.forward(x, xa, flush),
|
|
||||||
Self::Quantized(m) => m.decoder.forward(x, xa, flush),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decoder_final_linear(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
match self {
|
|
||||||
Self::Normal(m) => m.decoder.final_linear(x),
|
|
||||||
Self::Quantized(m) => m.decoder.final_linear(x),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct DecodingResult {
|
|
||||||
tokens: Vec<u32>,
|
|
||||||
text: String,
|
|
||||||
avg_logprob: f64,
|
|
||||||
no_speech_prob: f64,
|
|
||||||
temperature: f64,
|
|
||||||
compression_ratio: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Segment {
|
|
||||||
start: f64,
|
|
||||||
duration: f64,
|
|
||||||
dr: DecodingResult,
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Decoder {
|
|
||||||
model: Model,
|
|
||||||
rng: rand::rngs::StdRng,
|
|
||||||
task: Option<Task>,
|
|
||||||
timestamps: bool,
|
|
||||||
verbose: bool,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
suppress_tokens: Tensor,
|
|
||||||
sot_token: u32,
|
|
||||||
transcribe_token: u32,
|
|
||||||
translate_token: u32,
|
|
||||||
eot_token: u32,
|
|
||||||
no_speech_token: u32,
|
|
||||||
no_timestamps_token: u32,
|
|
||||||
language_token: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Decoder {
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn new(
|
|
||||||
model: Model,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
seed: u64,
|
|
||||||
device: &Device,
|
|
||||||
language_token: Option<u32>,
|
|
||||||
task: Option<Task>,
|
|
||||||
timestamps: bool,
|
|
||||||
verbose: bool,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
|
||||||
// Suppress the notimestamps token when in timestamps mode.
|
|
||||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
|
||||||
let suppress_tokens: Vec<f32> = (0..model.config().vocab_size as u32)
|
|
||||||
.map(|i| {
|
|
||||||
if model.config().suppress_tokens.contains(&i)
|
|
||||||
|| timestamps && i == no_timestamps_token
|
|
||||||
{
|
|
||||||
f32::NEG_INFINITY
|
|
||||||
} else {
|
|
||||||
0f32
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
|
||||||
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
|
||||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
|
||||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
|
||||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
|
||||||
let no_speech_token = m::NO_SPEECH_TOKENS
|
|
||||||
.iter()
|
|
||||||
.find_map(|token| token_id(&tokenizer, token).ok());
|
|
||||||
let no_speech_token = match no_speech_token {
|
|
||||||
None => anyhow::bail!("unable to find any non-speech token"),
|
|
||||||
Some(n) => n,
|
|
||||||
};
|
|
||||||
Ok(Self {
|
|
||||||
model,
|
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
|
||||||
tokenizer,
|
|
||||||
task,
|
|
||||||
timestamps,
|
|
||||||
verbose,
|
|
||||||
suppress_tokens,
|
|
||||||
sot_token,
|
|
||||||
transcribe_token,
|
|
||||||
translate_token,
|
|
||||||
eot_token,
|
|
||||||
no_speech_token,
|
|
||||||
language_token,
|
|
||||||
no_timestamps_token,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
|
||||||
let model = &mut self.model;
|
|
||||||
let audio_features = model.encoder_forward(mel, true)?;
|
|
||||||
if self.verbose {
|
|
||||||
println!("audio features: {:?}", audio_features.dims());
|
|
||||||
}
|
|
||||||
let sample_len = model.config().max_target_positions / 2;
|
|
||||||
let mut sum_logprob = 0f64;
|
|
||||||
let mut no_speech_prob = f64::NAN;
|
|
||||||
let mut tokens = vec![self.sot_token];
|
|
||||||
if let Some(language_token) = self.language_token {
|
|
||||||
tokens.push(language_token);
|
|
||||||
}
|
|
||||||
match self.task {
|
|
||||||
None | Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
|
||||||
Some(Task::Translate) => tokens.push(self.translate_token),
|
|
||||||
}
|
|
||||||
if !self.timestamps {
|
|
||||||
tokens.push(self.no_timestamps_token);
|
|
||||||
}
|
|
||||||
for i in 0..sample_len {
|
|
||||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
|
||||||
|
|
||||||
// The model expects a batch dim but this inference loop does not handle
|
|
||||||
// it so we add it at this point.
|
|
||||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
|
||||||
let ys = model.decoder_forward(&tokens_t, &audio_features, i == 0)?;
|
|
||||||
|
|
||||||
// Extract the no speech probability on the first iteration by looking at the first
|
|
||||||
// token logits and the probability for the according token.
|
|
||||||
if i == 0 {
|
|
||||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
|
||||||
no_speech_prob = softmax(&logits, 0)?
|
|
||||||
.i(self.no_speech_token as usize)?
|
|
||||||
.to_scalar::<f32>()? as f64;
|
|
||||||
}
|
|
||||||
|
|
||||||
let (_, seq_len, _) = ys.dims3()?;
|
|
||||||
let logits = model
|
|
||||||
.decoder_final_linear(&ys.i((..1, seq_len - 1..))?)?
|
|
||||||
.i(0)?
|
|
||||||
.i(0)?;
|
|
||||||
// TODO: Besides suppress tokens, we should apply the heuristics from
|
|
||||||
// ApplyTimestampRules, i.e.:
|
|
||||||
// - Timestamps come in pairs, except before EOT.
|
|
||||||
// - Timestamps should be non-decreasing.
|
|
||||||
// - If the sum of the probabilities of timestamps is higher than any other tokens,
|
|
||||||
// only consider timestamps when sampling.
|
|
||||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L439
|
|
||||||
let logits = logits.broadcast_add(&self.suppress_tokens)?;
|
|
||||||
let next_token = if t > 0f64 {
|
|
||||||
let prs = softmax(&(&logits / t)?, 0)?;
|
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
|
||||||
distr.sample(&mut self.rng) as u32
|
|
||||||
} else {
|
|
||||||
let logits_v: Vec<f32> = logits.to_vec1()?;
|
|
||||||
logits_v
|
|
||||||
.iter()
|
|
||||||
.enumerate()
|
|
||||||
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
|
||||||
.map(|(i, _)| i as u32)
|
|
||||||
.unwrap()
|
|
||||||
};
|
|
||||||
tokens.push(next_token);
|
|
||||||
let prob = softmax(&logits, candle::D::Minus1)?
|
|
||||||
.i(next_token as usize)?
|
|
||||||
.to_scalar::<f32>()? as f64;
|
|
||||||
if next_token == self.eot_token || tokens.len() > model.config().max_target_positions {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
sum_logprob += prob.ln();
|
|
||||||
}
|
|
||||||
let text = self.tokenizer.decode(&tokens, true).map_err(E::msg)?;
|
|
||||||
let avg_logprob = sum_logprob / tokens.len() as f64;
|
|
||||||
|
|
||||||
Ok(DecodingResult {
|
|
||||||
tokens,
|
|
||||||
text,
|
|
||||||
avg_logprob,
|
|
||||||
no_speech_prob,
|
|
||||||
temperature: t,
|
|
||||||
compression_ratio: f64::NAN,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
|
||||||
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
|
||||||
let dr: Result<DecodingResult> = self.decode(segment, t);
|
|
||||||
if i == m::TEMPERATURES.len() - 1 {
|
|
||||||
return dr;
|
|
||||||
}
|
|
||||||
// On errors, we try again with a different temperature.
|
|
||||||
match dr {
|
|
||||||
Ok(dr) => {
|
|
||||||
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
|
||||||
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
|
||||||
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
|
||||||
return Ok(dr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
println!("Error running at {t}: {err}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn run(&mut self, mel: &Tensor, times: Option<(f64, f64)>) -> Result<Vec<Segment>> {
|
|
||||||
let (_, _, content_frames) = mel.dims3()?;
|
|
||||||
let mut seek = 0;
|
|
||||||
let mut segments = vec![];
|
|
||||||
while seek < content_frames {
|
|
||||||
let start = std::time::Instant::now();
|
|
||||||
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
|
||||||
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
|
||||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
|
||||||
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
|
||||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
|
||||||
seek += segment_size;
|
|
||||||
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
|
||||||
println!("no speech detected, skipping {seek} {dr:?}");
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let segment = Segment {
|
|
||||||
start: time_offset,
|
|
||||||
duration: segment_duration,
|
|
||||||
dr,
|
|
||||||
};
|
|
||||||
if self.timestamps {
|
|
||||||
println!(
|
|
||||||
"{:.1}s -- {:.1}s",
|
|
||||||
segment.start,
|
|
||||||
segment.start + segment.duration,
|
|
||||||
);
|
|
||||||
let mut tokens_to_decode = vec![];
|
|
||||||
let mut prev_timestamp_s = 0f32;
|
|
||||||
for &token in segment.dr.tokens.iter() {
|
|
||||||
if token == self.sot_token || token == self.eot_token {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// The no_timestamp_token is the last before the timestamp ones.
|
|
||||||
if token > self.no_timestamps_token {
|
|
||||||
let timestamp_s = (token - self.no_timestamps_token + 1) as f32 / 50.;
|
|
||||||
if !tokens_to_decode.is_empty() {
|
|
||||||
let text = self
|
|
||||||
.tokenizer
|
|
||||||
.decode(&tokens_to_decode, true)
|
|
||||||
.map_err(E::msg)?;
|
|
||||||
println!(" {:.1}s-{:.1}s: {}", prev_timestamp_s, timestamp_s, text);
|
|
||||||
tokens_to_decode.clear()
|
|
||||||
}
|
|
||||||
prev_timestamp_s = timestamp_s;
|
|
||||||
} else {
|
|
||||||
tokens_to_decode.push(token)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !tokens_to_decode.is_empty() {
|
|
||||||
let text = self
|
|
||||||
.tokenizer
|
|
||||||
.decode(&tokens_to_decode, true)
|
|
||||||
.map_err(E::msg)?;
|
|
||||||
if !text.is_empty() {
|
|
||||||
println!(" {:.1}s-...: {}", prev_timestamp_s, text);
|
|
||||||
}
|
|
||||||
tokens_to_decode.clear()
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
match times {
|
|
||||||
Some((start, end)) => {
|
|
||||||
println!("{:.1}s -- {:.1}s: {}", start, end, segment.dr.text)
|
|
||||||
}
|
|
||||||
None => {
|
|
||||||
println!(
|
|
||||||
"{:.1}s -- {:.1}s: {}",
|
|
||||||
segment.start,
|
|
||||||
segment.start + segment.duration,
|
|
||||||
segment.dr.text,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if self.verbose {
|
|
||||||
println!("{seek}: {segment:?}, in {:?}", start.elapsed());
|
|
||||||
}
|
|
||||||
segments.push(segment)
|
|
||||||
}
|
|
||||||
Ok(segments)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_language_token(&mut self, language_token: Option<u32>) {
|
|
||||||
self.language_token = language_token;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
match &mut self.model {
|
|
||||||
Model::Normal(m) => m.reset_kv_cache(),
|
|
||||||
Model::Quantized(m) => m.reset_kv_cache(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model(&mut self) -> &mut Model {
|
|
||||||
&mut self.model
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
|
||||||
match tokenizer.token_to_id(token) {
|
|
||||||
None => candle::bail!("no token-id for {token}"),
|
|
||||||
Some(id) => Ok(id),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
|
||||||
enum Task {
|
|
||||||
Transcribe,
|
|
||||||
Translate,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)]
|
|
||||||
enum WhichModel {
|
|
||||||
Tiny,
|
|
||||||
#[value(name = "tiny.en")]
|
|
||||||
TinyEn,
|
|
||||||
Base,
|
|
||||||
#[value(name = "base.en")]
|
|
||||||
BaseEn,
|
|
||||||
Small,
|
|
||||||
#[value(name = "small.en")]
|
|
||||||
SmallEn,
|
|
||||||
Medium,
|
|
||||||
#[value(name = "medium.en")]
|
|
||||||
MediumEn,
|
|
||||||
Large,
|
|
||||||
LargeV2,
|
|
||||||
LargeV3,
|
|
||||||
#[value(name = "distil-medium.en")]
|
|
||||||
DistilMediumEn,
|
|
||||||
#[value(name = "distil-large-v2")]
|
|
||||||
DistilLargeV2,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl WhichModel {
|
|
||||||
fn is_multilingual(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
Self::Tiny
|
|
||||||
| Self::Base
|
|
||||||
| Self::Small
|
|
||||||
| Self::Medium
|
|
||||||
| Self::Large
|
|
||||||
| Self::LargeV2
|
|
||||||
| Self::LargeV3
|
|
||||||
| Self::DistilLargeV2 => true,
|
|
||||||
Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn model_and_revision(&self) -> (&'static str, &'static str) {
|
|
||||||
match self {
|
|
||||||
Self::Tiny => ("openai/whisper-tiny", "main"),
|
|
||||||
Self::TinyEn => ("openai/whisper-tiny.en", "refs/pr/15"),
|
|
||||||
Self::Base => ("openai/whisper-base", "refs/pr/22"),
|
|
||||||
Self::BaseEn => ("openai/whisper-base.en", "refs/pr/13"),
|
|
||||||
Self::Small => ("openai/whisper-small", "main"),
|
|
||||||
Self::SmallEn => ("openai/whisper-small.en", "refs/pr/10"),
|
|
||||||
Self::Medium => ("openai/whisper-medium", "main"),
|
|
||||||
Self::MediumEn => ("openai/whisper-medium.en", "main"),
|
|
||||||
Self::Large => ("openai/whisper-large", "refs/pr/36"),
|
|
||||||
Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"),
|
|
||||||
Self::LargeV3 => ("openai/whisper-large-v3", "main"),
|
|
||||||
Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"),
|
|
||||||
Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
|
||||||
#[command(author, version, about, long_about = None)]
|
|
||||||
struct Args {
|
|
||||||
/// Run on CPU rather than on GPU.
|
|
||||||
#[arg(long)]
|
|
||||||
cpu: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
model_id: Option<String>,
|
|
||||||
|
|
||||||
/// The model to use, check out available models:
|
|
||||||
/// https://huggingface.co/models?search=whisper
|
|
||||||
#[arg(long)]
|
|
||||||
revision: Option<String>,
|
|
||||||
|
|
||||||
/// The model to be used, can be tiny, small, medium.
|
|
||||||
#[arg(long, default_value = "tiny.en")]
|
|
||||||
model: WhichModel,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
|
||||||
#[arg(long, default_value_t = 299792458)]
|
|
||||||
seed: u64,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
|
||||||
#[arg(long)]
|
|
||||||
tracing: bool,
|
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
quantized: bool,
|
|
||||||
|
|
||||||
/// Language.
|
|
||||||
#[arg(long)]
|
|
||||||
language: Option<String>,
|
|
||||||
|
|
||||||
/// Task, when no task is specified, the input tokens contain only the sot token which can
|
|
||||||
/// improve things when in no-timestamp mode.
|
|
||||||
#[arg(long)]
|
|
||||||
task: Option<Task>,
|
|
||||||
|
|
||||||
/// Timestamps mode, this is not fully implemented yet.
|
|
||||||
#[arg(long)]
|
|
||||||
timestamps: bool,
|
|
||||||
|
|
||||||
/// Print the full DecodingResult structure rather than just the text.
|
|
||||||
#[arg(long)]
|
|
||||||
verbose: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() -> Result<()> {
|
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
|
||||||
use tracing_subscriber::prelude::*;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
|
||||||
Some(guard)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
|
||||||
let (default_model, default_revision) = if args.quantized {
|
|
||||||
("lmz/candle-whisper", "main")
|
|
||||||
} else {
|
|
||||||
args.model.model_and_revision()
|
|
||||||
};
|
|
||||||
let default_model = default_model.to_string();
|
|
||||||
let default_revision = default_revision.to_string();
|
|
||||||
let (model_id, revision) = match (args.model_id, args.revision) {
|
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
|
||||||
(None, Some(revision)) => (default_model, revision),
|
|
||||||
(None, None) => (default_model, default_revision),
|
|
||||||
};
|
|
||||||
|
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
|
||||||
let api = Api::new()?;
|
|
||||||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
|
||||||
let (config, tokenizer, model) = if args.quantized {
|
|
||||||
let ext = match args.model {
|
|
||||||
WhichModel::TinyEn => "tiny-en",
|
|
||||||
WhichModel::Tiny => "tiny",
|
|
||||||
_ => unimplemented!("no quantized support for {:?}", args.model),
|
|
||||||
};
|
|
||||||
(
|
|
||||||
repo.get(&format!("config-{ext}.json"))?,
|
|
||||||
repo.get(&format!("tokenizer-{ext}.json"))?,
|
|
||||||
repo.get(&format!("model-{ext}-q80.gguf"))?,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
let config = repo.get("config.json")?;
|
|
||||||
let tokenizer = repo.get("tokenizer.json")?;
|
|
||||||
let model = repo.get("model.safetensors")?;
|
|
||||||
(config, tokenizer, model)
|
|
||||||
};
|
|
||||||
(config, tokenizer, model)
|
|
||||||
};
|
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
let model = if args.quantized {
|
|
||||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(
|
|
||||||
&weights_filename,
|
|
||||||
&device,
|
|
||||||
)?;
|
|
||||||
Model::Quantized(m::quantized_model::Whisper::load(&vb, config.clone())?)
|
|
||||||
} else {
|
|
||||||
let vb =
|
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
|
||||||
Model::Normal(m::model::Whisper::load(&vb, config.clone())?)
|
|
||||||
};
|
|
||||||
let language_token = None;
|
|
||||||
let mut dc = Decoder::new(
|
|
||||||
model,
|
|
||||||
tokenizer.clone(),
|
|
||||||
args.seed,
|
|
||||||
&device,
|
|
||||||
language_token,
|
|
||||||
args.task,
|
|
||||||
args.timestamps,
|
|
||||||
args.verbose,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mel_bytes = match config.num_mel_bins {
|
|
||||||
80 => include_bytes!("../whisper/melfilters.bytes").as_slice(),
|
|
||||||
128 => include_bytes!("../whisper/melfilters128.bytes").as_slice(),
|
|
||||||
nmel => anyhow::bail!("unexpected num_mel_bins {nmel}"),
|
|
||||||
};
|
|
||||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
|
||||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
|
||||||
|
|
||||||
// Set up the input device and stream with the default input config.
|
|
||||||
let host = cpal::default_host();
|
|
||||||
let _device = "default";
|
|
||||||
let _device = if _device == "default" {
|
|
||||||
host.default_input_device()
|
|
||||||
} else {
|
|
||||||
host.input_devices()?
|
|
||||||
.find(|x| x.name().map(|y| y == _device).unwrap_or(false))
|
|
||||||
}
|
|
||||||
.expect("failed to find input device");
|
|
||||||
|
|
||||||
let _config = _device
|
|
||||||
.default_input_config()
|
|
||||||
.expect("Failed to get default input config");
|
|
||||||
|
|
||||||
let channel_count = _config.channels() as usize;
|
|
||||||
|
|
||||||
let audio_ring_buffer = Arc::new(Mutex::new(Vec::new()));
|
|
||||||
let audio_ring_buffer_2 = audio_ring_buffer.clone();
|
|
||||||
|
|
||||||
std::thread::spawn(move || loop {
|
|
||||||
let data = record_audio(&_device, &_config, 300).unwrap();
|
|
||||||
audio_ring_buffer.lock().unwrap().extend_from_slice(&data);
|
|
||||||
let max_len = data.len() * 16;
|
|
||||||
let data_len = data.len();
|
|
||||||
let len = audio_ring_buffer.lock().unwrap().len();
|
|
||||||
if len > max_len {
|
|
||||||
let mut data = audio_ring_buffer.lock().unwrap();
|
|
||||||
let new_data = data[data_len..].to_vec();
|
|
||||||
*data = new_data;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// loop to process the audio data forever (until the user stops the program)
|
|
||||||
println!("Transcribing audio...");
|
|
||||||
for (i, _) in iter::repeat(()).enumerate() {
|
|
||||||
std::thread::sleep(std::time::Duration::from_millis(1000));
|
|
||||||
let data = audio_ring_buffer_2.lock().unwrap().clone();
|
|
||||||
let pcm_data: Vec<_> = data[..data.len() / channel_count as usize]
|
|
||||||
.iter()
|
|
||||||
.map(|v| *v as f32 / 32768.)
|
|
||||||
.collect();
|
|
||||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
|
||||||
let mel_len = mel.len();
|
|
||||||
let mel = Tensor::from_vec(
|
|
||||||
mel,
|
|
||||||
(1, config.num_mel_bins, mel_len / config.num_mel_bins),
|
|
||||||
&device,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// on the first iteration, we detect the language and set the language token.
|
|
||||||
if i == 0 {
|
|
||||||
let language_token = match (args.model.is_multilingual(), args.language.clone()) {
|
|
||||||
(true, None) => Some(multilingual::detect_language(dc.model(), &tokenizer, &mel)?),
|
|
||||||
(false, None) => None,
|
|
||||||
(true, Some(language)) => match token_id(&tokenizer, &format!("<|{language}|>")) {
|
|
||||||
Ok(token_id) => Some(token_id),
|
|
||||||
Err(_) => anyhow::bail!("language {language} is not supported"),
|
|
||||||
},
|
|
||||||
(false, Some(_)) => {
|
|
||||||
anyhow::bail!("a language cannot be set for non-multilingual models")
|
|
||||||
}
|
|
||||||
};
|
|
||||||
println!("language_token: {:?}", language_token);
|
|
||||||
dc.set_language_token(language_token);
|
|
||||||
}
|
|
||||||
dc.run(
|
|
||||||
&mel,
|
|
||||||
Some((
|
|
||||||
i as f64,
|
|
||||||
i as f64 + data.len() as f64 / m::SAMPLE_RATE as f64,
|
|
||||||
)),
|
|
||||||
)?;
|
|
||||||
dc.reset_kv_cache();
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn record_audio(
|
|
||||||
device: &cpal::Device,
|
|
||||||
config: &cpal::SupportedStreamConfig,
|
|
||||||
milliseconds: u64,
|
|
||||||
) -> Result<Vec<i16>> {
|
|
||||||
let writer = Arc::new(Mutex::new(Vec::new()));
|
|
||||||
let writer_2 = writer.clone();
|
|
||||||
let stream = device.build_input_stream(
|
|
||||||
&config.config(),
|
|
||||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
|
||||||
let processed = data
|
|
||||||
.iter()
|
|
||||||
.map(|v| (v * 32768.0) as i16)
|
|
||||||
.collect::<Vec<i16>>();
|
|
||||||
writer_2.lock().unwrap().extend_from_slice(&processed);
|
|
||||||
},
|
|
||||||
move |err| {
|
|
||||||
eprintln!("an error occurred on stream: {}", err);
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
)?;
|
|
||||||
stream.play()?;
|
|
||||||
std::thread::sleep(std::time::Duration::from_millis(milliseconds));
|
|
||||||
drop(stream);
|
|
||||||
let data = writer.lock().unwrap().clone();
|
|
||||||
let step = 3;
|
|
||||||
let data: Vec<i16> = data.iter().step_by(step).copied().collect();
|
|
||||||
Ok(data)
|
|
||||||
}
|
|
@ -1,137 +0,0 @@
|
|||||||
use crate::{token_id, Model};
|
|
||||||
use candle::{IndexOp, Result, Tensor, D};
|
|
||||||
use candle_transformers::models::whisper::{self as m};
|
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
const LANGUAGES: [(&str, &str); 99] = [
|
|
||||||
("en", "english"),
|
|
||||||
("zh", "chinese"),
|
|
||||||
("de", "german"),
|
|
||||||
("es", "spanish"),
|
|
||||||
("ru", "russian"),
|
|
||||||
("ko", "korean"),
|
|
||||||
("fr", "french"),
|
|
||||||
("ja", "japanese"),
|
|
||||||
("pt", "portuguese"),
|
|
||||||
("tr", "turkish"),
|
|
||||||
("pl", "polish"),
|
|
||||||
("ca", "catalan"),
|
|
||||||
("nl", "dutch"),
|
|
||||||
("ar", "arabic"),
|
|
||||||
("sv", "swedish"),
|
|
||||||
("it", "italian"),
|
|
||||||
("id", "indonesian"),
|
|
||||||
("hi", "hindi"),
|
|
||||||
("fi", "finnish"),
|
|
||||||
("vi", "vietnamese"),
|
|
||||||
("he", "hebrew"),
|
|
||||||
("uk", "ukrainian"),
|
|
||||||
("el", "greek"),
|
|
||||||
("ms", "malay"),
|
|
||||||
("cs", "czech"),
|
|
||||||
("ro", "romanian"),
|
|
||||||
("da", "danish"),
|
|
||||||
("hu", "hungarian"),
|
|
||||||
("ta", "tamil"),
|
|
||||||
("no", "norwegian"),
|
|
||||||
("th", "thai"),
|
|
||||||
("ur", "urdu"),
|
|
||||||
("hr", "croatian"),
|
|
||||||
("bg", "bulgarian"),
|
|
||||||
("lt", "lithuanian"),
|
|
||||||
("la", "latin"),
|
|
||||||
("mi", "maori"),
|
|
||||||
("ml", "malayalam"),
|
|
||||||
("cy", "welsh"),
|
|
||||||
("sk", "slovak"),
|
|
||||||
("te", "telugu"),
|
|
||||||
("fa", "persian"),
|
|
||||||
("lv", "latvian"),
|
|
||||||
("bn", "bengali"),
|
|
||||||
("sr", "serbian"),
|
|
||||||
("az", "azerbaijani"),
|
|
||||||
("sl", "slovenian"),
|
|
||||||
("kn", "kannada"),
|
|
||||||
("et", "estonian"),
|
|
||||||
("mk", "macedonian"),
|
|
||||||
("br", "breton"),
|
|
||||||
("eu", "basque"),
|
|
||||||
("is", "icelandic"),
|
|
||||||
("hy", "armenian"),
|
|
||||||
("ne", "nepali"),
|
|
||||||
("mn", "mongolian"),
|
|
||||||
("bs", "bosnian"),
|
|
||||||
("kk", "kazakh"),
|
|
||||||
("sq", "albanian"),
|
|
||||||
("sw", "swahili"),
|
|
||||||
("gl", "galician"),
|
|
||||||
("mr", "marathi"),
|
|
||||||
("pa", "punjabi"),
|
|
||||||
("si", "sinhala"),
|
|
||||||
("km", "khmer"),
|
|
||||||
("sn", "shona"),
|
|
||||||
("yo", "yoruba"),
|
|
||||||
("so", "somali"),
|
|
||||||
("af", "afrikaans"),
|
|
||||||
("oc", "occitan"),
|
|
||||||
("ka", "georgian"),
|
|
||||||
("be", "belarusian"),
|
|
||||||
("tg", "tajik"),
|
|
||||||
("sd", "sindhi"),
|
|
||||||
("gu", "gujarati"),
|
|
||||||
("am", "amharic"),
|
|
||||||
("yi", "yiddish"),
|
|
||||||
("lo", "lao"),
|
|
||||||
("uz", "uzbek"),
|
|
||||||
("fo", "faroese"),
|
|
||||||
("ht", "haitian creole"),
|
|
||||||
("ps", "pashto"),
|
|
||||||
("tk", "turkmen"),
|
|
||||||
("nn", "nynorsk"),
|
|
||||||
("mt", "maltese"),
|
|
||||||
("sa", "sanskrit"),
|
|
||||||
("lb", "luxembourgish"),
|
|
||||||
("my", "myanmar"),
|
|
||||||
("bo", "tibetan"),
|
|
||||||
("tl", "tagalog"),
|
|
||||||
("mg", "malagasy"),
|
|
||||||
("as", "assamese"),
|
|
||||||
("tt", "tatar"),
|
|
||||||
("haw", "hawaiian"),
|
|
||||||
("ln", "lingala"),
|
|
||||||
("ha", "hausa"),
|
|
||||||
("ba", "bashkir"),
|
|
||||||
("jw", "javanese"),
|
|
||||||
("su", "sundanese"),
|
|
||||||
];
|
|
||||||
|
|
||||||
/// Returns the token id for the selected language.
|
|
||||||
pub fn detect_language(model: &mut Model, tokenizer: &Tokenizer, mel: &Tensor) -> Result<u32> {
|
|
||||||
let (_bsize, _, seq_len) = mel.dims3()?;
|
|
||||||
let mel = mel.narrow(
|
|
||||||
2,
|
|
||||||
0,
|
|
||||||
usize::min(seq_len, model.config().max_source_positions),
|
|
||||||
)?;
|
|
||||||
let device = mel.device();
|
|
||||||
let language_token_ids = LANGUAGES
|
|
||||||
.iter()
|
|
||||||
.map(|(t, _)| token_id(tokenizer, &format!("<|{t}|>")))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let sot_token = token_id(tokenizer, m::SOT_TOKEN)?;
|
|
||||||
let audio_features = model.encoder_forward(&mel, true)?;
|
|
||||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
|
||||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
|
||||||
let ys = model.decoder_forward(&tokens, &audio_features, true)?;
|
|
||||||
let logits = model.decoder_final_linear(&ys.i(..1)?)?.i(0)?.i(0)?;
|
|
||||||
let logits = logits.index_select(&language_token_ids, 0)?;
|
|
||||||
let probs = candle_nn::ops::softmax(&logits, D::Minus1)?;
|
|
||||||
let probs = probs.to_vec1::<f32>()?;
|
|
||||||
let mut probs = LANGUAGES.iter().zip(probs.iter()).collect::<Vec<_>>();
|
|
||||||
probs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
|
||||||
for ((_, language), p) in probs.iter().take(5) {
|
|
||||||
println!("{language}: {p}")
|
|
||||||
}
|
|
||||||
let language = token_id(tokenizer, &format!("<|{}|>", probs[0].0 .0))?;
|
|
||||||
Ok(language)
|
|
||||||
}
|
|
@ -18,8 +18,6 @@ use rand::{distributions::Distribution, SeedableRng};
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod multilingual;
|
mod multilingual;
|
||||||
mod pcm_decode;
|
|
||||||
|
|
||||||
use candle_transformers::models::whisper::{self as m, audio, Config};
|
use candle_transformers::models::whisper::{self as m, audio, Config};
|
||||||
|
|
||||||
pub enum Model {
|
pub enum Model {
|
||||||
@ -537,10 +535,17 @@ fn main() -> Result<()> {
|
|||||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||||
|
|
||||||
let (pcm_data, sample_rate) = pcm_decode::pcm_decode(input)?;
|
let mut input = std::fs::File::open(input)?;
|
||||||
if sample_rate != m::SAMPLE_RATE as u32 {
|
let (header, data) = wav::read(&mut input)?;
|
||||||
anyhow::bail!("input file must have a {} sampling rate", m::SAMPLE_RATE)
|
println!("loaded wav data: {header:?}");
|
||||||
|
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||||
|
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||||
}
|
}
|
||||||
|
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||||
|
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||||
|
.iter()
|
||||||
|
.map(|v| *v as f32 / 32768.)
|
||||||
|
.collect();
|
||||||
println!("pcm data loaded {}", pcm_data.len());
|
println!("pcm data loaded {}", pcm_data.len());
|
||||||
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
let mel = audio::pcm_to_mel(&config, &pcm_data, &mel_filters);
|
||||||
let mel_len = mel.len();
|
let mel_len = mel.len();
|
||||||
|
@ -1,74 +0,0 @@
|
|||||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
|
||||||
use symphonia::core::codecs::{DecoderOptions, CODEC_TYPE_NULL};
|
|
||||||
use symphonia::core::conv::FromSample;
|
|
||||||
|
|
||||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
|
||||||
where
|
|
||||||
T: symphonia::core::sample::Sample,
|
|
||||||
f32: symphonia::core::conv::FromSample<T>,
|
|
||||||
{
|
|
||||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
|
||||||
// Open the media source.
|
|
||||||
let src = std::fs::File::open(path)?;
|
|
||||||
|
|
||||||
// Create the media source stream.
|
|
||||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
|
||||||
|
|
||||||
// Create a probe hint using the file's extension. [Optional]
|
|
||||||
let hint = symphonia::core::probe::Hint::new();
|
|
||||||
|
|
||||||
// Use the default options for metadata and format readers.
|
|
||||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
|
||||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
|
||||||
|
|
||||||
// Probe the media source.
|
|
||||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
|
||||||
// Get the instantiated format reader.
|
|
||||||
let mut format = probed.format;
|
|
||||||
|
|
||||||
// Find the first audio track with a known (decodeable) codec.
|
|
||||||
let track = format
|
|
||||||
.tracks()
|
|
||||||
.iter()
|
|
||||||
.find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
|
|
||||||
.expect("no supported audio tracks");
|
|
||||||
|
|
||||||
// Use the default options for the decoder.
|
|
||||||
let dec_opts: DecoderOptions = Default::default();
|
|
||||||
|
|
||||||
// Create a decoder for the track.
|
|
||||||
let mut decoder = symphonia::default::get_codecs()
|
|
||||||
.make(&track.codec_params, &dec_opts)
|
|
||||||
.expect("unsupported codec");
|
|
||||||
let track_id = track.id;
|
|
||||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
|
||||||
let mut pcm_data = Vec::new();
|
|
||||||
// The decode loop.
|
|
||||||
while let Ok(packet) = format.next_packet() {
|
|
||||||
// Consume any new metadata that has been read since the last packet.
|
|
||||||
while !format.metadata().is_latest() {
|
|
||||||
format.metadata().pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the packet does not belong to the selected track, skip over it.
|
|
||||||
if packet.track_id() != track_id {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
match decoder.decode(&packet)? {
|
|
||||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
|
||||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
|
||||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok((pcm_data, sample_rate))
|
|
||||||
}
|
|
@ -104,7 +104,6 @@ impl TextGeneration {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
let t = t.replace("<|im_end|>", "\n");
|
|
||||||
print!("{t}");
|
print!("{t}");
|
||||||
std::io::stdout().flush()?;
|
std::io::stdout().flush()?;
|
||||||
}
|
}
|
||||||
|
@ -216,7 +216,7 @@ fn detect(
|
|||||||
xs: &Tensor,
|
xs: &Tensor,
|
||||||
image_height: usize,
|
image_height: usize,
|
||||||
classes: usize,
|
classes: usize,
|
||||||
anchors: &[(usize, usize)],
|
anchors: &Vec<(usize, usize)>,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let (bsize, _channels, height, _width) = xs.dims4()?;
|
let (bsize, _channels, height, _width) = xs.dims4()?;
|
||||||
let stride = image_height / height;
|
let stride = image_height / height;
|
||||||
|
@ -40,7 +40,7 @@ impl TokenOutputStream {
|
|||||||
};
|
};
|
||||||
self.tokens.push(token);
|
self.tokens.push(token);
|
||||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() {
|
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
|
||||||
let text = text.split_at(prev_text.len());
|
let text = text.split_at(prev_text.len());
|
||||||
self.prev_index = self.current_index;
|
self.prev_index = self.current_index;
|
||||||
self.current_index = self.tokens.len();
|
self.current_index = self.tokens.len();
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.4.0"
|
version = "0.3.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.0" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.4.0"
|
version = "0.3.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -71,6 +71,7 @@ __device__ void im2col1d(
|
|||||||
}
|
}
|
||||||
const size_t *src_dims = info;
|
const size_t *src_dims = info;
|
||||||
const size_t *src_s = info + 3;
|
const size_t *src_s = info + 3;
|
||||||
|
const size_t b_in = src_dims[0];
|
||||||
const size_t c_in = src_dims[1];
|
const size_t c_in = src_dims[1];
|
||||||
const size_t l_in = src_dims[2];
|
const size_t l_in = src_dims[2];
|
||||||
|
|
||||||
@ -119,6 +120,7 @@ __device__ void im2col(
|
|||||||
}
|
}
|
||||||
const size_t *src_dims = info;
|
const size_t *src_dims = info;
|
||||||
const size_t *src_s = info + 4;
|
const size_t *src_s = info + 4;
|
||||||
|
const size_t b_in = src_dims[0];
|
||||||
const size_t c_in = src_dims[1];
|
const size_t c_in = src_dims[1];
|
||||||
const size_t h_in = src_dims[2];
|
const size_t h_in = src_dims[2];
|
||||||
const size_t w_in = src_dims[3];
|
const size_t w_in = src_dims[3];
|
||||||
@ -223,60 +225,6 @@ __device__ void conv2d(
|
|||||||
dst[dst_i] = static_cast<T>(d);
|
dst[dst_i] = static_cast<T>(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Naive implementation of conv_transpose1d.
|
|
||||||
template <typename T, typename A>
|
|
||||||
__device__ void conv_transpose1d(
|
|
||||||
const size_t src_numel,
|
|
||||||
const size_t l_out,
|
|
||||||
const size_t stride,
|
|
||||||
const size_t padding,
|
|
||||||
const size_t out_padding,
|
|
||||||
const size_t dilation,
|
|
||||||
const size_t *info,
|
|
||||||
const T *src,
|
|
||||||
const T *kernel,
|
|
||||||
T *dst
|
|
||||||
) {
|
|
||||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
|
||||||
// src: (b_size, c_in, l_in)
|
|
||||||
// k: (c_in, c_out, l_k)
|
|
||||||
const size_t *src_dims = info;
|
|
||||||
const size_t *src_s = info + 3;
|
|
||||||
const size_t *k_dims = info + 6;
|
|
||||||
const size_t *k_s = info + 9;
|
|
||||||
const size_t l_k = k_dims[2];
|
|
||||||
const size_t c_out = k_dims[1];
|
|
||||||
const size_t c_in = src_dims[1];
|
|
||||||
const size_t l_in = src_dims[2];
|
|
||||||
if (dst_i >= src_dims[0] * c_out * l_out) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO
|
|
||||||
const size_t b_idx = dst_i / (l_out * c_out);
|
|
||||||
const size_t dst_c_idx = (dst_i / l_out) % c_out;
|
|
||||||
// NCL layout.
|
|
||||||
const size_t out_x = dst_i % l_out;
|
|
||||||
|
|
||||||
const size_t src_idx0 = b_idx * src_s[0];
|
|
||||||
A d = 0;
|
|
||||||
for (int k_x = 0; k_x < (int)l_k; ++k_x) {
|
|
||||||
// let out_x = inp_x * p.stride + k_x * p.dilation - p.padding;
|
|
||||||
int inp_x_stride = (int)(out_x + padding) - k_x * dilation;
|
|
||||||
if (inp_x_stride < 0 || inp_x_stride % stride) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
int inp_x = inp_x_stride / stride;
|
|
||||||
if (inp_x >= l_in) continue;
|
|
||||||
for (size_t src_c_idx = 0; src_c_idx < c_in; ++src_c_idx) {
|
|
||||||
const size_t src_idx = src_idx0 + src_c_idx * src_s[1] + inp_x * src_s[2];
|
|
||||||
const size_t k_idx = src_c_idx * k_s[0] + dst_c_idx * k_s[1] + k_x * k_s[2];
|
|
||||||
d += static_cast<A>(src[src_idx]) * static_cast<A>(kernel[k_idx]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dst[dst_i] = static_cast<T>(d);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Naive implementation of conv_transpose2d.
|
// Naive implementation of conv_transpose2d.
|
||||||
template <typename T, typename A>
|
template <typename T, typename A>
|
||||||
__device__ void conv_transpose2d(
|
__device__ void conv_transpose2d(
|
||||||
@ -559,22 +507,6 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
|
im2col<TYPENAME>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, info, src, dst); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
#define CONVT1D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
|
||||||
extern "C" __global__ void FN_NAME( \
|
|
||||||
const size_t src_numel, \
|
|
||||||
const size_t l_out, \
|
|
||||||
const size_t stride, \
|
|
||||||
const size_t padding, \
|
|
||||||
const size_t out_padding, \
|
|
||||||
const size_t dilation, \
|
|
||||||
const size_t *info, \
|
|
||||||
const TYPENAME *src, \
|
|
||||||
const TYPENAME *kernel, \
|
|
||||||
TYPENAME *dst \
|
|
||||||
) { \
|
|
||||||
conv_transpose1d<TYPENAME, TYPEACC>(src_numel, l_out, stride, padding, out_padding, dilation, info, src, kernel, dst); \
|
|
||||||
} \
|
|
||||||
|
|
||||||
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
#define CONVT2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t src_numel, \
|
const size_t src_numel, \
|
||||||
@ -636,7 +568,6 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
CONV1D_OP(__nv_bfloat16, float, conv1d_bf16)
|
||||||
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
|
CONV2D_OP(__nv_bfloat16, float, conv2d_bf16)
|
||||||
CONVT1D_OP(__nv_bfloat16, float, conv_transpose1d_bf16)
|
|
||||||
CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
|
CONVT2D_OP(__nv_bfloat16, float, conv_transpose2d_bf16)
|
||||||
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||||
@ -648,7 +579,6 @@ IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
|||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
CONV1D_OP(__half, float, conv1d_f16)
|
CONV1D_OP(__half, float, conv1d_f16)
|
||||||
CONV2D_OP(__half, float, conv2d_f16)
|
CONV2D_OP(__half, float, conv2d_f16)
|
||||||
CONVT1D_OP(__half, float, conv_transpose1d_f16)
|
|
||||||
CONVT2D_OP(__half, float, conv_transpose2d_f16)
|
CONVT2D_OP(__half, float, conv_transpose2d_f16)
|
||||||
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||||
@ -667,11 +597,6 @@ CONV2D_OP(double, double, conv2d_f64)
|
|||||||
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
|
CONV2D_OP(uint8_t, uint8_t, conv2d_u8)
|
||||||
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
|
CONV2D_OP(uint32_t, uint32_t, conv2d_u32)
|
||||||
|
|
||||||
CONVT1D_OP(float, float, conv_transpose1d_f32)
|
|
||||||
CONVT1D_OP(double, double, conv_transpose1d_f64)
|
|
||||||
CONVT1D_OP(uint8_t, uint8_t, conv_transpose1d_u8)
|
|
||||||
CONVT1D_OP(uint32_t, uint32_t, conv_transpose1d_u32)
|
|
||||||
|
|
||||||
CONVT2D_OP(float, float, conv_transpose2d_f32)
|
CONVT2D_OP(float, float, conv_transpose2d_f32)
|
||||||
CONVT2D_OP(double, double, conv_transpose2d_f64)
|
CONVT2D_OP(double, double, conv_transpose2d_f64)
|
||||||
CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)
|
CONVT2D_OP(uint8_t, uint8_t, conv_transpose2d_u8)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.4.0"
|
version = "0.3.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
xcrun metal -c src/gemm/kernels/steel_gemm.metal -I src/
|
|
||||||
xcrun metallib steel_gemm.air -o src/gemm/steel_gemm.metallib
|
|
@ -73,7 +73,7 @@ BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \
|
|||||||
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided);
|
||||||
|
|
||||||
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
#define INT64_BINARY_OP_OUT(NAME, FN) \
|
||||||
BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided);
|
BINARY(FN, int64_t, int8_t, NAME##_i64, NAME##_i64_strided);
|
||||||
|
|
||||||
BINARY_OP(x + y, add)
|
BINARY_OP(x + y, add)
|
||||||
BINARY_OP(x - y, sub)
|
BINARY_OP(x - y, sub)
|
||||||
|
@ -1,317 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
|
||||||
|
|
||||||
typedef bfloat bfloat16_t;
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Helpers
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) {
|
|
||||||
// Check for nan
|
|
||||||
if ((as_type<uint32_t>(x) & ~_fp_encoding_traits<float>::sign_mask) >
|
|
||||||
_fp_encoding_traits<float>::inf_mask) {
|
|
||||||
return uint16_t(as_type<uint32_t>(0x7FC0));
|
|
||||||
}
|
|
||||||
// Take bits
|
|
||||||
uint32_t float_bits = as_type<uint32_t>(x);
|
|
||||||
|
|
||||||
// Round to nearest even
|
|
||||||
float_bits += ((float_bits >> 16) & 1) + as_type<uint32_t>(0x7FFF);
|
|
||||||
|
|
||||||
// Take upper 16 bits
|
|
||||||
return float_bits >> 16;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) {
|
|
||||||
// Upper 16 bits are the data and lower 16 bits are 0s
|
|
||||||
return as_type<float>((uint32_t)x << 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct _MLX_BFloat16;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_to_bfloat =
|
|
||||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<T, float>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_from_bfloat =
|
|
||||||
!is_same_v<T, _MLX_BFloat16> && is_convertible_v<float, T>;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat struct
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
struct _MLX_BFloat16 {
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Constructors
|
|
||||||
uint16_t bits_;
|
|
||||||
_MLX_BFloat16() thread = default;
|
|
||||||
_MLX_BFloat16() threadgroup = default;
|
|
||||||
_MLX_BFloat16() device = default;
|
|
||||||
_MLX_BFloat16() constant = default;
|
|
||||||
|
|
||||||
struct bits_to_bfloat_struct {};
|
|
||||||
static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() {
|
|
||||||
return bits_to_bfloat_struct();
|
|
||||||
}
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct)
|
|
||||||
: bits_(bits) {}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Conversions to bfloat
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) thread
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) device
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16(T x) constant
|
|
||||||
: bits_(float_to_bfloat_bits(static_cast<float>(x))) {}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Conversions from bfloat
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const thread {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const threadgroup {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const device {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_bfloat<T>>::type>
|
|
||||||
constexpr METAL_FUNC operator T() const constant {
|
|
||||||
return static_cast<T>(bfloat_bits_to_float(bits_));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat operators
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Unary ops
|
|
||||||
constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) {
|
|
||||||
return -static_cast<float>(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Binary operators
|
|
||||||
#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \
|
|
||||||
constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \
|
|
||||||
constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
} \
|
|
||||||
constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
return static_cast<ctype>(lhs) __op__ static_cast<ctype>(rhs); \
|
|
||||||
}
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Arithmetic Operators
|
|
||||||
#define bfloat_binop(_op_, _operator_) \
|
|
||||||
bfloat_binop_base( \
|
|
||||||
_op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, float, float, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, float, half, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \
|
|
||||||
bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float);
|
|
||||||
|
|
||||||
bfloat_binop(+, operator+);
|
|
||||||
bfloat_binop(-, operator-);
|
|
||||||
bfloat_binop(*, operator*);
|
|
||||||
bfloat_binop(/, operator/);
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Comparison ops
|
|
||||||
#define bfloat_compop(__op__, __operator__) \
|
|
||||||
bfloat_binop_base( \
|
|
||||||
__op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, float, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, half, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \
|
|
||||||
bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float);
|
|
||||||
|
|
||||||
bfloat_compop(>, operator>);
|
|
||||||
bfloat_compop(<, operator<);
|
|
||||||
bfloat_compop(>=, operator>=);
|
|
||||||
bfloat_compop(<=, operator<=);
|
|
||||||
bfloat_compop(==, operator==);
|
|
||||||
bfloat_compop(!=, operator!=);
|
|
||||||
|
|
||||||
#undef bfloat_compop
|
|
||||||
#undef bfloat_binop_base
|
|
||||||
#undef bfloat_binop_helper
|
|
||||||
#undef bfloat_binop
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Inplace Operators
|
|
||||||
#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \
|
|
||||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
|
||||||
addr_space _MLX_BFloat16& lhs, itype rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
} \
|
|
||||||
constexpr METAL_FUNC addr_space itype& __operator__( \
|
|
||||||
addr_space itype& lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, device); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup);
|
|
||||||
|
|
||||||
#define bfloat_inplace_op(itype) \
|
|
||||||
bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \
|
|
||||||
bfloat_inplace_op_addr_space_helper(/, operator/=, itype);
|
|
||||||
|
|
||||||
bfloat_inplace_op(float);
|
|
||||||
bfloat_inplace_op(half);
|
|
||||||
bfloat_inplace_op(int16_t);
|
|
||||||
bfloat_inplace_op(int32_t);
|
|
||||||
bfloat_inplace_op(int64_t);
|
|
||||||
bfloat_inplace_op(uint16_t);
|
|
||||||
bfloat_inplace_op(uint32_t);
|
|
||||||
bfloat_inplace_op(uint64_t);
|
|
||||||
|
|
||||||
#undef bfloat_inplace_op_helper
|
|
||||||
#undef bfloat_inplace_op_addr_space_helper
|
|
||||||
#undef bfloat_inplace_op
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \
|
|
||||||
constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \
|
|
||||||
addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \
|
|
||||||
lhs = static_cast<float>(lhs) __op__ static_cast<float>(rhs); \
|
|
||||||
return lhs; \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, device); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, thread); \
|
|
||||||
bfloat_inplace_op_helper(__op__, __operator__, threadgroup);
|
|
||||||
|
|
||||||
bfloat_inplace_op_addr_space_helper(+, operator+=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(-, operator-=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(*, operator*=);
|
|
||||||
bfloat_inplace_op_addr_space_helper(/, operator/=);
|
|
||||||
|
|
||||||
#undef bfloat_inplace_op_helper
|
|
||||||
#undef bfloat_inplace_op_addr_space_helper
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat typedef
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
typedef struct _MLX_BFloat16 bfloat16_t;
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Bfloat numeric limits
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#pragma METAL internals : enable
|
|
||||||
|
|
||||||
namespace metal {
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct _numeric_limits_impl<bfloat16_t> : _fp_numeric_limits_impl_base {
|
|
||||||
static constexpr constant int digits = 8;
|
|
||||||
static constexpr constant int digits10 = 2;
|
|
||||||
static constexpr constant int max_digits10 = 4;
|
|
||||||
static constexpr constant int radix = 2;
|
|
||||||
static constexpr constant int min_exponent = -125;
|
|
||||||
static constexpr constant int min_exponent10 = -37;
|
|
||||||
static constexpr constant int max_exponent = 128;
|
|
||||||
static constexpr constant int max_exponent10 = 38;
|
|
||||||
|
|
||||||
static constexpr bfloat16_t min() {
|
|
||||||
return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t lowest() {
|
|
||||||
return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t max() {
|
|
||||||
return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t epsilon() {
|
|
||||||
return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t round_error() {
|
|
||||||
return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t infinity() {
|
|
||||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t quiet_NaN() {
|
|
||||||
return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t signaling_NaN() {
|
|
||||||
return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
static constexpr bfloat16_t denorm_min() {
|
|
||||||
return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
METAL_FUNC bool isnan(_MLX_BFloat16 x) {
|
|
||||||
return x != x;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace metal
|
|
||||||
|
|
||||||
#pragma METAL internals : disable
|
|
||||||
|
|
||||||
#endif // defined(__HAVE_BFLOAT__)
|
|
||||||
|
|
||||||
#include "gemm/bf16_math.h"
|
|
@ -1,394 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "gemm/bf16.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Metal math for bfloat16
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
Following the Metal Shading Language Specification (Metal 3.1)
|
|
||||||
|
|
||||||
"bfloat is an extended itypeing point type that only allows implicit conversion
|
|
||||||
to a type of greater itypeing point rank. While bfloat can be implicitly
|
|
||||||
converted to itype, it cannot be implicitly converted to half, and neither
|
|
||||||
itype nor half can be implicitly converted to bfloat."
|
|
||||||
|
|
||||||
Further, as far as I can tell, the stdlib math/simd functions are not defined
|
|
||||||
for bfloat and calling with an argument of type bfloat will result in that
|
|
||||||
argument getting implicitly converted to itype which then returns an output
|
|
||||||
that is (likely) a itype which cannot be implicitly converted into a bfloat
|
|
||||||
|
|
||||||
This leads to situations where
|
|
||||||
bfloat a = 5.0bf;
|
|
||||||
bfloat b = metal::abs(a); // this will throw an error since abs return itype
|
|
||||||
bfloat c = static_cast<bfloat>(metal::abs(a)); // this is fine
|
|
||||||
|
|
||||||
For the moment, I will be adding overloaded instantiations of the math
|
|
||||||
functions to accordingly automatically handle the casting
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype abs(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype acos(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_acos(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype acosh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_acosh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype asin(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_asin(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype asinh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_asinh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype atan(itype y_over_x) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_atan(static_cast<ctype>(y_over_x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype atan2(itype y, itype x) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_atan2(static_cast<ctype>(y), static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype atanh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_atanh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype ceil(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_ceil(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype cos(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_cos(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype cosh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_cosh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype cospi(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_cospi(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype divide(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_divide(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype exp(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_exp(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype exp10(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_exp10(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype exp2(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_exp2(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fabs(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_fabs(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fdim(itype x, itype y) { \
|
|
||||||
ctype t = static_cast<ctype>(x - y); \
|
|
||||||
return static_cast<otype>(select(t, ctype(0), t < ctype(0) || x == y)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype floor(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_floor(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fma(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fma( \
|
|
||||||
static_cast<ctype>(x), static_cast<ctype>(y), static_cast<ctype>(z))); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmax(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmax3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmax3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmedian3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmin(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmin3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmin3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fmod(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmod(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype fract(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_fract(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype frexp(itype x, thread int& exp) { \
|
|
||||||
return static_cast<otype>(__metal_frexp(static_cast<ctype>(x), &exp)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype ldexp(itype x, int k) { \
|
|
||||||
return static_cast<otype>(__metal_ldexp(static_cast<ctype>(x), k, mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype log(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_log(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype log10(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_log10(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype log2(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_log2(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype max(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmax(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype max3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmax3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype median3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmedian3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype min(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_fmin(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype min3(itype x, itype y, itype z) { \
|
|
||||||
return static_cast<otype>(__metal_fmin3( \
|
|
||||||
static_cast<ctype>(x), \
|
|
||||||
static_cast<ctype>(y), \
|
|
||||||
static_cast<ctype>(z), \
|
|
||||||
mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype nextafter(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_nextafter(static_cast<ctype>(x), static_cast<ctype>(y))); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype pow(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_pow(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype powr(itype x, itype y) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_powr(static_cast<ctype>(x), static_cast<ctype>(y), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype rint(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_rint(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype round(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_round(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype rsqrt(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_rsqrt(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype sin(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_sin(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype sinh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_sinh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype sinpi(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_sinpi(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype sqrt(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_sqrt(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype tan(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_tan(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype tanh(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_tanh(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype tanpi(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_tanpi(static_cast<ctype>(x), mfast)); \
|
|
||||||
} \
|
|
||||||
METAL_FUNC otype trunc(itype x) { \
|
|
||||||
return static_cast<otype>(__metal_trunc(static_cast<ctype>(x), mfast)); \
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace metal {
|
|
||||||
|
|
||||||
instantiate_metal_math_funcs(
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
float,
|
|
||||||
__METAL_MAYBE_FAST_MATH__);
|
|
||||||
|
|
||||||
namespace fast {
|
|
||||||
|
|
||||||
instantiate_metal_math_funcs(
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
float,
|
|
||||||
__METAL_FAST_MATH__);
|
|
||||||
|
|
||||||
} // namespace fast
|
|
||||||
|
|
||||||
namespace precise {
|
|
||||||
|
|
||||||
instantiate_metal_math_funcs(
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
float,
|
|
||||||
__METAL_PRECISE_MATH__);
|
|
||||||
|
|
||||||
} // namespace precise
|
|
||||||
|
|
||||||
} // namespace metal
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Metal simd for bfloat16
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_metal_simd_comm_funcs( \
|
|
||||||
itype, otype, ctype, itype_to_ctype, ctype_to_otype) \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
|
||||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
|
||||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
|
||||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_and_fill_down( \
|
|
||||||
itype data, itype filling_data, ushort delta) { \
|
|
||||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \
|
|
||||||
itype_to_ctype(data), \
|
|
||||||
itype_to_ctype(filling_data), \
|
|
||||||
delta, \
|
|
||||||
__metal_get_simdgroup_size(ushort()))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
|
||||||
itype data, itype filling_data, ushort delta, ushort modulo) { \
|
|
||||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
|
||||||
itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_and_fill_up( \
|
|
||||||
itype data, itype filling_data, ushort delta) { \
|
|
||||||
return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \
|
|
||||||
itype_to_ctype(data), \
|
|
||||||
itype_to_ctype(filling_data), \
|
|
||||||
delta, \
|
|
||||||
__metal_get_simdgroup_size(ushort()))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_down(itype_to_ctype(data), delta)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_up(itype_to_ctype(data), delta)); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \
|
|
||||||
return ctype_to_otype( \
|
|
||||||
__metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_max(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_max(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_min(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_min(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_simd_prefix_exclusive_product(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_simd_prefix_exclusive_sum(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_simd_prefix_inclusive_product(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \
|
|
||||||
return static_cast<otype>( \
|
|
||||||
__metal_simd_prefix_inclusive_sum(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_product(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_product(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_sum(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_sum(static_cast<ctype>(data))); \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
METAL_FUNC otype simd_xor(itype data) { \
|
|
||||||
return static_cast<otype>(__metal_simd_xor(static_cast<ctype>(data))); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) as_type<uint16_t>(x)
|
|
||||||
#define uint16_to_bfloat16(x) as_type<bfloat16_t>(x)
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
#define bfloat16_to_uint16(x) x.bits_
|
|
||||||
#define uint16_to_bfloat16(x) _MLX_BFloat16(x, _MLX_BFloat16::bits_to_bfloat())
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace metal {
|
|
||||||
|
|
||||||
instantiate_metal_simd_comm_funcs(
|
|
||||||
bfloat16_t,
|
|
||||||
bfloat16_t,
|
|
||||||
uint16_t,
|
|
||||||
bfloat16_to_uint16,
|
|
||||||
uint16_to_bfloat16);
|
|
||||||
instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float);
|
|
||||||
|
|
||||||
} // namespace metal
|
|
@ -1,131 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
struct complex64_t;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_to_complex64 =
|
|
||||||
!is_same_v<T, complex64_t> && is_convertible_v<T, float>;
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static constexpr constant bool can_convert_from_complex64 =
|
|
||||||
!is_same_v<T, complex64_t> &&
|
|
||||||
(is_convertible_v<float, T> || is_convertible_v<bfloat16_t, T>);
|
|
||||||
|
|
||||||
struct complex64_t {
|
|
||||||
float real;
|
|
||||||
float imag;
|
|
||||||
|
|
||||||
// Constructors
|
|
||||||
constexpr complex64_t(float real, float imag) : real(real), imag(imag){};
|
|
||||||
|
|
||||||
// Conversions to complex64_t
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
||||||
constexpr complex64_t(T x) thread : real(x), imag(0) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
||||||
constexpr complex64_t(T x) threadgroup : real(x), imag(0) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
||||||
constexpr complex64_t(T x) device : real(x), imag(0) {}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_to_complex64<T>>::type>
|
|
||||||
constexpr complex64_t(T x) constant : real(x), imag(0) {}
|
|
||||||
|
|
||||||
// Conversions from complex64_t
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
||||||
constexpr operator T() const thread {
|
|
||||||
return static_cast<T>(real);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
||||||
constexpr operator T() const threadgroup {
|
|
||||||
return static_cast<T>(real);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
||||||
constexpr operator T() const device {
|
|
||||||
return static_cast<T>(real);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename = typename enable_if<can_convert_from_complex64<T>>::type>
|
|
||||||
constexpr operator T() const constant {
|
|
||||||
return static_cast<T>(real);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
constexpr complex64_t operator-(complex64_t x) {
|
|
||||||
return {-x.real, -x.imag};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator>=(complex64_t a, complex64_t b) {
|
|
||||||
return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator>(complex64_t a, complex64_t b) {
|
|
||||||
return (a.real > b.real) || (a.real == b.real && a.imag > b.imag);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator<=(complex64_t a, complex64_t b) {
|
|
||||||
return operator>=(b, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator<(complex64_t a, complex64_t b) {
|
|
||||||
return operator>(b, a);
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr bool operator==(complex64_t a, complex64_t b) {
|
|
||||||
return a.real == b.real && a.imag == b.imag;
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator+(complex64_t a, complex64_t b) {
|
|
||||||
return {a.real + b.real, a.imag + b.imag};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator-(complex64_t a, complex64_t b) {
|
|
||||||
return {a.real - b.real, a.imag - b.imag};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator*(complex64_t a, complex64_t b) {
|
|
||||||
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator/(complex64_t a, complex64_t b) {
|
|
||||||
auto denom = b.real * b.real + b.imag * b.imag;
|
|
||||||
auto x = a.real * b.real + a.imag * b.imag;
|
|
||||||
auto y = a.imag * b.real - a.real * b.imag;
|
|
||||||
return {x / denom, y / denom};
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr complex64_t operator%(complex64_t a, complex64_t b) {
|
|
||||||
auto real = a.real - (b.real * static_cast<int64_t>(a.real / b.real));
|
|
||||||
auto imag = a.imag - (b.imag * static_cast<int64_t>(a.imag / b.imag));
|
|
||||||
if (real != 0 && (real < 0 != b.real < 0)) {
|
|
||||||
real += b.real;
|
|
||||||
}
|
|
||||||
if (imag != 0 && (imag < 0 != b.imag < 0)) {
|
|
||||||
imag += b.imag;
|
|
||||||
}
|
|
||||||
return {real, imag};
|
|
||||||
}
|
|
@ -1,292 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "gemm/loader.h"
|
|
||||||
#include "gemm/mma.h"
|
|
||||||
#include "gemm/transforms.h"
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel class
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
template <bool M_aligned, bool N_aligned, bool K_aligned>
|
|
||||||
struct LoopAlignment {};
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename U,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned,
|
|
||||||
typename AccumType = typename AccumHelper<T>::accum_type,
|
|
||||||
typename Epilogue = TransformNone<U, AccumType>>
|
|
||||||
struct GEMMKernel {
|
|
||||||
STEEL_CONST short tgp_padding_a = 16 / sizeof(T);
|
|
||||||
STEEL_CONST short tgp_padding_b = 16 / sizeof(T);
|
|
||||||
STEEL_CONST short tgp_mem_size_a =
|
|
||||||
transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a);
|
|
||||||
STEEL_CONST short tgp_mem_size_b =
|
|
||||||
transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b);
|
|
||||||
STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b;
|
|
||||||
|
|
||||||
STEEL_CONST short tgp_size = WM * WN * 32;
|
|
||||||
|
|
||||||
using loader_a_t = BlockLoader<
|
|
||||||
T,
|
|
||||||
transpose_a ? BK : BM,
|
|
||||||
transpose_a ? BM : BK,
|
|
||||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
|
||||||
!transpose_a,
|
|
||||||
tgp_size>;
|
|
||||||
using loader_b_t = BlockLoader<
|
|
||||||
T,
|
|
||||||
transpose_b ? BN : BK,
|
|
||||||
transpose_b ? BK : BN,
|
|
||||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
|
||||||
transpose_b,
|
|
||||||
tgp_size>;
|
|
||||||
using mma_t = BlockMMA<
|
|
||||||
T,
|
|
||||||
U,
|
|
||||||
BM,
|
|
||||||
BN,
|
|
||||||
BK,
|
|
||||||
WM,
|
|
||||||
WN,
|
|
||||||
transpose_a,
|
|
||||||
transpose_b,
|
|
||||||
transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a,
|
|
||||||
transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b,
|
|
||||||
AccumType,
|
|
||||||
Epilogue>;
|
|
||||||
|
|
||||||
/* Main kernel function */
|
|
||||||
template <bool M_aligned, bool N_aligned, bool K_aligned_>
|
|
||||||
static METAL_FUNC void gemm_loop(
|
|
||||||
threadgroup T* As [[threadgroup(0)]],
|
|
||||||
threadgroup T* Bs [[threadgroup(1)]],
|
|
||||||
const int gemm_k_iterations,
|
|
||||||
thread loader_a_t& loader_a,
|
|
||||||
thread loader_b_t& loader_b,
|
|
||||||
thread mma_t& mma_op,
|
|
||||||
thread const short& tgp_bm,
|
|
||||||
thread const short& tgp_bn,
|
|
||||||
thread const short& lbk,
|
|
||||||
LoopAlignment<M_aligned, N_aligned, K_aligned_> l = {}) {
|
|
||||||
// Appease the compiler
|
|
||||||
(void)l;
|
|
||||||
|
|
||||||
short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm);
|
|
||||||
|
|
||||||
short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK);
|
|
||||||
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
if (M_aligned) {
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (N_aligned) {
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
} else {
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!K_aligned_) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
short2 tile_dims_A_last =
|
|
||||||
transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm);
|
|
||||||
short2 tile_dims_B_last =
|
|
||||||
transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A_last);
|
|
||||||
loader_b.load_safe(tile_dims_B_last);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Main kernel function */
|
|
||||||
static METAL_FUNC void run(
|
|
||||||
const device T* A [[buffer(0)]],
|
|
||||||
const device T* B [[buffer(1)]],
|
|
||||||
device U* C [[buffer(2)]],
|
|
||||||
const constant GEMMParams* params [[buffer(3)]],
|
|
||||||
threadgroup T* As [[threadgroup(0)]],
|
|
||||||
threadgroup T* Bs [[threadgroup(1)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
// Pacifying compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
||||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
||||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
||||||
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
|
|
||||||
A += transpose_a ? c_row : c_row * params->lda;
|
|
||||||
B += transpose_b ? c_col * params->ldb : c_col;
|
|
||||||
C += c_row * params->ldc + c_col;
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MNK aligned loop
|
|
||||||
if (MN_aligned) {
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Loop tail
|
|
||||||
if (!K_aligned) {
|
|
||||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
mma_op.store_result(C, params->ldc);
|
|
||||||
return;
|
|
||||||
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MN unaligned loop
|
|
||||||
else { // Loop over K - unaligned case
|
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
|
|
||||||
if (tgp_bm == BM && tgp_bn == BN) {
|
|
||||||
gemm_loop<true, true, K_aligned>(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk);
|
|
||||||
|
|
||||||
mma_op.store_result(C, params->ldc);
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else if (tgp_bn == BN) {
|
|
||||||
gemm_loop<false, true, K_aligned>(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk);
|
|
||||||
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else if (tgp_bm == BM) {
|
|
||||||
gemm_loop<true, false, K_aligned>(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk);
|
|
||||||
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
gemm_loop<false, false, K_aligned>(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk);
|
|
||||||
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
@ -1,5 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "params.h"
|
|
@ -1,89 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "gemm/bf16.h"
|
|
||||||
#include "gemm/gemm.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm(
|
|
||||||
const device T *A [[buffer(0)]],
|
|
||||||
const device T *B [[buffer(1)]],
|
|
||||||
device T *C [[buffer(2)]],
|
|
||||||
const constant GEMMParams* params [[buffer(3)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, T, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
// Adjust for batch
|
|
||||||
A += params->batch_stride_a * tid.z;
|
|
||||||
B += params->batch_stride_b * tid.z;
|
|
||||||
C += params->batch_stride_c * tid.z;
|
|
||||||
|
|
||||||
gemm_kernel::run(
|
|
||||||
A, B, C,
|
|
||||||
params,
|
|
||||||
As, Bs,
|
|
||||||
simd_lane_id, simd_group_id, tid, lid
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
|
||||||
template [[host_name("steel_gemm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
|
||||||
[[kernel]] void gemm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
|
||||||
const device itype *A [[buffer(0)]], \
|
|
||||||
const device itype *B [[buffer(1)]], \
|
|
||||||
device itype *C [[buffer(2)]], \
|
|
||||||
const constant GEMMParams* params [[buffer(3)]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
|
||||||
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
|
||||||
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
|
@ -1,254 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned,
|
|
||||||
typename AccumType = float,
|
|
||||||
typename Epilogue = TransformAdd<T, AccumType>>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void addmm(
|
|
||||||
const device T *A [[buffer(0)]],
|
|
||||||
const device T *B [[buffer(1)]],
|
|
||||||
const device T *C [[buffer(2)]],
|
|
||||||
device T *D [[buffer(3)]],
|
|
||||||
const constant GEMMAddMMParams* params [[buffer(4)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
// Pacifying compiler
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
using gemm_kernel =
|
|
||||||
GEMMKernel<T, T, BM, BN, BK, WM, WN,
|
|
||||||
transpose_a, transpose_b,
|
|
||||||
MN_aligned, K_aligned,
|
|
||||||
AccumType, Epilogue>;
|
|
||||||
|
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
||||||
using mma_t = typename gemm_kernel::mma_t;
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
// Adjust for batch
|
|
||||||
A += params->batch_stride_a * tid.z;
|
|
||||||
B += params->batch_stride_b * tid.z;
|
|
||||||
C += params->batch_stride_c * tid.z;
|
|
||||||
D += params->batch_stride_d * tid.z;
|
|
||||||
|
|
||||||
const int tid_y = ((tid.y) << params->swizzle_log) +
|
|
||||||
((tid.x) & ((1 << params->swizzle_log) - 1));
|
|
||||||
const int tid_x = (tid.x) >> params->swizzle_log;
|
|
||||||
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
|
|
||||||
A += transpose_a ? c_row : c_row * params->lda;
|
|
||||||
B += transpose_b ? c_col * params->ldb : c_col;
|
|
||||||
C += c_row * params->ldc + c_col * params->fdc;
|
|
||||||
D += c_row * params->ldd + c_col;
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
const Epilogue epilogue_op(params->alpha, params->beta);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MNK aligned loop
|
|
||||||
if (MN_aligned) {
|
|
||||||
for (int k = 0; k < gemm_k_iterations; k++) {
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
// Load elements into threadgroup
|
|
||||||
loader_a.load_unsafe();
|
|
||||||
loader_b.load_unsafe();
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
// Multiply and accumulate threadgroup elements
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
|
|
||||||
// Prepare for next iteration
|
|
||||||
loader_a.next();
|
|
||||||
loader_b.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Loop tail
|
|
||||||
if (!K_aligned) {
|
|
||||||
int lbk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM);
|
|
||||||
short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk);
|
|
||||||
|
|
||||||
loader_a.load_safe(tile_dims_A);
|
|
||||||
loader_b.load_safe(tile_dims_B);
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
mma_op.mma(As, Bs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Store results to device memory
|
|
||||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
|
||||||
return;
|
|
||||||
|
|
||||||
}
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MN unaligned loop
|
|
||||||
else { // Loop over K - unaligned case
|
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK;
|
|
||||||
|
|
||||||
if (tgp_bm == BM && tgp_bn == BN) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, true, K_aligned>{});
|
|
||||||
|
|
||||||
mma_op.store_result(D, params->ldd, C, params->ldc, params->fdc, epilogue_op);
|
|
||||||
return;
|
|
||||||
|
|
||||||
} else if (tgp_bn == BN) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, true, K_aligned>{});
|
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
|
||||||
D, params->ldd,
|
|
||||||
C, params->ldc, params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op);
|
|
||||||
|
|
||||||
} else if (tgp_bm == BM) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, false, K_aligned>{});
|
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
|
||||||
D, params->ldd,
|
|
||||||
C, params->ldc, params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op);
|
|
||||||
|
|
||||||
} else {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, K_aligned>{});
|
|
||||||
|
|
||||||
return mma_op.store_result_safe(
|
|
||||||
D, params->ldd,
|
|
||||||
C, params->ldc, params->fdc,
|
|
||||||
short2(tgp_bn, tgp_bm),
|
|
||||||
epilogue_op);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, ep_name, epilogue) \
|
|
||||||
template [[host_name("steel_addmm_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname "_" #ep_name)]] \
|
|
||||||
[[kernel]] void addmm<itype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned, float, epilogue<itype, float>>( \
|
|
||||||
const device itype *A [[buffer(0)]], \
|
|
||||||
const device itype *B [[buffer(1)]], \
|
|
||||||
const device itype *C [[buffer(2)]], \
|
|
||||||
device itype *D [[buffer(3)]], \
|
|
||||||
const constant GEMMAddMMParams* params [[buffer(4)]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, add, TransformAdd) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned, axpby, TransformAxpby)
|
|
||||||
|
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
|
||||||
instantiate_gemm_bias_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
|
||||||
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
|
||||||
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float16, half);
|
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
|
@ -1,280 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#include "mlx/backend/metal/kernels/bf16.h"
|
|
||||||
#include "mlx/backend/metal/kernels/steel/gemm/gemm.h"
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
using namespace mlx::steel;
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernels
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename T,
|
|
||||||
typename U,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
bool MN_aligned,
|
|
||||||
bool K_aligned>
|
|
||||||
[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk(
|
|
||||||
const device T *A [[buffer(0)]],
|
|
||||||
const device T *B [[buffer(1)]],
|
|
||||||
device U *C [[buffer(2)]],
|
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]],
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]],
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]) {
|
|
||||||
|
|
||||||
(void)lid;
|
|
||||||
|
|
||||||
using gemm_kernel = GEMMKernel<T, U, BM, BN, BK, WM, WN, transpose_a, transpose_b, MN_aligned, K_aligned>;
|
|
||||||
using loader_a_t = typename gemm_kernel::loader_a_t;
|
|
||||||
using loader_b_t = typename gemm_kernel::loader_b_t;
|
|
||||||
using mma_t = typename gemm_kernel::mma_t;
|
|
||||||
|
|
||||||
threadgroup T As[gemm_kernel::tgp_mem_size_a];
|
|
||||||
threadgroup T Bs[gemm_kernel::tgp_mem_size_b];
|
|
||||||
|
|
||||||
const int tid_x = tid.x;
|
|
||||||
const int tid_y = tid.y;
|
|
||||||
const int tid_z = tid.z;
|
|
||||||
|
|
||||||
if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find block in A, B, C
|
|
||||||
const int c_row = tid_y * BM;
|
|
||||||
const int c_col = tid_x * BN;
|
|
||||||
const int k_start = params->split_k_partition_size * tid_z;
|
|
||||||
|
|
||||||
A += transpose_a ? (c_row + k_start * params->lda) : (k_start + c_row * params->lda);
|
|
||||||
B += transpose_b ? (k_start + c_col * params->ldb) : (c_col + k_start * params->ldb);
|
|
||||||
C += (params->split_k_partition_stride * tid_z) + (c_row * params->ldc + c_col);
|
|
||||||
|
|
||||||
// Prepare threadgroup loading operations
|
|
||||||
thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id);
|
|
||||||
thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
// Prepare threadgroup mma operation
|
|
||||||
thread mma_t mma_op(simd_group_id, simd_lane_id);
|
|
||||||
|
|
||||||
int gemm_k_iterations = params->gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
short tgp_bm = min(BM, params->M - c_row);
|
|
||||||
short tgp_bn = min(BN, params->N - c_col);
|
|
||||||
short leftover_bk = params->K % BK;
|
|
||||||
|
|
||||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, true, true>{});
|
|
||||||
} else if (tgp_bn == BN) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, true, true>{});
|
|
||||||
} else if (tgp_bm == BM) {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<true, false, true>{});
|
|
||||||
} else {
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iterations,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, true>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
|
|
||||||
if ((tid_z + 1) == (params->split_k_partitions)) {
|
|
||||||
int gemm_k_iter_remaining = (params->K - (k_start + params->split_k_partition_size)) / BK;
|
|
||||||
if(!K_aligned || gemm_k_iter_remaining > 0)
|
|
||||||
gemm_kernel::gemm_loop(
|
|
||||||
As,
|
|
||||||
Bs,
|
|
||||||
gemm_k_iter_remaining,
|
|
||||||
loader_a,
|
|
||||||
loader_b,
|
|
||||||
mma_op,
|
|
||||||
tgp_bm,
|
|
||||||
tgp_bn,
|
|
||||||
leftover_bk,
|
|
||||||
LoopAlignment<false, false, K_aligned>{});
|
|
||||||
}
|
|
||||||
|
|
||||||
if(MN_aligned || (tgp_bm == BM && tgp_bn == BN)) {
|
|
||||||
mma_op.store_result(C, params->ldc);
|
|
||||||
} else {
|
|
||||||
mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM kernel initializations
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
#define instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, aname, mn_aligned, kname, k_aligned) \
|
|
||||||
template [[host_name("steel_gemm_splitk_" #tname "_" #iname "_" #oname "_bm" #bm "_bn" #bn "_bk" #bk "_wm" #wm "_wn" #wn "_MN_" #aname "_K_" #kname)]] \
|
|
||||||
[[kernel]] void gemm_splitk<itype, otype, bm, bn, bk, wm, wn, trans_a, trans_b, mn_aligned, k_aligned>( \
|
|
||||||
const device itype *A [[buffer(0)]], \
|
|
||||||
const device itype *B [[buffer(1)]], \
|
|
||||||
device otype *C [[buffer(2)]], \
|
|
||||||
const constant GEMMSpiltKParams* params [[buffer(3)]], \
|
|
||||||
uint simd_lane_id [[thread_index_in_simdgroup]], \
|
|
||||||
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
|
|
||||||
uint3 tid [[threadgroup_position_in_grid]], \
|
|
||||||
uint3 lid [[thread_position_in_threadgroup]]);
|
|
||||||
|
|
||||||
#define instantiate_gemm_aligned_helper(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, taligned, true) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, taligned, true, naligned, false) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, taligned, true) \
|
|
||||||
instantiate_gemm(tname, trans_a, trans_b, iname, itype, oname, otype, bm, bn, bk, wm, wn, naligned, false, naligned, false)
|
|
||||||
|
|
||||||
#define instantiate_gemm_transpose_helper(iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nn, false, false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(nt, false, true , iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tn, true , false, iname, itype, oname, otype, bm, bn, bk, wm, wn) \
|
|
||||||
instantiate_gemm_aligned_helper(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
|
|
||||||
|
|
||||||
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 16, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 16, 32, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 16, 16, 2, 2) \
|
|
||||||
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float16, half, float32, float);
|
|
||||||
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, float32, float);
|
|
||||||
|
|
||||||
instantiate_gemm_shapes_helper(float32, float, float32, float);
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Split k accumulation kernel
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename AccT,
|
|
||||||
typename OutT,
|
|
||||||
typename Epilogue = TransformNone<OutT, AccT>>
|
|
||||||
[[kernel]] void gemm_splitk_accum(
|
|
||||||
const device AccT *C_split [[buffer(0)]],
|
|
||||||
device OutT *D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
|
|
||||||
// Ajust D and C
|
|
||||||
D += gid.x + gid.y * ldd;
|
|
||||||
C_split += gid.x + gid.y * ldd;
|
|
||||||
|
|
||||||
int offset = 0;
|
|
||||||
AccT out = 0;
|
|
||||||
|
|
||||||
for(int i = 0; i < k_partitions; i++) {
|
|
||||||
out += C_split[offset];
|
|
||||||
offset += partition_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write output
|
|
||||||
D[0] = Epilogue::apply(out);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename AccT,
|
|
||||||
typename OutT,
|
|
||||||
typename Epilogue = TransformAxpby<OutT, AccT>>
|
|
||||||
[[kernel]] void gemm_splitk_accum_axpby(
|
|
||||||
const device AccT *C_split [[buffer(0)]],
|
|
||||||
device OutT *D [[buffer(1)]],
|
|
||||||
const constant int& k_partitions [[buffer(2)]],
|
|
||||||
const constant int& partition_stride [[buffer(3)]],
|
|
||||||
const constant int& ldd [[buffer(4)]],
|
|
||||||
const device OutT *C [[buffer(5)]],
|
|
||||||
const constant int& ldc [[buffer(6)]],
|
|
||||||
const constant int& fdc [[buffer(7)]],
|
|
||||||
const constant float& alpha [[buffer(8)]],
|
|
||||||
const constant float& beta [[buffer(9)]],
|
|
||||||
uint2 gid [[thread_position_in_grid]]) {
|
|
||||||
|
|
||||||
// Ajust D and C
|
|
||||||
C += gid.x * fdc + gid.y * ldc;
|
|
||||||
D += gid.x + gid.y * ldd;
|
|
||||||
C_split += gid.x + gid.y * ldd;
|
|
||||||
|
|
||||||
int offset = 0;
|
|
||||||
AccT out = 0;
|
|
||||||
|
|
||||||
for(int i = 0; i < k_partitions; i++) {
|
|
||||||
out += C_split[offset];
|
|
||||||
offset += partition_stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write output
|
|
||||||
Epilogue op(alpha, beta);
|
|
||||||
D[0] = op.apply(out, *C);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
#define instantiate_accum(oname, otype, aname, atype) \
|
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname)]] \
|
|
||||||
[[kernel]] void gemm_splitk_accum<atype, otype>( \
|
|
||||||
const device atype *C_split [[buffer(0)]], \
|
|
||||||
device otype *D [[buffer(1)]], \
|
|
||||||
const constant int& k_partitions [[buffer(2)]], \
|
|
||||||
const constant int& partition_stride [[buffer(3)]], \
|
|
||||||
const constant int& ldd [[buffer(4)]], \
|
|
||||||
uint2 gid [[thread_position_in_grid]]); \
|
|
||||||
template [[host_name("steel_gemm_splitk_accum_" #oname "_" #aname "_axpby")]] \
|
|
||||||
[[kernel]] void gemm_splitk_accum_axpby<atype, otype>( \
|
|
||||||
const device atype *C_split [[buffer(0)]], \
|
|
||||||
device otype *D [[buffer(1)]], \
|
|
||||||
const constant int& k_partitions [[buffer(2)]], \
|
|
||||||
const constant int& partition_stride [[buffer(3)]], \
|
|
||||||
const constant int& ldd [[buffer(4)]], \
|
|
||||||
const device otype *C [[buffer(5)]], \
|
|
||||||
const constant int& ldc [[buffer(6)]], \
|
|
||||||
const constant int& fdc [[buffer(7)]], \
|
|
||||||
const constant float& alpha [[buffer(8)]], \
|
|
||||||
const constant float& beta [[buffer(9)]], \
|
|
||||||
uint2 gid [[thread_position_in_grid]]);
|
|
||||||
|
|
||||||
instantiate_accum(bfloat16, bfloat16_t, float32, float);
|
|
||||||
instantiate_accum(float16, half, float32, float);
|
|
||||||
instantiate_accum(float32, float, float32, float);
|
|
@ -1,125 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "utils2.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Loading helper
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
short BROWS,
|
|
||||||
short BCOLS,
|
|
||||||
short dst_ld,
|
|
||||||
short reduction_dim,
|
|
||||||
short tgp_size,
|
|
||||||
short alignment = 1,
|
|
||||||
short n_reads = (BCOLS * BROWS) / (tgp_size),
|
|
||||||
short TCOLS = BCOLS / n_reads,
|
|
||||||
short TROWS = tgp_size / TCOLS>
|
|
||||||
struct BlockLoader {
|
|
||||||
STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS;
|
|
||||||
STEEL_CONST short vec_size = n_reads;
|
|
||||||
|
|
||||||
// Leading dimension for src
|
|
||||||
const int src_ld;
|
|
||||||
const int tile_stride;
|
|
||||||
|
|
||||||
// Thread location indices
|
|
||||||
const short thread_idx;
|
|
||||||
const short bi;
|
|
||||||
const short bj;
|
|
||||||
|
|
||||||
// threadgroup and device memory
|
|
||||||
threadgroup T* dst;
|
|
||||||
const device T* src;
|
|
||||||
|
|
||||||
struct alignas(alignment * sizeof(T)) ReadVector {
|
|
||||||
uint8_t v[sizeof(T) * vec_size];
|
|
||||||
};
|
|
||||||
|
|
||||||
/* Constructor */
|
|
||||||
METAL_FUNC BlockLoader(
|
|
||||||
const device T* src_,
|
|
||||||
const int src_ld_,
|
|
||||||
threadgroup T* dst_,
|
|
||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
||||||
: src_ld(src_ld_),
|
|
||||||
tile_stride(reduction_dim ? BCOLS : BROWS * src_ld),
|
|
||||||
thread_idx(simd_group_id * 32 + simd_lane_id),
|
|
||||||
bi(thread_idx / TCOLS),
|
|
||||||
bj(vec_size * (thread_idx % TCOLS)),
|
|
||||||
dst(dst_ + bi * dst_ld + bj),
|
|
||||||
src(src_ + bi * src_ld + bj) {}
|
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - without bound checking */
|
|
||||||
METAL_FUNC void load_unsafe() const {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < BROWS; i += TROWS) {
|
|
||||||
*((threadgroup ReadVector*)(&dst[i * dst_ld])) =
|
|
||||||
*((const device ReadVector*)(&src[i * src_ld]));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Load from device memory into threadgroup memory - with bound checking */
|
|
||||||
METAL_FUNC void load_safe(short2 src_tile_dim) const {
|
|
||||||
src_tile_dim = src_tile_dim - short2(bj, bi);
|
|
||||||
|
|
||||||
// Skip loading if thread has no valid reads
|
|
||||||
if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < BROWS; i += TROWS) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
dst[i * dst_ld + j] = T(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use fast thread memory for bound checks
|
|
||||||
bool tmp_idx[vec_size];
|
|
||||||
T tmp_val[vec_size];
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < BROWS; i += TROWS) {
|
|
||||||
// Make sure tmp_idx only contains valid indices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read valid indices into tmp_val
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zero out uneeded values
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copy values to threadgroup memory
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < vec_size; j++) {
|
|
||||||
dst[i * dst_ld + j] = tmp_val[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Iteration helper */
|
|
||||||
METAL_FUNC void next() {
|
|
||||||
src += tile_stride;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
@ -1,264 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "gemm/transforms.h"
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// MMA helper
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
template <
|
|
||||||
typename T,
|
|
||||||
typename U,
|
|
||||||
int BM,
|
|
||||||
int BN,
|
|
||||||
int BK,
|
|
||||||
int WM,
|
|
||||||
int WN,
|
|
||||||
bool transpose_a,
|
|
||||||
bool transpose_b,
|
|
||||||
short lda_tgp,
|
|
||||||
short ldb_tgp,
|
|
||||||
typename AccumType = float,
|
|
||||||
typename Epilogue = TransformNone<U, AccumType>>
|
|
||||||
struct BlockMMA {
|
|
||||||
// Warp tile simdgroup matrix strides along M
|
|
||||||
STEEL_CONST short TM_stride = 8 * WM;
|
|
||||||
// Warp tile simdgroup matrix strides along M
|
|
||||||
STEEL_CONST short TN_stride = 8 * WN;
|
|
||||||
|
|
||||||
// Warp tile size along M
|
|
||||||
STEEL_CONST short TM = BM / TM_stride;
|
|
||||||
// Warp tile size along N
|
|
||||||
STEEL_CONST short TN = BN / TN_stride;
|
|
||||||
|
|
||||||
// Strides of A, B along reduction axis
|
|
||||||
STEEL_CONST short simd_stride_a = {
|
|
||||||
transpose_a ? TM_stride : TM_stride * lda_tgp};
|
|
||||||
STEEL_CONST short simd_stride_b = {
|
|
||||||
transpose_b ? TN_stride * ldb_tgp : TN_stride};
|
|
||||||
|
|
||||||
// Jump between elements
|
|
||||||
STEEL_CONST short jump_a = {transpose_a ? lda_tgp : 1};
|
|
||||||
STEEL_CONST short jump_b = {transpose_b ? ldb_tgp : 1};
|
|
||||||
|
|
||||||
STEEL_CONST short tile_stride_a = {transpose_a ? 8 * lda_tgp : 8};
|
|
||||||
STEEL_CONST short tile_stride_b = {transpose_b ? 8 : 8 * ldb_tgp};
|
|
||||||
|
|
||||||
// Simdgroup matrices
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> Asimd[TM];
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> Bsimd[TN];
|
|
||||||
simdgroup_matrix<AccumType, 8, 8> results[TM * TN] = {
|
|
||||||
simdgroup_matrix<AccumType, 8, 8>(0)};
|
|
||||||
|
|
||||||
// Offsets within threadgroup
|
|
||||||
const short tm;
|
|
||||||
const short tn;
|
|
||||||
|
|
||||||
short sm;
|
|
||||||
short sn;
|
|
||||||
|
|
||||||
short As_offset;
|
|
||||||
short Bs_offset;
|
|
||||||
|
|
||||||
/* Constructor */
|
|
||||||
METAL_FUNC BlockMMA(
|
|
||||||
ushort simd_group_id [[simdgroup_index_in_threadgroup]],
|
|
||||||
ushort simd_lane_id [[thread_index_in_simdgroup]])
|
|
||||||
: tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) {
|
|
||||||
// Determine thread position in simdgroup matrix
|
|
||||||
short qid = simd_lane_id / 4;
|
|
||||||
sm = (qid & 4) + (simd_lane_id / 2) % 4;
|
|
||||||
sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2;
|
|
||||||
|
|
||||||
// Determine thread and simdgroup offset
|
|
||||||
As_offset =
|
|
||||||
transpose_a ? ((sn)*lda_tgp + (tm + sm)) : ((sn) + (tm + sm) * lda_tgp);
|
|
||||||
Bs_offset =
|
|
||||||
transpose_b ? ((tn + sn) * ldb_tgp + (sm)) : ((sm)*ldb_tgp + (tn + sn));
|
|
||||||
}
|
|
||||||
|
|
||||||
/* (BM, BK) X (BK, BN) multiply accumulate function */
|
|
||||||
METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
As += As_offset;
|
|
||||||
Bs += Bs_offset;
|
|
||||||
|
|
||||||
// Iterate over BK in blocks of 8
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short kk = 0; kk < BK; kk += 8) {
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Load elements from threadgroup A as simdgroup matrices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
Asimd[i].thread_elements()[0] =
|
|
||||||
static_cast<AccumType>(As[i * simd_stride_a + 0]);
|
|
||||||
Asimd[i].thread_elements()[1] =
|
|
||||||
static_cast<AccumType>(As[i * simd_stride_a + jump_a]);
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Load elements from threadgroup B as simdgroup matrices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
Bsimd[j].thread_elements()[0] =
|
|
||||||
static_cast<AccumType>(Bs[j * simd_stride_b + 0]);
|
|
||||||
Bsimd[j].thread_elements()[1] =
|
|
||||||
static_cast<AccumType>(Bs[j * simd_stride_b + jump_b]);
|
|
||||||
}
|
|
||||||
|
|
||||||
simdgroup_barrier(mem_flags::mem_none);
|
|
||||||
|
|
||||||
// Multiply and accumulate into result simdgroup matrices
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
short j_serp = (i % 2) ? (TN - 1 - j) : j;
|
|
||||||
|
|
||||||
simdgroup_multiply_accumulate(
|
|
||||||
results[i * TN + j_serp],
|
|
||||||
Asimd[i],
|
|
||||||
Bsimd[j_serp],
|
|
||||||
results[i * TN + j_serp]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Progress to next simdgroup tile
|
|
||||||
As += tile_stride_a;
|
|
||||||
Bs += tile_stride_b;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Store results from simdgroup_matrix results into device memory */
|
|
||||||
METAL_FUNC void store_result(device U* C, const int ldc) const {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
C += (sm + tm) * ldc + tn + sn;
|
|
||||||
|
|
||||||
// Loop over all simdgroup tiles
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
|
||||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue
|
|
||||||
U outs[2] = {Epilogue::apply(accum[0]), Epilogue::apply(accum[1])};
|
|
||||||
|
|
||||||
// Write out C
|
|
||||||
C[offset] = outs[0];
|
|
||||||
C[offset + 1] = outs[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void
|
|
||||||
store_result_safe(device U* C, const int ldc, short2 dst_tile_dims) const {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
C += (sm + tm) * ldc + (tn + sn);
|
|
||||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < TM; i++) {
|
|
||||||
if (i * TM_stride < dst_tile_dims.y) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int j = 0; j < TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
|
||||||
int offset = (i * TM_stride) * ldc + (j * TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue and output C
|
|
||||||
if (j * TN_stride < dst_tile_dims.x) {
|
|
||||||
C[offset] = Epilogue::apply(accum[0]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
|
||||||
C[offset + 1] = Epilogue::apply(accum[1]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Store results from simdgroup_matrix results into device memory */
|
|
||||||
METAL_FUNC void store_result(
|
|
||||||
device U* D,
|
|
||||||
const int ldd,
|
|
||||||
const device U* C,
|
|
||||||
const int ldc,
|
|
||||||
const int fdc,
|
|
||||||
thread const Epilogue& epilogue_op) const {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
|
||||||
D += (sm + tm) * ldd + tn + sn;
|
|
||||||
|
|
||||||
// Loop over all simdgroup tiles
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short i = 0; i < TM; i++) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (short j = 0; j < TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
|
||||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
||||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue
|
|
||||||
U outs[2] = {
|
|
||||||
epilogue_op.apply(accum[0], C[offset_c]),
|
|
||||||
epilogue_op.apply(accum[1], C[offset_c + fdc])};
|
|
||||||
|
|
||||||
// Write out D
|
|
||||||
D[offset_d] = outs[0];
|
|
||||||
D[offset_d + 1] = outs[1];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC void store_result_safe(
|
|
||||||
device U* D,
|
|
||||||
const int ldd,
|
|
||||||
const device U* C,
|
|
||||||
const int ldc,
|
|
||||||
const int fdc,
|
|
||||||
short2 dst_tile_dims,
|
|
||||||
thread const Epilogue& epilogue_op) const {
|
|
||||||
// Adjust for simdgroup and thread location
|
|
||||||
C += (sm + tm) * ldc + (tn + sn) * fdc;
|
|
||||||
D += (sm + tm) * ldd + tn + sn;
|
|
||||||
dst_tile_dims -= short2(tn + sn, sm + tm);
|
|
||||||
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int i = 0; i < TM; i++) {
|
|
||||||
if (i * TM_stride < dst_tile_dims.y) {
|
|
||||||
STEEL_PRAGMA_UNROLL
|
|
||||||
for (int j = 0; j < TN; j++) {
|
|
||||||
// Get accumulated result and associated offset in C
|
|
||||||
thread const auto& accum = results[i * TN + j].thread_elements();
|
|
||||||
int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc;
|
|
||||||
int offset_d = (i * TM_stride) * ldd + (j * TN_stride);
|
|
||||||
|
|
||||||
// Apply epilogue and output C
|
|
||||||
if (j * TN_stride < dst_tile_dims.x) {
|
|
||||||
D[offset_d] = epilogue_op.apply(accum[0], C[offset_c]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (j * TN_stride + 1 < dst_tile_dims.x) {
|
|
||||||
D[offset_d + 1] = epilogue_op.apply(accum[1], C[offset_c + fdc]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
@ -1,79 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// GEMM param classes
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
struct GEMMParams {
|
|
||||||
const int M;
|
|
||||||
const int N;
|
|
||||||
const int K;
|
|
||||||
|
|
||||||
const int lda;
|
|
||||||
const int ldb;
|
|
||||||
const int ldc;
|
|
||||||
|
|
||||||
const int tiles_n;
|
|
||||||
const int tiles_m;
|
|
||||||
|
|
||||||
const int batch_stride_a;
|
|
||||||
const int batch_stride_b;
|
|
||||||
const int batch_stride_c;
|
|
||||||
|
|
||||||
const int swizzle_log;
|
|
||||||
const int gemm_k_iterations_aligned;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GEMMSpiltKParams {
|
|
||||||
const int M;
|
|
||||||
const int N;
|
|
||||||
const int K;
|
|
||||||
|
|
||||||
const int lda;
|
|
||||||
const int ldb;
|
|
||||||
const int ldc;
|
|
||||||
|
|
||||||
const int tiles_n;
|
|
||||||
const int tiles_m;
|
|
||||||
|
|
||||||
const int split_k_partitions;
|
|
||||||
const int split_k_partition_stride;
|
|
||||||
const int split_k_partition_size;
|
|
||||||
|
|
||||||
const int gemm_k_iterations_aligned;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct GEMMAddMMParams {
|
|
||||||
const int M;
|
|
||||||
const int N;
|
|
||||||
const int K;
|
|
||||||
|
|
||||||
const int lda;
|
|
||||||
const int ldb;
|
|
||||||
const int ldc;
|
|
||||||
const int ldd;
|
|
||||||
|
|
||||||
const int tiles_n;
|
|
||||||
const int tiles_m;
|
|
||||||
|
|
||||||
const int batch_stride_a;
|
|
||||||
const int batch_stride_b;
|
|
||||||
const int batch_stride_c;
|
|
||||||
const int batch_stride_d;
|
|
||||||
|
|
||||||
const int swizzle_log;
|
|
||||||
const int gemm_k_iterations_aligned;
|
|
||||||
|
|
||||||
const float alpha;
|
|
||||||
const float beta;
|
|
||||||
|
|
||||||
const int fdc;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
Binary file not shown.
@ -1,63 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Transforms and Epilogues
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
namespace mlx {
|
|
||||||
namespace steel {
|
|
||||||
|
|
||||||
template <typename OutT, typename InT>
|
|
||||||
struct TransformNone {
|
|
||||||
static METAL_FUNC OutT apply(InT x) {
|
|
||||||
return static_cast<OutT>(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
static METAL_FUNC OutT apply(InT x, OutT) {
|
|
||||||
return static_cast<OutT>(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename OutT, typename InT>
|
|
||||||
struct TransformAdd {
|
|
||||||
TransformAdd(const float, const float) {}
|
|
||||||
|
|
||||||
static METAL_FUNC OutT apply(InT x, OutT c) {
|
|
||||||
return static_cast<OutT>(x) + c;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename OutT, typename InT>
|
|
||||||
struct TransformAxpby {
|
|
||||||
const float alpha;
|
|
||||||
const float beta;
|
|
||||||
|
|
||||||
TransformAxpby(const float alpha_, const float beta_)
|
|
||||||
: alpha(alpha_), beta(beta_) {}
|
|
||||||
|
|
||||||
METAL_FUNC OutT apply(InT x, OutT c) const {
|
|
||||||
return static_cast<OutT>(x * alpha + (beta * c));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct AccumHelper {
|
|
||||||
typedef float accum_type;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct BlockSwizzle {
|
|
||||||
static METAL_FUNC int2
|
|
||||||
swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) {
|
|
||||||
const int tid_x = (tid.x) >> swizzle_log;
|
|
||||||
const int tid_y =
|
|
||||||
((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1));
|
|
||||||
return int2(tid_x, tid_y);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace steel
|
|
||||||
} // namespace mlx
|
|
@ -1,276 +0,0 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_math>
|
|
||||||
#include "gemm/bf16.h"
|
|
||||||
#include "gemm/complex.h"
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Type limits utils
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
template <typename U>
|
|
||||||
struct Limits {
|
|
||||||
static const constant U max = metal::numeric_limits<U>::max();
|
|
||||||
static const constant U min = metal::numeric_limits<U>::min();
|
|
||||||
static const constant U finite_max = metal::numeric_limits<U>::max();
|
|
||||||
static const constant U finite_min = metal::numeric_limits<U>::min();
|
|
||||||
};
|
|
||||||
|
|
||||||
#define instantiate_default_limit(type) \
|
|
||||||
template <> \
|
|
||||||
struct Limits<type> { \
|
|
||||||
static constexpr constant type max = metal::numeric_limits<type>::max(); \
|
|
||||||
static constexpr constant type min = metal::numeric_limits<type>::min(); \
|
|
||||||
static constexpr constant type finite_max = \
|
|
||||||
metal::numeric_limits<type>::max(); \
|
|
||||||
static constexpr constant type finite_min = \
|
|
||||||
metal::numeric_limits<type>::min(); \
|
|
||||||
};
|
|
||||||
|
|
||||||
instantiate_default_limit(uint8_t);
|
|
||||||
instantiate_default_limit(uint16_t);
|
|
||||||
instantiate_default_limit(uint32_t);
|
|
||||||
instantiate_default_limit(uint64_t);
|
|
||||||
instantiate_default_limit(int8_t);
|
|
||||||
instantiate_default_limit(int16_t);
|
|
||||||
instantiate_default_limit(int32_t);
|
|
||||||
instantiate_default_limit(int64_t);
|
|
||||||
|
|
||||||
#define instantiate_float_limit(type) \
|
|
||||||
template <> \
|
|
||||||
struct Limits<type> { \
|
|
||||||
static constexpr constant type max = \
|
|
||||||
metal::numeric_limits<type>::infinity(); \
|
|
||||||
static constexpr constant type min = \
|
|
||||||
-metal::numeric_limits<type>::infinity(); \
|
|
||||||
static constexpr constant type finite_max = \
|
|
||||||
metal::numeric_limits<type>::max(); \
|
|
||||||
static constexpr constant type finite_min = \
|
|
||||||
-metal::numeric_limits<type>::max(); \
|
|
||||||
};
|
|
||||||
|
|
||||||
instantiate_float_limit(half);
|
|
||||||
instantiate_float_limit(float);
|
|
||||||
instantiate_float_limit(bfloat16_t);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct Limits<bool> {
|
|
||||||
static constexpr constant bool max = true;
|
|
||||||
static constexpr constant bool min = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Indexing utils
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
inline size_t elem_to_loc(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides,
|
|
||||||
int ndim) {
|
|
||||||
size_t loc = 0;
|
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
|
||||||
loc += (elem % shape[i]) * strides[i];
|
|
||||||
elem /= shape[i];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc(
|
|
||||||
uint elem,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* strides,
|
|
||||||
int ndim) {
|
|
||||||
size_t loc = 0;
|
|
||||||
for (int i = ndim - 1; i >= 0; --i) {
|
|
||||||
loc += (elem % shape[i]) * strides[i];
|
|
||||||
elem /= shape[i];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
inline uint2 elem_to_loc_2_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const size_t a_strides[NDIM],
|
|
||||||
constant const size_t b_strides[NDIM]) {
|
|
||||||
uint2 loc = {
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])};
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
uint l = elem.z % shape[d];
|
|
||||||
loc.x += l * a_strides[d];
|
|
||||||
loc.y += l * b_strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
inline size_t elem_to_loc_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int shape[NDIM],
|
|
||||||
constant const size_t strides[NDIM]) {
|
|
||||||
size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2];
|
|
||||||
for (int d = NDIM - 3; d >= 0; --d) {
|
|
||||||
loc += (elem.z % shape[d]) * strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) {
|
|
||||||
return elem * stride;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) {
|
|
||||||
return elem.x * strides[1] + elem.y * strides[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) {
|
|
||||||
return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Non templated version to handle arbitrary dims
|
|
||||||
inline size_t elem_to_loc(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* strides,
|
|
||||||
int ndim) {
|
|
||||||
size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2];
|
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
|
||||||
loc += (elem.z % shape[d]) * strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline uint2 elem_to_loc_2_nd(
|
|
||||||
uint3 elem,
|
|
||||||
constant const int* shape,
|
|
||||||
constant const size_t* a_strides,
|
|
||||||
constant const size_t* b_strides,
|
|
||||||
int ndim) {
|
|
||||||
uint2 loc = {
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]),
|
|
||||||
static_cast<uint>(
|
|
||||||
elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])};
|
|
||||||
for (int d = ndim - 3; d >= 0; --d) {
|
|
||||||
uint l = elem.z % shape[d];
|
|
||||||
loc.x += l * a_strides[d];
|
|
||||||
loc.y += l * b_strides[d];
|
|
||||||
elem.z /= shape[d];
|
|
||||||
}
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <int NDIM>
|
|
||||||
inline uint elem_to_loc_nd(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides);
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<1>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
return (elem % shape[0]) * strides[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<2>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
uint loc = (elem % shape[1]) * strides[1];
|
|
||||||
elem /= shape[1];
|
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<3>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
uint loc = (elem % shape[2]) * strides[2];
|
|
||||||
elem /= shape[2];
|
|
||||||
loc += (elem % shape[1]) * strides[1];
|
|
||||||
elem /= shape[1];
|
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <>
|
|
||||||
inline uint elem_to_loc_nd<4>(
|
|
||||||
uint elem,
|
|
||||||
device const int* shape,
|
|
||||||
device const size_t* strides) {
|
|
||||||
uint loc = (elem % shape[3]) * strides[3];
|
|
||||||
elem /= shape[3];
|
|
||||||
loc += (elem % shape[2]) * strides[2];
|
|
||||||
elem /= shape[2];
|
|
||||||
loc += (elem % shape[1]) * strides[1];
|
|
||||||
elem /= shape[1];
|
|
||||||
loc += (elem % shape[0]) * strides[0];
|
|
||||||
return loc;
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// Calculation utils
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
/** Compute ceil((float)N/(float)M) */
|
|
||||||
inline size_t ceildiv(size_t N, size_t M) {
|
|
||||||
return (N + M - 1) / M;
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
|
|
||||||
inline float log1p(float x) {
|
|
||||||
float xp1 = 1.0f + x;
|
|
||||||
if (xp1 == Limits<float>::max) {
|
|
||||||
return Limits<float>::max;
|
|
||||||
}
|
|
||||||
if (xp1 == 1.0f) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
return x * (metal::log(xp1) / (xp1 - 1.0f));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bfloat16_t log1p(bfloat16_t x) {
|
|
||||||
float xp1 = 1.0f + static_cast<float>(x);
|
|
||||||
if (xp1 == Limits<float>::max) {
|
|
||||||
return Limits<bfloat16_t>::max;
|
|
||||||
}
|
|
||||||
if (xp1 == 1.0f) {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
// SIMD shuffle ops
|
|
||||||
///////////////////////////////////////////////////////////////////////////////
|
|
||||||
|
|
||||||
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
|
|
||||||
return as_type<uint64_t>(
|
|
||||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
|
|
||||||
return as_type<int64_t>(
|
|
||||||
metal::simd_shuffle_down(as_type<uint2>(data), delta));
|
|
||||||
}
|
|
||||||
|
|
||||||
inline bool simd_shuffle_down(bool data, uint16_t delta) {
|
|
||||||
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
|
|
||||||
}
|
|
@ -1,9 +0,0 @@
|
|||||||
// Copyright © 2024 Apple Inc.
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <metal_stdlib>
|
|
||||||
#include "gemm/host.h"
|
|
||||||
|
|
||||||
#define STEEL_CONST static constant constexpr const
|
|
||||||
#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
|
|
@ -1,7 +1,6 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
|
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||||
NSUInteger,
|
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -13,12 +12,9 @@ const UNARY: &str = include_str!("unary.metal");
|
|||||||
const BINARY: &str = include_str!("binary.metal");
|
const BINARY: &str = include_str!("binary.metal");
|
||||||
const TERNARY: &str = include_str!("ternary.metal");
|
const TERNARY: &str = include_str!("ternary.metal");
|
||||||
const CAST: &str = include_str!("cast.metal");
|
const CAST: &str = include_str!("cast.metal");
|
||||||
const CONV: &str = include_str!("conv.metal");
|
|
||||||
const REDUCE: &str = include_str!("reduce.metal");
|
const REDUCE: &str = include_str!("reduce.metal");
|
||||||
const RANDOM: &str = include_str!("random.metal");
|
const CONV: &str = include_str!("conv.metal");
|
||||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||||
const GEMM: &[u8] = include_bytes!("gemm/steel_gemm.metallib");
|
|
||||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
|
||||||
|
|
||||||
/// Most kernels apply similarly across the tensors
|
/// Most kernels apply similarly across the tensors
|
||||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||||
@ -65,12 +61,10 @@ macro_rules! primitive {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
primitive!(bool);
|
|
||||||
primitive!(usize);
|
primitive!(usize);
|
||||||
primitive!(i32);
|
|
||||||
primitive!(i64);
|
primitive!(i64);
|
||||||
|
primitive!(i32);
|
||||||
primitive!(u32);
|
primitive!(u32);
|
||||||
primitive!(u64);
|
|
||||||
primitive!(f32);
|
primitive!(f32);
|
||||||
|
|
||||||
impl<T> EncoderParam for &[T] {
|
impl<T> EncoderParam for &[T] {
|
||||||
@ -124,9 +118,7 @@ pub enum Source {
|
|||||||
Cast,
|
Cast,
|
||||||
Reduce,
|
Reduce,
|
||||||
Mfa,
|
Mfa,
|
||||||
Gemm,
|
|
||||||
Conv,
|
Conv,
|
||||||
Random,
|
|
||||||
Quantized,
|
Quantized,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -248,10 +240,7 @@ impl Kernels {
|
|||||||
Source::Cast => CAST,
|
Source::Cast => CAST,
|
||||||
Source::Reduce => REDUCE,
|
Source::Reduce => REDUCE,
|
||||||
Source::Conv => CONV,
|
Source::Conv => CONV,
|
||||||
Source::Random => RANDOM,
|
|
||||||
Source::Quantized => QUANTIZED,
|
|
||||||
Source::Mfa => panic!("Invalid lib"),
|
Source::Mfa => panic!("Invalid lib"),
|
||||||
Source::Gemm => panic!("Invalid lib"),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -275,14 +264,6 @@ impl Kernels {
|
|||||||
))
|
))
|
||||||
})?
|
})?
|
||||||
}
|
}
|
||||||
Source::Gemm => {
|
|
||||||
let source_data = GEMM;
|
|
||||||
device.new_library_with_data(source_data).map_err(|e| {
|
|
||||||
MetalKernelError::LoadLibraryError(format!(
|
|
||||||
"Candle metal requires macosx > 13.0 or higher, cannot load GEMM: {e}"
|
|
||||||
))
|
|
||||||
})?
|
|
||||||
}
|
|
||||||
source => {
|
source => {
|
||||||
let source_content = self.get_library_source(source);
|
let source_content = self.get_library_source(source);
|
||||||
device
|
device
|
||||||
@ -1242,34 +1223,6 @@ impl ConstantValues {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn string_to_static_str(s: String) -> &'static str {
|
|
||||||
Box::leak(s.into_boxed_str())
|
|
||||||
}
|
|
||||||
|
|
||||||
use core::ffi::c_int;
|
|
||||||
|
|
||||||
#[repr(C)]
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct GEMMParams {
|
|
||||||
m: c_int,
|
|
||||||
n: c_int,
|
|
||||||
k: c_int,
|
|
||||||
|
|
||||||
lda: c_int,
|
|
||||||
ldb: c_int,
|
|
||||||
ldc: c_int,
|
|
||||||
|
|
||||||
tiles_n: c_int,
|
|
||||||
tiles_m: c_int,
|
|
||||||
|
|
||||||
batch_stride_a: c_int,
|
|
||||||
batch_stride_b: c_int,
|
|
||||||
batch_stride_c: c_int,
|
|
||||||
|
|
||||||
swizzle_log: c_int,
|
|
||||||
gemm_k_iterations_aligned: c_int,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_gemm(
|
pub fn call_gemm(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -1291,10 +1244,10 @@ pub fn call_gemm(
|
|||||||
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
let (a_trans, lda) = if lhs_m1 == 1 && lhs_m2 == k {
|
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
(false, k as c_int)
|
false
|
||||||
} else if lhs_m1 == m && lhs_m2 == 1 {
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
(true, n as c_int)
|
true
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
@ -1302,10 +1255,10 @@ pub fn call_gemm(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?;
|
})?;
|
||||||
};
|
};
|
||||||
let (b_trans, ldb) = if rhs_m1 == 1 && rhs_m2 == n {
|
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
(false, n as c_int)
|
false
|
||||||
} else if rhs_m1 == k && rhs_m2 == 1 {
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
(true, k as c_int)
|
true
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
@ -1313,195 +1266,120 @@ pub fn call_gemm(
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?;
|
})?;
|
||||||
};
|
};
|
||||||
// let d_trans = false;
|
let d_trans = false;
|
||||||
// let alpha = 1.0f32;
|
let alpha = 1.0f32;
|
||||||
// let beta = 0.0f32;
|
let beta = 0.0f32;
|
||||||
// let batched = b > 1;
|
let batched = b > 1;
|
||||||
// let fused_activation = false;
|
let fused_activation = false;
|
||||||
// let fused_bias = false;
|
let fused_bias = false;
|
||||||
// let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||||
// let m_simd = 8;
|
let m_simd = 8;
|
||||||
// let n_simd = 8;
|
let n_simd = 8;
|
||||||
// let k_simd = 64;
|
let k_simd = 64;
|
||||||
// let m_splits = 1;
|
let m_splits = 1;
|
||||||
// let n_splits = 1;
|
let n_splits = 1;
|
||||||
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||||
// } else {
|
} else {
|
||||||
// let m_simd = 40;
|
let m_simd = 40;
|
||||||
// let n_simd = 40;
|
let n_simd = 40;
|
||||||
// let k_simd = 32;
|
let k_simd = 32;
|
||||||
// let m_splits = 1;
|
let m_splits = 1;
|
||||||
// let n_splits = 1;
|
let n_splits = 1;
|
||||||
// (m_simd, n_simd, k_simd, m_splits, n_splits)
|
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||||
// };
|
};
|
||||||
// let constants = Some(ConstantValues::new(vec![
|
let constants = Some(ConstantValues::new(vec![
|
||||||
// (0, Value::USize(m)),
|
(0, Value::USize(m)),
|
||||||
// (1, Value::USize(n)),
|
(1, Value::USize(n)),
|
||||||
// (2, Value::USize(k)),
|
(2, Value::USize(k)),
|
||||||
// (10, Value::Bool(a_trans)),
|
(10, Value::Bool(a_trans)),
|
||||||
// (11, Value::Bool(b_trans)),
|
(11, Value::Bool(b_trans)),
|
||||||
// (13, Value::Bool(d_trans)),
|
(13, Value::Bool(d_trans)),
|
||||||
// (20, Value::F32(alpha)),
|
(20, Value::F32(alpha)),
|
||||||
// (21, Value::F32(beta)),
|
(21, Value::F32(beta)),
|
||||||
// (100, Value::Bool(batched)),
|
(100, Value::Bool(batched)),
|
||||||
// (101, Value::Bool(fused_activation)),
|
(101, Value::Bool(fused_activation)),
|
||||||
// // Garbage
|
// Garbage
|
||||||
// (102, Value::Bool(false)),
|
(102, Value::Bool(false)),
|
||||||
// (103, Value::Bool(false)),
|
(103, Value::Bool(false)),
|
||||||
// (113, Value::Bool(false)),
|
(113, Value::Bool(false)),
|
||||||
// (50_000, Value::Bool(false)),
|
(50_000, Value::Bool(false)),
|
||||||
// // End garbage
|
// End garbage
|
||||||
// (200, Value::U16(m_simd)),
|
(200, Value::U16(m_simd)),
|
||||||
// (201, Value::U16(n_simd)),
|
(201, Value::U16(n_simd)),
|
||||||
// (202, Value::U16(k_simd)),
|
(202, Value::U16(k_simd)),
|
||||||
// (210, Value::U16(m_splits)),
|
(210, Value::U16(m_splits)),
|
||||||
// (211, Value::U16(n_splits)),
|
(211, Value::U16(n_splits)),
|
||||||
// (50_001, Value::Bool(fused_bias)),
|
(50_001, Value::Bool(fused_bias)),
|
||||||
// ]));
|
]));
|
||||||
let a_trans_name = if a_trans { "t" } else { "n" };
|
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
|
||||||
let b_trans_name = if b_trans { "t" } else { "n" };
|
let m_group = m_simd * m_splits;
|
||||||
let (iname, oname) = match name {
|
let n_group = n_simd * n_splits;
|
||||||
"sgemm" => ("float32", "float32"),
|
|
||||||
"hgemm" => ("float16", "float16"),
|
let a_block_length = m_group * k_simd;
|
||||||
"bgemm" => ("bfloat16", "bfloat16"),
|
let b_block_length = k_simd * n_group;
|
||||||
|
|
||||||
|
let mut block_elements = a_block_length + b_block_length;
|
||||||
|
if (m % 8 != 0) && (n % 8 != 0) {
|
||||||
|
let c_block_length = m_group * n_group;
|
||||||
|
block_elements = std::cmp::max(c_block_length, block_elements)
|
||||||
|
}
|
||||||
|
if fused_bias {
|
||||||
|
if d_trans {
|
||||||
|
block_elements = std::cmp::max(block_elements, m_group);
|
||||||
|
} else {
|
||||||
|
block_elements = std::cmp::max(block_elements, n_group);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let bytes = match name {
|
||||||
|
"sgemm" => 4,
|
||||||
|
"hgemm" => 2,
|
||||||
other => {
|
other => {
|
||||||
return Err(MetalKernelError::LoadLibraryError(format!(
|
return Err(MetalKernelError::LoadLibraryError(format!(
|
||||||
"{other} is not a valid kernel for gemm"
|
"{other} is not a valid kernel for gemm"
|
||||||
)))
|
)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let mut bm = 32;
|
let block_bytes = block_elements * bytes;
|
||||||
let mut bn = 32;
|
|
||||||
let mut bk = 16;
|
|
||||||
let wm = 2;
|
|
||||||
let wn = 2;
|
|
||||||
if b * m * n >= 1 << 20 {
|
|
||||||
if !a_trans && b_trans {
|
|
||||||
bm = 64;
|
|
||||||
bn = if oname == "float32" { 64 } else { 32 };
|
|
||||||
bk = if oname == "float32" { 16 } else { 32 };
|
|
||||||
} else {
|
|
||||||
bm = 64;
|
|
||||||
bn = 64;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mnaligned = if m % bm == 0 && n % bn == 0 {
|
|
||||||
"taligned"
|
|
||||||
} else {
|
|
||||||
"naligned"
|
|
||||||
};
|
|
||||||
let kaligned = if k % bk == 0 { "taligned" } else { "naligned" };
|
|
||||||
// let bytes = match &name[..] {
|
|
||||||
// "sgemm" => 4,
|
|
||||||
// "hgemm" => 2,
|
|
||||||
// other => {
|
|
||||||
// return Err(MetalKernelError::LoadLibraryError(format!(
|
|
||||||
// "{other} is not a valid kernel for gemm"
|
|
||||||
// )));
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
let name = format!("steel_gemm_{a_trans_name}{b_trans_name}_{iname}_{oname}_bm{bm}_bn{bn}_bk{bk}_wm{wm}_wn{wn}_MN_{mnaligned}_K_{kaligned}");
|
|
||||||
let name = string_to_static_str(name);
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Gemm, name)?;
|
|
||||||
// let m_group = m_simd * m_splits;
|
|
||||||
// let n_group = n_simd * n_splits;
|
|
||||||
|
|
||||||
// let a_block_length = m_group * k_simd;
|
|
||||||
// let b_block_length = k_simd * n_group;
|
|
||||||
|
|
||||||
// let mut block_elements = a_block_length + b_block_length;
|
|
||||||
// if (m % 8 != 0) && (n % 8 != 0) {
|
|
||||||
// let c_block_length = m_group * n_group;
|
|
||||||
// block_elements = std::cmp::max(c_block_length, block_elements)
|
|
||||||
// }
|
|
||||||
// if fused_bias {
|
|
||||||
// if d_trans {
|
|
||||||
// block_elements = std::cmp::max(block_elements, m_group);
|
|
||||||
// } else {
|
|
||||||
// block_elements = std::cmp::max(block_elements, n_group);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// let block_bytes = block_elements * bytes;
|
|
||||||
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
// encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
encoder.set_threadgroup_memory_length(0, block_bytes.into());
|
||||||
|
|
||||||
let batch_stride_a: i32 = if lhs_stride.len() > 2 {
|
|
||||||
lhs_stride[lhs_stride.len() - 3] as i32
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
let batch_stride_b: i32 = if rhs_stride.len() > 2 {
|
|
||||||
rhs_stride[rhs_stride.len() - 3] as i32
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
let batch_stride_c = (m * n) as i32;
|
|
||||||
|
|
||||||
let swizzle_log = 0;
|
|
||||||
let tiles_n = ((n + bn - 1) / bn) as c_int;
|
|
||||||
let tiles_m = ((m + bm - 1) / bm) as c_int;
|
|
||||||
|
|
||||||
let params = GEMMParams {
|
|
||||||
m: m as c_int,
|
|
||||||
n: n as c_int,
|
|
||||||
k: k as c_int,
|
|
||||||
lda,
|
|
||||||
ldb,
|
|
||||||
ldc: n as c_int,
|
|
||||||
tiles_m,
|
|
||||||
tiles_n,
|
|
||||||
batch_stride_a,
|
|
||||||
batch_stride_b,
|
|
||||||
batch_stride_c,
|
|
||||||
swizzle_log,
|
|
||||||
gemm_k_iterations_aligned: (k / bk) as c_int,
|
|
||||||
};
|
|
||||||
let params_buffer = device.new_buffer_with_data(
|
|
||||||
¶ms as *const GEMMParams as *const c_void,
|
|
||||||
core::mem::size_of::<GEMMParams>() as u64,
|
|
||||||
MTLResourceOptions::StorageModeShared,
|
|
||||||
);
|
|
||||||
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
|
||||||
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
|
||||||
encoder.set_buffer(2, Some(output), 0);
|
encoder.set_buffer(2, Some(output), 0);
|
||||||
encoder.set_buffer(3, Some(¶ms_buffer), 0);
|
|
||||||
// TODO Tensor D
|
// TODO Tensor D
|
||||||
|
|
||||||
let grid_z = b;
|
let grid_z = b;
|
||||||
// if batched {
|
if batched {
|
||||||
// let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
|
||||||
// let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
|
||||||
// let byte_stride_c = m * n * bytes as usize;
|
let byte_stride_c = m * n * bytes as usize;
|
||||||
// // TODO byte_stride_d
|
// TODO byte_stride_d
|
||||||
// let byte_stride_d = 0;
|
let byte_stride_d = 0;
|
||||||
|
|
||||||
// let buffer: Vec<u64> = vec![
|
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
||||||
// byte_stride_a as _,
|
for i in 0..b {
|
||||||
// byte_stride_b as _,
|
buffer.push((i * byte_stride_a) as u64);
|
||||||
// byte_stride_c as _,
|
buffer.push((i * byte_stride_b) as u64);
|
||||||
// byte_stride_d as _,
|
buffer.push((i * byte_stride_c) as u64);
|
||||||
// ];
|
buffer.push((i * byte_stride_d) as u64);
|
||||||
// // encoder.set_bytes(
|
}
|
||||||
// // 10,
|
encoder.set_bytes(
|
||||||
// // (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
10,
|
||||||
// // buffer.as_ptr() as *const NSUInteger as *const c_void,
|
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||||
// // );
|
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||||
// }
|
);
|
||||||
let tile = 1 << swizzle_log;
|
}
|
||||||
let tm = (tiles_m + tile - 1) / tile;
|
|
||||||
let tn = tiles_n * tile;
|
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
let grid_size = MTLSize {
|
||||||
width: tn as u64,
|
width: divide(n, n_group.into()),
|
||||||
height: tm as u64,
|
height: divide(m, m_group.into()),
|
||||||
depth: grid_z as NSUInteger,
|
depth: grid_z as NSUInteger,
|
||||||
};
|
};
|
||||||
let group_size = MTLSize {
|
let group_size = MTLSize {
|
||||||
width: 32,
|
width: 32 * (m_splits as u64) * (n_splits as u64),
|
||||||
height: wn,
|
height: 1,
|
||||||
depth: wm,
|
depth: 1,
|
||||||
};
|
};
|
||||||
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(lhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(rhs_buffer, metal::MTLResourceUsage::Read);
|
||||||
@ -1647,73 +1525,6 @@ pub fn call_upsample_nearest_2d(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_random_uniform(
|
|
||||||
device: &Device,
|
|
||||||
command_buffer: &CommandBufferRef,
|
|
||||||
kernels: &Kernels,
|
|
||||||
name: &'static str,
|
|
||||||
min: f32,
|
|
||||||
max: f32,
|
|
||||||
length: usize,
|
|
||||||
seed: &Buffer,
|
|
||||||
buffer: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
if min >= max {
|
|
||||||
return Err(MetalKernelError::LoadLibraryError(
|
|
||||||
"min must be less than max".to_string(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
|
|
||||||
let odd = (length % 2 != 0) as usize;
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
set_params!(encoder, (length, min, max, seed, buffer));
|
|
||||||
|
|
||||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
|
||||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn call_random_normal(
|
|
||||||
device: &Device,
|
|
||||||
command_buffer: &CommandBufferRef,
|
|
||||||
kernels: &Kernels,
|
|
||||||
name: &'static str,
|
|
||||||
mean: f32,
|
|
||||||
stddev: f32,
|
|
||||||
length: usize,
|
|
||||||
seed: &Buffer,
|
|
||||||
buffer: &Buffer,
|
|
||||||
) -> Result<(), MetalKernelError> {
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Random, name)?;
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
|
|
||||||
let odd = (length % 2 != 0) as usize;
|
|
||||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length / 2 + odd);
|
|
||||||
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
|
||||||
|
|
||||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
|
||||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
|
||||||
encoder.end_encoding();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum GgmlDType {
|
pub enum GgmlDType {
|
||||||
Q4_0,
|
Q4_0,
|
||||||
@ -1743,145 +1554,7 @@ pub fn call_quantized_matmul_t(
|
|||||||
rhs: &Buffer,
|
rhs: &Buffer,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// Everything is in reverse
|
todo!("Not implemented yet");
|
||||||
let ne00 = k as i64;
|
|
||||||
let ne01 = n as i64;
|
|
||||||
let ne02 = b as i64;
|
|
||||||
let ne03 = 1 as i64;
|
|
||||||
|
|
||||||
let nb00 = 0i64;
|
|
||||||
let nb01 = 0 as i64;
|
|
||||||
let nb02 = 0 as i64;
|
|
||||||
|
|
||||||
let ne10 = k as i64;
|
|
||||||
let ne11 = m as i64;
|
|
||||||
let ne12 = b as i64;
|
|
||||||
let ne13 = 1 as i64;
|
|
||||||
|
|
||||||
let nb10 = 0i64;
|
|
||||||
let nb11 = 0i64;
|
|
||||||
let nb12 = 0i64;
|
|
||||||
|
|
||||||
let ne0 = n as i64;
|
|
||||||
let ne1 = m as i64;
|
|
||||||
let r2: u32 = (ne12 / ne02) as u32;
|
|
||||||
let r3: u32 = (ne13 / ne03) as u32;
|
|
||||||
|
|
||||||
let (nth0, nth1, align) = match dtype {
|
|
||||||
GgmlDType::Q4_0
|
|
||||||
| GgmlDType::Q4_1
|
|
||||||
| GgmlDType::Q5_0
|
|
||||||
| GgmlDType::Q5_1
|
|
||||||
| GgmlDType::Q8_0
|
|
||||||
| GgmlDType::Q8_1 => {
|
|
||||||
let nth0 = 8;
|
|
||||||
let nth1 = 8;
|
|
||||||
let align = 8;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::Q2K => {
|
|
||||||
// Fixing a bug in Metal for GGML
|
|
||||||
let nth0 = 4;
|
|
||||||
let nth1 = 8;
|
|
||||||
let align = 4;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::Q4K => {
|
|
||||||
let nth0 = 4;
|
|
||||||
let nth1 = 8;
|
|
||||||
let align = 4;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::Q3K | GgmlDType::Q5K => {
|
|
||||||
let nth0 = 2;
|
|
||||||
let nth1 = 32;
|
|
||||||
let align = 4;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::Q6K => {
|
|
||||||
let nth0 = 2;
|
|
||||||
let nth1 = 32;
|
|
||||||
let align = 2;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::F16 | GgmlDType::Q8K => {
|
|
||||||
// Original implem uses rows
|
|
||||||
let nth0 = 32;
|
|
||||||
let nth1 = 1;
|
|
||||||
let align = 8;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
GgmlDType::F32 => {
|
|
||||||
let nth0 = 32;
|
|
||||||
let nth1 = 1;
|
|
||||||
let align = 8;
|
|
||||||
(nth0, nth1, align)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let thread_groups_count = MTLSize {
|
|
||||||
width: divide(ne01 as usize, align),
|
|
||||||
height: ne11 as u64,
|
|
||||||
depth: (ne12 * ne13) as u64,
|
|
||||||
};
|
|
||||||
let threads_per_threadgroup = MTLSize {
|
|
||||||
width: nth0,
|
|
||||||
height: nth1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
let name = match dtype {
|
|
||||||
GgmlDType::Q4_0 => "kernel_mul_mv_q4_0_f32",
|
|
||||||
GgmlDType::Q4_1 => "kernel_mul_mv_q4_1_f32",
|
|
||||||
GgmlDType::Q5_0 => "kernel_mul_mv_q5_0_f32",
|
|
||||||
GgmlDType::Q5_1 => "kernel_mul_mv_q5_1_f32",
|
|
||||||
GgmlDType::Q8_0 => "kernel_mul_mv_q8_0_f32",
|
|
||||||
GgmlDType::Q8_1 => "kernel_mul_mv_q8_1_f32",
|
|
||||||
GgmlDType::Q2K => "kernel_mul_mv_q2_K_f32",
|
|
||||||
GgmlDType::Q3K => "kernel_mul_mv_q3_K_f32",
|
|
||||||
GgmlDType::Q4K => "kernel_mul_mv_q4_K_f32",
|
|
||||||
GgmlDType::Q5K => "kernel_mul_mv_q5_K_f32",
|
|
||||||
GgmlDType::Q6K => "kernel_mul_mv_q6_K_f32",
|
|
||||||
GgmlDType::Q8K => "kernel_mul_mv_q8_K_f32",
|
|
||||||
GgmlDType::F16 => "kernel_mul_mv_f16_f32",
|
|
||||||
GgmlDType::F32 => "kernel_mul_mv_f32_f32",
|
|
||||||
};
|
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Quantized, name)?;
|
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
|
||||||
|
|
||||||
set_params!(
|
|
||||||
encoder,
|
|
||||||
(
|
|
||||||
rhs,
|
|
||||||
(lhs, lhs_offset),
|
|
||||||
output,
|
|
||||||
ne00,
|
|
||||||
ne01,
|
|
||||||
ne02,
|
|
||||||
nb00,
|
|
||||||
nb01,
|
|
||||||
nb02,
|
|
||||||
ne10,
|
|
||||||
ne11,
|
|
||||||
ne12,
|
|
||||||
nb10,
|
|
||||||
nb11,
|
|
||||||
nb12,
|
|
||||||
ne0,
|
|
||||||
ne1,
|
|
||||||
r2,
|
|
||||||
r3
|
|
||||||
)
|
|
||||||
);
|
|
||||||
encoder.set_threadgroup_memory_length(0, 8192);
|
|
||||||
encoder.use_resource(lhs, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(rhs, metal::MTLResourceUsage::Read);
|
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
|
||||||
|
|
||||||
encoder.dispatch_thread_groups(thread_groups_count, threads_per_threadgroup);
|
|
||||||
encoder.end_encoding();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn divide(m: usize, b: usize) -> NSUInteger {
|
fn divide(m: usize, b: usize) -> NSUInteger {
|
||||||
|
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -1,206 +0,0 @@
|
|||||||
#include <metal_stdlib>
|
|
||||||
#include <metal_integer>
|
|
||||||
#include <metal_atomic>
|
|
||||||
|
|
||||||
using namespace metal;
|
|
||||||
|
|
||||||
// Constants
|
|
||||||
// 2^32 and 1/2^32. Useful for converting between float and uint.
|
|
||||||
static constexpr constant ulong UNIF01_NORM32 = 4294967296;
|
|
||||||
static constexpr constant float UNIF01_INV32 = 2.328306436538696289e-10;
|
|
||||||
// 2 * pi
|
|
||||||
static constexpr constant float TWO_PI = 2.0 * M_PI_F;
|
|
||||||
static constexpr constant int3 S1 = {13, 19, 12};
|
|
||||||
static constexpr constant int3 S2 = {2, 25, 4};
|
|
||||||
static constexpr constant int3 S3 = {3, 11, 17};
|
|
||||||
|
|
||||||
// Used to prevent bad seeds.
|
|
||||||
static constexpr constant uint64_t PHI[16] = {
|
|
||||||
0x9E3779B97F4A7C15,
|
|
||||||
0xF39CC0605CEDC834,
|
|
||||||
0x1082276BF3A27251,
|
|
||||||
0xF86C6A11D0C18E95,
|
|
||||||
0x2767F0B153D27B7F,
|
|
||||||
0x0347045B5BF1827F,
|
|
||||||
0x01886F0928403002,
|
|
||||||
0xC1D64BA40F335E36,
|
|
||||||
0xF06AD7AE9717877E,
|
|
||||||
0x85839D6EFFBD7DC6,
|
|
||||||
0x64D325D1C5371682,
|
|
||||||
0xCADD0CCCFDFFBBE1,
|
|
||||||
0x626E33B8D04B4331,
|
|
||||||
0xBBF73C790D94F79D,
|
|
||||||
0x471C4AB3ED3D82A5,
|
|
||||||
0xFEC507705E4AE6E5,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Combined Tausworthe and LCG Random Number Generator.
|
|
||||||
// https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-37-efficient-random-number-generation-and-application
|
|
||||||
// https://indico.cern.ch/event/93877/contributions/2118070/attachments/1104200/1575343/acat3_revised_final.pdf
|
|
||||||
struct HybridTaus {
|
|
||||||
|
|
||||||
float state;
|
|
||||||
|
|
||||||
HybridTaus() thread = default;
|
|
||||||
HybridTaus() threadgroup = default;
|
|
||||||
HybridTaus() device = default;
|
|
||||||
HybridTaus() constant = default;
|
|
||||||
|
|
||||||
// Generate seeds for each thread.
|
|
||||||
METAL_FUNC static uint4 seed_per_thread(const ulong4 seeds) {
|
|
||||||
return uint4(ulong4(seeds) * ulong4(PHI[0], PHI[1], PHI[2], PHI[3]) * ulong4(1099087573UL));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tausworthe generator.
|
|
||||||
METAL_FUNC static uint taus(const uint z, const int3 s, const uint M) {
|
|
||||||
uint b = (((z << s.x) ^ z) >> s.y);
|
|
||||||
return (((z & M) << s.z) ^ b);
|
|
||||||
}
|
|
||||||
|
|
||||||
// LCG generator.
|
|
||||||
METAL_FUNC static uint lcg(const uint z) {
|
|
||||||
return (1664525 * z + 1013904223UL);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the RNG state.
|
|
||||||
METAL_FUNC static HybridTaus init(const ulong4 seeds) {
|
|
||||||
uint4 seed = seed_per_thread(seeds);
|
|
||||||
|
|
||||||
// Seed #1
|
|
||||||
uint z1 = taus(seed.x, S1, 4294967294UL);
|
|
||||||
uint z2 = taus(seed.y, S2, 4294967288UL);
|
|
||||||
uint z3 = taus(seed.z, S3, 4294967280UL);
|
|
||||||
uint z4 = lcg(seed.x);
|
|
||||||
|
|
||||||
// Seed #2
|
|
||||||
uint r1 = (z1^z2^z3^z4^seed.y);
|
|
||||||
z1 = taus(r1, S1, 429496729UL);
|
|
||||||
z2 = taus(r1, S2, 4294967288UL);
|
|
||||||
z3 = taus(r1, S3, 429496280UL);
|
|
||||||
z4 = lcg(r1);
|
|
||||||
|
|
||||||
// Seed #3
|
|
||||||
r1 = (z1^z2^z3^z4^seed.z);
|
|
||||||
z1 = taus(r1, S1, 429496729UL);
|
|
||||||
z2 = taus(r1, S2, 4294967288UL);
|
|
||||||
z3 = taus(r1, S3, 429496280UL);
|
|
||||||
z4 = lcg(r1);
|
|
||||||
|
|
||||||
// Seed #4
|
|
||||||
r1 = (z1^z2^z3^z4^seed.w);
|
|
||||||
z1 = taus(r1, S1, 429496729UL);
|
|
||||||
z2 = taus(r1, S2, 4294967288UL);
|
|
||||||
z3 = taus(r1, S3, 429496280UL);
|
|
||||||
z4 = lcg(r1);
|
|
||||||
|
|
||||||
HybridTaus rng;
|
|
||||||
rng.state = (z1^z2^z3^z4) * UNIF01_INV32;
|
|
||||||
return rng;
|
|
||||||
}
|
|
||||||
|
|
||||||
METAL_FUNC float rand() {
|
|
||||||
uint seed = this->state * UNIF01_NORM32;
|
|
||||||
uint z1 = taus(seed, S1, 429496729UL);
|
|
||||||
uint z2 = taus(seed, S2, 4294967288UL);
|
|
||||||
uint z3 = taus(seed, S3, 429496280UL);
|
|
||||||
uint z4 = lcg(seed);
|
|
||||||
|
|
||||||
thread float result = this->state;
|
|
||||||
this->state = (z1^z2^z3^z4) * UNIF01_INV32;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename T> METAL_FUNC void rand_uniform(
|
|
||||||
constant size_t &size,
|
|
||||||
constant float &min,
|
|
||||||
constant float &max,
|
|
||||||
device atomic_uint *seed,
|
|
||||||
device T *out,
|
|
||||||
uint tid [[thread_position_in_grid]]
|
|
||||||
) {
|
|
||||||
if (tid >= size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
float diff = abs(min - max);
|
|
||||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
|
||||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
|
||||||
if (tid == 0) {
|
|
||||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
|
||||||
// Return early if tid == 0, otherwise we will write to out[size].
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Use symmetry to fill the other half of the array.
|
|
||||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create Gaussian normal distribution using Box-Muller transform:
|
|
||||||
// https://en.wikipedia.org/wiki/Box–Muller_transform
|
|
||||||
template<typename T> METAL_FUNC void normal(
|
|
||||||
constant size_t &size,
|
|
||||||
constant float &mean,
|
|
||||||
constant float &stddev,
|
|
||||||
device atomic_uint *seed,
|
|
||||||
device T *out,
|
|
||||||
uint tid [[thread_position_in_grid]]
|
|
||||||
) {
|
|
||||||
if (tid >= size) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
|
||||||
float u1 = rng.rand();
|
|
||||||
float u2 = rng.rand();
|
|
||||||
|
|
||||||
float cosval;
|
|
||||||
float sinval = sincos(TWO_PI * u2, cosval);
|
|
||||||
float mag = stddev * sqrt(-2.0 * log(u1));
|
|
||||||
float z0 = mag * cosval + mean;
|
|
||||||
float z1 = mag * sinval + mean;
|
|
||||||
|
|
||||||
out[tid] = static_cast<T>(z0);
|
|
||||||
|
|
||||||
if (tid == 0) {
|
|
||||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
|
||||||
// Return early if tid == 0, otherwise we will write to out[size].
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// Use symmetry to fill the other half of the array.
|
|
||||||
out[size - tid] = static_cast<T>(z1);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define UNIFORM_OP(NAME, T) \
|
|
||||||
kernel void rand_uniform_##NAME( \
|
|
||||||
constant size_t &size, \
|
|
||||||
constant float &min, \
|
|
||||||
constant float &max, \
|
|
||||||
device atomic_uint *seed, \
|
|
||||||
device T *out, \
|
|
||||||
uint tid [[thread_position_in_grid]] \
|
|
||||||
) { \
|
|
||||||
rand_uniform<T>(size, min, max, seed, out, tid); \
|
|
||||||
} \
|
|
||||||
|
|
||||||
#define NORMAL_OP(NAME, T) \
|
|
||||||
kernel void rand_normal_##NAME( \
|
|
||||||
constant size_t &size, \
|
|
||||||
constant float &mean, \
|
|
||||||
constant float &stddev, \
|
|
||||||
device atomic_uint *seed, \
|
|
||||||
device T *out, \
|
|
||||||
uint tid [[thread_position_in_grid]] \
|
|
||||||
) { \
|
|
||||||
normal<T>(size, mean, stddev, seed, out, tid); \
|
|
||||||
} \
|
|
||||||
|
|
||||||
|
|
||||||
#define RANDOM_OPS(NAME, T) \
|
|
||||||
UNIFORM_OP(NAME, T) \
|
|
||||||
NORMAL_OP(NAME, T) \
|
|
||||||
|
|
||||||
RANDOM_OPS(f32, float)
|
|
||||||
RANDOM_OPS(f16, half)
|
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 310
|
|
||||||
RANDOM_OPS(bf16, bfloat)
|
|
||||||
#endif
|
|
@ -11,7 +11,7 @@ fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
|||||||
|
|
||||||
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let ptr = data.as_ptr() as *const c_void;
|
let ptr = data.as_ptr() as *const core::ffi::c_void;
|
||||||
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
let size = (data.len() * std::mem::size_of::<T>()) as u64;
|
||||||
device.new_buffer_with_data(ptr, size, options)
|
device.new_buffer_with_data(ptr, size, options)
|
||||||
}
|
}
|
||||||
@ -713,6 +713,7 @@ fn softmax() {
|
|||||||
}
|
}
|
||||||
let results = run_softmax(&v, last_dim, "softmax_f32");
|
let results = run_softmax(&v, last_dim, "softmax_f32");
|
||||||
let results = approx(results, 4);
|
let results = approx(results, 4);
|
||||||
|
println!("{results:?}");
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
|
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
|
||||||
n
|
n
|
||||||
@ -926,124 +927,3 @@ fn gemm() {
|
|||||||
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_random<T: Clone>(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec<T> {
|
|
||||||
let device = device();
|
|
||||||
let kernels = Kernels::new();
|
|
||||||
let command_queue = device.new_command_queue();
|
|
||||||
let command_buffer = command_queue.new_command_buffer();
|
|
||||||
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
|
||||||
let output = device.new_buffer((length * core::mem::size_of::<T>()) as NSUInteger, options);
|
|
||||||
|
|
||||||
let seed = device.new_buffer_with_data(
|
|
||||||
&seed as *const u32 as *const core::ffi::c_void,
|
|
||||||
std::mem::size_of::<u32>() as NSUInteger,
|
|
||||||
options,
|
|
||||||
);
|
|
||||||
|
|
||||||
if name.starts_with("rand_uniform") {
|
|
||||||
call_random_uniform(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
length,
|
|
||||||
&seed,
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
} else {
|
|
||||||
call_random_normal(
|
|
||||||
&device,
|
|
||||||
command_buffer,
|
|
||||||
&kernels,
|
|
||||||
name,
|
|
||||||
a,
|
|
||||||
b,
|
|
||||||
length,
|
|
||||||
&seed,
|
|
||||||
&output,
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
}
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
|
|
||||||
read_to_vec(&output, length)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn random() {
|
|
||||||
fn calc_mean(data: &[f32]) -> f32 {
|
|
||||||
let sum = data.iter().sum::<f32>() as f32;
|
|
||||||
let count = data.len();
|
|
||||||
assert!(count > 0);
|
|
||||||
sum / count as f32
|
|
||||||
}
|
|
||||||
|
|
||||||
fn calc_stddev(data: &[f32]) -> f32 {
|
|
||||||
let mean = calc_mean(data);
|
|
||||||
let count = data.len();
|
|
||||||
assert!(count > 0);
|
|
||||||
|
|
||||||
let variance = data
|
|
||||||
.iter()
|
|
||||||
.map(|value| {
|
|
||||||
let diff = mean - (*value as f32);
|
|
||||||
diff * diff
|
|
||||||
})
|
|
||||||
.sum::<f32>()
|
|
||||||
/ count as f32;
|
|
||||||
|
|
||||||
variance.sqrt()
|
|
||||||
}
|
|
||||||
|
|
||||||
let shape = vec![1024, 10];
|
|
||||||
|
|
||||||
let length = shape.iter().product::<usize>();
|
|
||||||
let seed = 299792458;
|
|
||||||
|
|
||||||
let min = -30.0;
|
|
||||||
let max = 30.0;
|
|
||||||
let mean = 100.0;
|
|
||||||
let stddev = 50.0;
|
|
||||||
|
|
||||||
macro_rules! validate_random {
|
|
||||||
($type:ty) => {
|
|
||||||
let results: Vec<f32> = run_random::<$type>(
|
|
||||||
concat!("rand_uniform_", stringify!($type)),
|
|
||||||
seed,
|
|
||||||
length,
|
|
||||||
min,
|
|
||||||
max,
|
|
||||||
)
|
|
||||||
.into_iter()
|
|
||||||
.map(f32::from)
|
|
||||||
.collect();
|
|
||||||
results.iter().for_each(|v| {
|
|
||||||
assert!(*v >= min && *v <= max);
|
|
||||||
});
|
|
||||||
assert!(calc_mean(&results) > -1.0 && calc_mean(&results) < 1.0);
|
|
||||||
|
|
||||||
let results: Vec<f32> = run_random::<$type>(
|
|
||||||
concat!("rand_normal_", stringify!($type)),
|
|
||||||
seed,
|
|
||||||
length,
|
|
||||||
mean,
|
|
||||||
stddev,
|
|
||||||
)
|
|
||||||
.into_iter()
|
|
||||||
.map(f32::from)
|
|
||||||
.collect();
|
|
||||||
assert!((calc_mean(&results) - mean).abs() < mean / 10.0);
|
|
||||||
assert!((calc_stddev(&results) - stddev).abs() < stddev / 10.0);
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
validate_random!(f32);
|
|
||||||
validate_random!(f16);
|
|
||||||
validate_random!(bf16);
|
|
||||||
}
|
|
||||||
|
@ -262,19 +262,9 @@ impl BatchNorm {
|
|||||||
let target_shape = target_shape.as_slice();
|
let target_shape = target_shape.as_slice();
|
||||||
|
|
||||||
let x = x
|
let x = x
|
||||||
.broadcast_sub(
|
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
|
||||||
&self
|
|
||||||
.running_mean
|
|
||||||
.as_detached_tensor()
|
|
||||||
.reshape(target_shape)?,
|
|
||||||
)?
|
|
||||||
.broadcast_div(
|
.broadcast_div(
|
||||||
&(self
|
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
|
||||||
.running_var
|
|
||||||
.as_detached_tensor()
|
|
||||||
.reshape(target_shape)?
|
|
||||||
+ self.eps)?
|
|
||||||
.sqrt()?,
|
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
match &self.weight_and_bias {
|
match &self.weight_and_bias {
|
||||||
|
@ -124,7 +124,7 @@ fn set_at_index<D: WithDType, I: Into<i64>>(
|
|||||||
value: I,
|
value: I,
|
||||||
offset: usize,
|
offset: usize,
|
||||||
depth: usize,
|
depth: usize,
|
||||||
v: &mut [D],
|
v: &mut Vec<D>,
|
||||||
on_value: D,
|
on_value: D,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let value = value.into();
|
let value = value.into();
|
||||||
|
@ -412,16 +412,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> VarBuilder<'a> {
|
impl<'a> VarBuilder<'a> {
|
||||||
/// Initializes a `VarBuilder` using a custom backend.
|
fn new(backend: Box<dyn SimpleBackend + 'a>, dtype: DType, device: Device) -> Self {
|
||||||
///
|
|
||||||
/// It is preferred to use one of the more specific constructors. This
|
|
||||||
/// constructor is provided to allow downstream users to define their own
|
|
||||||
/// backends.
|
|
||||||
pub fn from_backend(
|
|
||||||
backend: Box<dyn SimpleBackend + 'a>,
|
|
||||||
dtype: DType,
|
|
||||||
device: Device,
|
|
||||||
) -> Self {
|
|
||||||
let data = TensorData {
|
let data = TensorData {
|
||||||
backend,
|
backend,
|
||||||
dtype,
|
dtype,
|
||||||
@ -436,13 +427,13 @@ impl<'a> VarBuilder<'a> {
|
|||||||
|
|
||||||
/// Initializes a `VarBuilder` that uses zeros for any tensor.
|
/// Initializes a `VarBuilder` that uses zeros for any tensor.
|
||||||
pub fn zeros(dtype: DType, dev: &Device) -> Self {
|
pub fn zeros(dtype: DType, dev: &Device) -> Self {
|
||||||
Self::from_backend(Box::new(Zeros), dtype, dev.clone())
|
Self::new(Box::new(Zeros), dtype, dev.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a hashtable. An error is
|
||||||
/// returned if no tensor is available under the requested path or on shape mismatches.
|
/// returned if no tensor is available under the requested path or on shape mismatches.
|
||||||
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
|
pub fn from_tensors(ts: HashMap<String, Tensor>, dtype: DType, dev: &Device) -> Self {
|
||||||
Self::from_backend(Box::new(ts), dtype, dev.clone())
|
Self::new(Box::new(ts), dtype, dev.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
|
/// Initializes a `VarBuilder` using a `VarMap`. The requested tensors are created and
|
||||||
@ -452,7 +443,7 @@ impl<'a> VarBuilder<'a> {
|
|||||||
/// Note that it is possible to load the tensor values after model creation using the `load`
|
/// Note that it is possible to load the tensor values after model creation using the `load`
|
||||||
/// method on `varmap`, this can be used to start model training from an existing checkpoint.
|
/// method on `varmap`, this can be used to start model training from an existing checkpoint.
|
||||||
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
|
pub fn from_varmap(varmap: &VarMap, dtype: DType, dev: &Device) -> Self {
|
||||||
Self::from_backend(Box::new(varmap.clone()), dtype, dev.clone())
|
Self::new(Box::new(varmap.clone()), dtype, dev.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a collection of safetensors
|
||||||
@ -467,25 +458,25 @@ impl<'a> VarBuilder<'a> {
|
|||||||
dev: &Device,
|
dev: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
|
let tensors = candle::safetensors::MmapedSafetensors::multi(paths)?;
|
||||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
/// Initializes a `VarBuilder` from a binary builder in the safetensor format.
|
||||||
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_buffered_safetensors(data: Vec<u8>, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
let tensors = candle::safetensors::BufferedSafetensors::new(data)?;
|
||||||
Ok(Self::from_backend(Box::new(tensors), dtype, dev.clone()))
|
Ok(Self::new(Box::new(tensors), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a numpy npz file.
|
||||||
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_npz<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let npz = candle::npy::NpzTensors::new(p)?;
|
let npz = candle::npy::NpzTensors::new(p)?;
|
||||||
Ok(Self::from_backend(Box::new(npz), dtype, dev.clone()))
|
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
|
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
|
||||||
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
|
||||||
let pth = candle::pickle::PthTensors::new(p, None)?;
|
let pth = candle::pickle::PthTensors::new(p)?;
|
||||||
Ok(Self::from_backend(Box::new(pth), dtype, dev.clone()))
|
Ok(Self::new(Box::new(pth), dtype, dev.clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.4.0"
|
version = "0.3.3"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.0" }
|
candle = { path = "../candle-core", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.4.0" }
|
candle-nn = { path = "../candle-nn" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -766,16 +766,6 @@ pub fn simple_eval(
|
|||||||
let output = input.cumsum(axis as usize)?;
|
let output = input.cumsum(axis as usize)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten
|
|
||||||
"Flatten" => {
|
|
||||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(1) as usize;
|
|
||||||
let input = get(&node.input[0])?;
|
|
||||||
let first_part: usize = input.shape().dims().iter().take(axis).product();
|
|
||||||
let end_index = input.shape().dims().iter().product::<usize>();
|
|
||||||
let new_shape = (first_part, end_index / first_part);
|
|
||||||
let output = input.reshape(new_shape)?;
|
|
||||||
values.insert(node.output[0].clone(), output);
|
|
||||||
}
|
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
const INPUT_X: &str = "x";
|
const INPUT_X: &str = "x";
|
||||||
@ -677,134 +677,6 @@ fn test_dropout_operation() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// "Flatten"
|
|
||||||
#[test]
|
|
||||||
fn test_flatten_operation() -> Result<()> {
|
|
||||||
let mut att_axis = AttributeProto {
|
|
||||||
name: "axis".to_string(),
|
|
||||||
ref_attr_name: "axis".to_string(),
|
|
||||||
i: 0,
|
|
||||||
doc_string: "axis".to_string(),
|
|
||||||
r#type: 2,
|
|
||||||
f: 0.0,
|
|
||||||
s: vec![],
|
|
||||||
t: None,
|
|
||||||
g: None,
|
|
||||||
sparse_tensor: None,
|
|
||||||
tp: None,
|
|
||||||
floats: vec![],
|
|
||||||
ints: vec![],
|
|
||||||
strings: vec![],
|
|
||||||
tensors: vec![],
|
|
||||||
graphs: vec![],
|
|
||||||
sparse_tensors: vec![],
|
|
||||||
type_protos: vec![],
|
|
||||||
};
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Flatten".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![att_axis.clone()],
|
|
||||||
input: vec![INPUT_X.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_X.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_Y.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
let x = Tensor::from_vec(
|
|
||||||
vec![
|
|
||||||
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32,
|
|
||||||
],
|
|
||||||
&[2, 2, 2],
|
|
||||||
&Device::Cpu,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
|
||||||
inputs.insert(INPUT_X.to_string(), x);
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
|
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]);
|
|
||||||
|
|
||||||
att_axis.i = 1;
|
|
||||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
|
||||||
node: vec![NodeProto {
|
|
||||||
op_type: "Flatten".to_string(),
|
|
||||||
domain: "".to_string(),
|
|
||||||
attribute: vec![att_axis.clone()],
|
|
||||||
input: vec![INPUT_X.to_string()],
|
|
||||||
output: vec![OUTPUT_Z.to_string()],
|
|
||||||
name: "".to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
}],
|
|
||||||
name: "".to_string(),
|
|
||||||
initializer: vec![],
|
|
||||||
input: vec![
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_X.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
ValueInfoProto {
|
|
||||||
name: INPUT_Y.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
output: vec![ValueInfoProto {
|
|
||||||
name: OUTPUT_Z.to_string(),
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
r#type: None,
|
|
||||||
}],
|
|
||||||
value_info: vec![],
|
|
||||||
doc_string: "".to_string(),
|
|
||||||
sparse_initializer: vec![],
|
|
||||||
quantization_annotation: vec![],
|
|
||||||
}));
|
|
||||||
|
|
||||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
|
||||||
assert_eq!(eval.len(), 1);
|
|
||||||
|
|
||||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
|
||||||
|
|
||||||
let results = z.to_vec2::<f32>()?;
|
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
results,
|
|
||||||
vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]]
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Below are ops that are implemented but not tested yet
|
// Below are ops that are implemented but not tested yet
|
||||||
|
|
||||||
// "MaxPool"
|
// "MaxPool"
|
||||||
|
@ -88,27 +88,23 @@ class QTensor:
|
|||||||
Dequantizes the tensor.
|
Dequantizes the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ggml_dtype(self) -> str:
|
def ggml_dtype(self) -> str:
|
||||||
"""
|
"""
|
||||||
Gets the tensors quantized dtype.
|
Gets the tensors quantized dtype.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def matmul_t(self, lhs: Tensor) -> Tensor:
|
def matmul_t(self, lhs: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
|
Performs a quantized matrix multiplication, with the quantized tensor as the right hand side.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rank(self) -> int:
|
def rank(self) -> int:
|
||||||
"""
|
"""
|
||||||
Gets the rank of the tensor.
|
Gets the rank of the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Tuple[int]:
|
def shape(self) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
@ -123,213 +119,178 @@ class Tensor:
|
|||||||
|
|
||||||
def __init__(self, data: _ArrayLike):
|
def __init__(self, data: _ArrayLike):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __add__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Add a scalar to a tensor or two tensors together.
|
Add a scalar to a tensor or two tensors together.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Compare a tensor with a scalar or one tensor with another.
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Compare a tensor with a scalar or one tensor with another.
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Return a slice of a tensor.
|
Return a slice of a tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Compare a tensor with a scalar or one tensor with another.
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Compare a tensor with a scalar or one tensor with another.
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Compare a tensor with a scalar or one tensor with another.
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Multiply a tensor by a scalar or one tensor by another.
|
Multiply a tensor by a scalar or one tensor by another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Compare a tensor with a scalar or one tensor with another.
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Add a scalar to a tensor or two tensors together.
|
Add a scalar to a tensor or two tensors together.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
|
def __richcmp__(self, rhs: Union[Tensor, Scalar], op) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Compare a tensor with a scalar or one tensor with another.
|
Compare a tensor with a scalar or one tensor with another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __rmul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Multiply a tensor by a scalar or one tensor by another.
|
Multiply a tensor by a scalar or one tensor by another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __sub__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Subtract a scalar from a tensor or one tensor from another.
|
Subtract a scalar from a tensor or one tensor from another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
def __truediv__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
|
||||||
"""
|
"""
|
||||||
Divide a tensor by a scalar or one tensor by another.
|
Divide a tensor by a scalar or one tensor by another.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def abs(self) -> Tensor:
|
def abs(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the `abs` operation on the tensor.
|
Performs the `abs` operation on the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def argmax_keepdim(self, dim: int) -> Tensor:
|
def argmax_keepdim(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the indices of the maximum value(s) across the selected dimension.
|
Returns the indices of the maximum value(s) across the selected dimension.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def argmin_keepdim(self, dim: int) -> Tensor:
|
def argmin_keepdim(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the indices of the minimum value(s) across the selected dimension.
|
Returns the indices of the minimum value(s) across the selected dimension.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def broadcast_add(self, rhs: Tensor) -> Tensor:
|
def broadcast_add(self, rhs: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def broadcast_as(self, *shape: Shape) -> Tensor:
|
def broadcast_as(self, *shape: Shape) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Broadcasts the tensor to the given shape.
|
Broadcasts the tensor to the given shape.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def broadcast_div(self, rhs: Tensor) -> Tensor:
|
def broadcast_div(self, rhs: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def broadcast_left(self, *shape: Shape) -> Tensor:
|
def broadcast_left(self, *shape: Shape) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
Broadcasts the tensor to the given shape, adding new dimensions on the left.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def broadcast_mul(self, rhs: Tensor) -> Tensor:
|
def broadcast_mul(self, rhs: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
Multiplies the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def broadcast_sub(self, rhs: Tensor) -> Tensor:
|
def broadcast_sub(self, rhs: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
Subtracts the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def contiguous(self) -> Tensor:
|
def contiguous(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Makes the tensor contiguous in memory.
|
Makes the tensor contiguous in memory.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def copy(self) -> Tensor:
|
def copy(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns a copy of the tensor.
|
Returns a copy of the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def cos(self) -> Tensor:
|
def cos(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the `cos` operation on the tensor.
|
Performs the `cos` operation on the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def detach(self) -> Tensor:
|
def detach(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Detach the tensor from the computation graph.
|
Detach the tensor from the computation graph.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> Device:
|
def device(self) -> Device:
|
||||||
"""
|
"""
|
||||||
Gets the tensor's device.
|
Gets the tensor's device.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> DType:
|
def dtype(self) -> DType:
|
||||||
"""
|
"""
|
||||||
Gets the tensor's dtype.
|
Gets the tensor's dtype.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def exp(self) -> Tensor:
|
def exp(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the `exp` operation on the tensor.
|
Performs the `exp` operation on the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def flatten_all(self) -> Tensor:
|
def flatten_all(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Flattens the tensor into a 1D tensor.
|
Flattens the tensor into a 1D tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def flatten_from(self, dim: int) -> Tensor:
|
def flatten_from(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
|
Flattens the tensor on the dimension indexes from `dim` (inclusive) to the last dimension.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def flatten_to(self, dim: int) -> Tensor:
|
def flatten_to(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
|
Flattens the tensor on the dimension indexes from `0` to `dim` (inclusive).
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get(self, index: int) -> Tensor:
|
def get(self, index: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Gets the value at the specified index.
|
Gets the value at the specified index.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def index_select(self, rhs: Tensor, dim: int) -> Tensor:
|
def index_select(self, rhs: Tensor, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Select values for the input tensor at the target indexes across the specified dimension.
|
Select values for the input tensor at the target indexes across the specified dimension.
|
||||||
@ -341,192 +302,161 @@ class Tensor:
|
|||||||
tensor.
|
tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_contiguous(self) -> bool:
|
def is_contiguous(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns true if the tensor is contiguous in C order.
|
Returns true if the tensor is contiguous in C order.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def is_fortran_contiguous(self) -> bool:
|
def is_fortran_contiguous(self) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns true if the tensor is contiguous in Fortran order.
|
Returns true if the tensor is contiguous in Fortran order.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log(self) -> Tensor:
|
def log(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the `log` operation on the tensor.
|
Performs the `log` operation on the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def matmul(self, rhs: Tensor) -> Tensor:
|
def matmul(self, rhs: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs a matrix multiplication between the two tensors.
|
Performs a matrix multiplication between the two tensors.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def max_keepdim(self, dim: int) -> Tensor:
|
def max_keepdim(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Gathers the maximum value across the selected dimension.
|
Gathers the maximum value across the selected dimension.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def mean_all(self) -> Tensor:
|
def mean_all(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the mean of the tensor.
|
Returns the mean of the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def min_keepdim(self, dim: int) -> Tensor:
|
def min_keepdim(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Gathers the minimum value across the selected dimension.
|
Gathers the minimum value across the selected dimension.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def narrow(self, dim: int, start: int, len: int) -> Tensor:
|
def narrow(self, dim: int, start: int, len: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
Returns a new tensor that is a narrowed version of the input, the dimension `dim`
|
||||||
ranges from `start` to `start + len`.
|
ranges from `start` to `start + len`.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nelement(self) -> int:
|
def nelement(self) -> int:
|
||||||
"""
|
"""
|
||||||
Gets the tensor's element count.
|
Gets the tensor's element count.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def powf(self, p: float) -> Tensor:
|
def powf(self, p: float) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the `pow` operation on the tensor with the given exponent.
|
Performs the `pow` operation on the tensor with the given exponent.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def quantize(self, quantized_dtype: str) -> QTensor:
|
def quantize(self, quantized_dtype: str) -> QTensor:
|
||||||
"""
|
"""
|
||||||
Quantize the tensor.
|
Quantize the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rank(self) -> int:
|
def rank(self) -> int:
|
||||||
"""
|
"""
|
||||||
Gets the tensor's rank.
|
Gets the tensor's rank.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def recip(self) -> Tensor:
|
def recip(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Get the `recip` of the tensor.
|
Get the `recip` of the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def reshape(self, *shape: Shape) -> Tensor:
|
def reshape(self, *shape: Shape) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Reshapes the tensor to the given shape.
|
Reshapes the tensor to the given shape.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Tuple[int]:
|
def shape(self) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
Gets the tensor's shape.
|
Gets the tensor's shape.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sin(self) -> Tensor:
|
def sin(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs the `sin` operation on the tensor.
|
Performs the `sin` operation on the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sqr(self) -> Tensor:
|
def sqr(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Squares the tensor.
|
Squares the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sqrt(self) -> Tensor:
|
def sqrt(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Calculates the square root of the tensor.
|
Calculates the square root of the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def squeeze(self, dim: int) -> Tensor:
|
def squeeze(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Creates a new tensor with the specified dimension removed if its size was one.
|
Creates a new tensor with the specified dimension removed if its size was one.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stride(self) -> Tuple[int]:
|
def stride(self) -> Tuple[int]:
|
||||||
"""
|
"""
|
||||||
Gets the tensor's strides.
|
Gets the tensor's strides.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sum_all(self) -> Tensor:
|
def sum_all(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the sum of the tensor.
|
Returns the sum of the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor:
|
def sum_keepdim(self, dim: Union[int, List[int]]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
|
Returns the sum of all elements in the input tensor. The sum is performed over all the input dimensions.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def t(self) -> Tensor:
|
def t(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Transposes the tensor.
|
Transposes the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to(self, *args, **kwargs) -> Tensor:
|
def to(self, *args, **kwargs) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Performs Tensor dtype and/or device conversion.
|
Performs Tensor dtype and/or device conversion.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to_device(self, device: Union[str, Device]) -> Tensor:
|
def to_device(self, device: Union[str, Device]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Move the tensor to a new device.
|
Move the tensor to a new device.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to_dtype(self, dtype: Union[str, DType]) -> Tensor:
|
def to_dtype(self, dtype: Union[str, DType]) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Convert the tensor to a new dtype.
|
Convert the tensor to a new dtype.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to_torch(self) -> torch.Tensor:
|
def to_torch(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Converts candle's tensor to pytorch's tensor
|
Converts candle's tensor to pytorch's tensor
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def transpose(self, dim1: int, dim2: int) -> Tensor:
|
def transpose(self, dim1: int, dim2: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def unsqueeze(self, dim: int) -> Tensor:
|
def unsqueeze(self, dim: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Creates a new tensor with a dimension of size one inserted at the specified position.
|
Creates a new tensor with a dimension of size one inserted at the specified position.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def values(self) -> _ArrayLike:
|
def values(self) -> _ArrayLike:
|
||||||
"""
|
"""
|
||||||
Gets the tensor's data as a Python scalar or array-like object.
|
Gets the tensor's data as a Python scalar or array-like object.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor:
|
def where_cond(self, on_true: Tensor, on_false: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns a tensor with the same shape as the input tensor, the values are taken from
|
Returns a tensor with the same shape as the input tensor, the values are taken from
|
||||||
|
@ -57,10 +57,12 @@ class Sequential(Module):
|
|||||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __init__(self, *args: Module) -> None: ...
|
def __init__(self, *args: Module) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def __init__(self, arg: "OrderedDict[str, Module]") -> None: ...
|
def __init__(self, arg: "OrderedDict[str, Module]") -> None:
|
||||||
|
...
|
||||||
|
|
||||||
def __init__(self, *args):
|
def __init__(self, *args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -204,10 +204,12 @@ class Module:
|
|||||||
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
|
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination: ...
|
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]: ...
|
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
|
||||||
|
...
|
||||||
|
|
||||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||||
r"""Returns a dictionary containing references to the whole state of the module.
|
r"""Returns a dictionary containing references to the whole state of the module.
|
||||||
@ -584,10 +586,12 @@ class Module:
|
|||||||
self: T,
|
self: T,
|
||||||
device: str = ...,
|
device: str = ...,
|
||||||
dtype: Optional[Union[DType, str]] = ...,
|
dtype: Optional[Union[DType, str]] = ...,
|
||||||
) -> T: ...
|
) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def to(self: T, dtype: Union[DType, str]) -> T: ...
|
def to(self: T, dtype: Union[DType, str]) -> T:
|
||||||
|
...
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
def to(self, *args, **kwargs):
|
||||||
r"""Moves and/or casts the parameters and buffers.
|
r"""Moves and/or casts the parameters and buffers.
|
||||||
|
@ -14,7 +14,6 @@ class LayerNorm(Module):
|
|||||||
math::
|
math::
|
||||||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__constants__ = ["normalized_shape", "eps"]
|
__constants__ = ["normalized_shape", "eps"]
|
||||||
normalized_shape: Tuple[int, ...]
|
normalized_shape: Tuple[int, ...]
|
||||||
eps: float
|
eps: float
|
||||||
|
@ -11,69 +11,59 @@ class ONNXModel:
|
|||||||
|
|
||||||
def __init__(self, path: str):
|
def __init__(self, path: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def doc_string(self) -> str:
|
def doc_string(self) -> str:
|
||||||
"""
|
"""
|
||||||
The doc string of the model.
|
The doc string of the model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def domain(self) -> str:
|
def domain(self) -> str:
|
||||||
"""
|
"""
|
||||||
The domain of the operator set of the model.
|
The domain of the operator set of the model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def initializers(self) -> Dict[str, Tensor]:
|
def initializers(self) -> Dict[str, Tensor]:
|
||||||
"""
|
"""
|
||||||
Get the weights of the model.
|
Get the weights of the model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
||||||
"""
|
"""
|
||||||
The inputs of the model.
|
The inputs of the model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ir_version(self) -> int:
|
def ir_version(self) -> int:
|
||||||
"""
|
"""
|
||||||
The version of the IR this model targets.
|
The version of the IR this model targets.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model_version(self) -> int:
|
def model_version(self) -> int:
|
||||||
"""
|
"""
|
||||||
The version of the model.
|
The version of the model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
|
||||||
"""
|
"""
|
||||||
The outputs of the model.
|
The outputs of the model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def producer_name(self) -> str:
|
def producer_name(self) -> str:
|
||||||
"""
|
"""
|
||||||
The producer of the model.
|
The producer of the model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def producer_version(self) -> str:
|
def producer_version(self) -> str:
|
||||||
"""
|
"""
|
||||||
The version of the producer of the model.
|
The version of the producer of the model.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||||
"""
|
"""
|
||||||
Run the model on the given inputs.
|
Run the model on the given inputs.
|
||||||
@ -91,7 +81,6 @@ class ONNXTensorDescription:
|
|||||||
The data type of the tensor.
|
The data type of the tensor.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self) -> Tuple[Union[int, str, Any]]:
|
def shape(self) -> Tuple[Union[int, str, Any]]:
|
||||||
"""
|
"""
|
||||||
|
@ -938,8 +938,8 @@ impl PyTensor {
|
|||||||
|
|
||||||
/// Detach the tensor from the computation graph.
|
/// Detach the tensor from the computation graph.
|
||||||
/// &RETURNS&: Tensor
|
/// &RETURNS&: Tensor
|
||||||
fn detach(&self) -> Self {
|
fn detach(&self) -> PyResult<Self> {
|
||||||
PyTensor(self.0.detach())
|
Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a copy of the tensor.
|
/// Returns a copy of the tensor.
|
||||||
|
@ -189,6 +189,7 @@ def do_black(content, is_pyi):
|
|||||||
line_length=119,
|
line_length=119,
|
||||||
is_pyi=is_pyi,
|
is_pyi=is_pyi,
|
||||||
string_normalization=True,
|
string_normalization=True,
|
||||||
|
experimental_string_processing=False,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
return black.format_file_contents(content, fast=True, mode=mode)
|
return black.format_file_contents(content, fast=True, mode=mode)
|
||||||
|
@ -23,6 +23,7 @@ serde = { workspace = true }
|
|||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
serde_plain = { workspace = true }
|
serde_plain = { workspace = true }
|
||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
|
wav = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
@ -1,593 +0,0 @@
|
|||||||
use crate::models::with_tracing::Linear;
|
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
|
||||||
use candle_nn::VarBuilder;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Config {
|
|
||||||
pub num_layers: usize,
|
|
||||||
pub padded_vocab_size: usize,
|
|
||||||
pub hidden_size: usize,
|
|
||||||
pub ffn_hidden_size: usize,
|
|
||||||
pub kv_channels: usize,
|
|
||||||
pub num_attention_heads: usize,
|
|
||||||
pub seq_length: usize,
|
|
||||||
pub layernorm_epsilon: f64,
|
|
||||||
pub rmsnorm: bool,
|
|
||||||
pub apply_residual_connection_post_layernorm: bool,
|
|
||||||
pub post_layer_norm: bool,
|
|
||||||
pub add_bias_linear: bool,
|
|
||||||
pub add_qkv_bias: bool,
|
|
||||||
pub bias_dropout_fusion: bool,
|
|
||||||
pub multi_query_attention: bool,
|
|
||||||
pub multi_query_group_num: usize,
|
|
||||||
pub apply_query_key_layer_scaling: bool,
|
|
||||||
pub attention_softmax_in_fp32: bool,
|
|
||||||
pub fp32_residual_connection: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
pub fn glm3_6b() -> Self {
|
|
||||||
Self {
|
|
||||||
num_layers: 28,
|
|
||||||
padded_vocab_size: 65024,
|
|
||||||
hidden_size: 4096,
|
|
||||||
ffn_hidden_size: 13696,
|
|
||||||
kv_channels: 128,
|
|
||||||
num_attention_heads: 32,
|
|
||||||
seq_length: 8192,
|
|
||||||
layernorm_epsilon: 1e-5,
|
|
||||||
rmsnorm: true,
|
|
||||||
apply_residual_connection_post_layernorm: false,
|
|
||||||
post_layer_norm: true,
|
|
||||||
add_bias_linear: false,
|
|
||||||
add_qkv_bias: true,
|
|
||||||
bias_dropout_fusion: true,
|
|
||||||
multi_query_attention: true,
|
|
||||||
multi_query_group_num: 2,
|
|
||||||
apply_query_key_layer_scaling: true,
|
|
||||||
attention_softmax_in_fp32: true,
|
|
||||||
fp32_residual_connection: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
if bias {
|
|
||||||
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
|
||||||
} else {
|
|
||||||
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct RotaryEmbedding {
|
|
||||||
cache: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
|
||||||
fn new(cfg: &Config, dtype: DType, dev: &Device) -> Result<Self> {
|
|
||||||
let rotary_dim = cfg.kv_channels;
|
|
||||||
let n_elem = rotary_dim / 2;
|
|
||||||
let inv_freq: Vec<_> = (0..n_elem)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| 1f32 / 10_000f64.powf(i as f64 / n_elem as f64) as f32)
|
|
||||||
.collect();
|
|
||||||
let inv_freq_len = inv_freq.len();
|
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
|
||||||
let t = Tensor::arange(0u32, cfg.seq_length as u32, dev)?
|
|
||||||
.to_dtype(dtype)?
|
|
||||||
.reshape((cfg.seq_length, 1))?;
|
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
|
||||||
let cache = Tensor::stack(&[&freqs.cos()?, &freqs.sin()?], D::Minus1)?;
|
|
||||||
Ok(Self { cache })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
|
||||||
let (seqlen, _b, np, _hn) = xs.dims4()?;
|
|
||||||
let cache = self.cache.narrow(0, seqlen_offset, seqlen)?;
|
|
||||||
let rot_dim = cache.dim(D::Minus2)? * 2;
|
|
||||||
let (xs, xs_pass) = (
|
|
||||||
xs.narrow(D::Minus1, 0, rot_dim)?,
|
|
||||||
xs.narrow(D::Minus1, rot_dim, rot_dim)?,
|
|
||||||
);
|
|
||||||
let xshaped = xs.reshape((seqlen, (), np, rot_dim / 2, 2))?;
|
|
||||||
let cache = cache.reshape((seqlen, (), 1, rot_dim / 2, 2))?;
|
|
||||||
let (xshaped0, xshaped1) = (
|
|
||||||
xshaped.i((.., .., .., .., 0))?,
|
|
||||||
xshaped.i((.., .., .., .., 1))?,
|
|
||||||
);
|
|
||||||
let (cache0, cache1) = (cache.i((.., .., .., .., 0))?, cache.i((.., .., .., .., 1))?);
|
|
||||||
let xs_out = Tensor::stack(
|
|
||||||
&[
|
|
||||||
(xshaped0.broadcast_mul(&cache0)? - xshaped1.broadcast_mul(&cache1)?)?,
|
|
||||||
(xshaped1.broadcast_mul(&cache0)? + xshaped0.broadcast_mul(&cache1)?)?,
|
|
||||||
],
|
|
||||||
D::Minus1,
|
|
||||||
)?;
|
|
||||||
let xs_out = xs_out.flatten_from(3)?;
|
|
||||||
Tensor::cat(&[xs_out, xs_pass], D::Minus1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct CoreAttention {
|
|
||||||
coeff: Option<f64>,
|
|
||||||
norm_factor: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
|
||||||
let shape = mask.shape();
|
|
||||||
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
|
|
||||||
let m = mask.where_cond(&on_true, on_false)?;
|
|
||||||
Ok(m)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CoreAttention {
|
|
||||||
fn new(layer_number: usize, cfg: &Config) -> Result<Self> {
|
|
||||||
let norm_factor = (cfg.kv_channels as f64).sqrt();
|
|
||||||
let (norm_factor, coeff) = if cfg.apply_query_key_layer_scaling {
|
|
||||||
let coeff = f64::max(1.0, layer_number as f64);
|
|
||||||
(norm_factor * coeff, Some(coeff))
|
|
||||||
} else {
|
|
||||||
(norm_factor, None)
|
|
||||||
};
|
|
||||||
Ok(Self { coeff, norm_factor })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&self,
|
|
||||||
query_layer: &Tensor,
|
|
||||||
key_layer: &Tensor,
|
|
||||||
value_layer: &Tensor,
|
|
||||||
attention_mask: &Option<Tensor>,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let output_size = (
|
|
||||||
query_layer.dim(1)?, // b
|
|
||||||
query_layer.dim(2)?, // np
|
|
||||||
query_layer.dim(0)?, // sq
|
|
||||||
key_layer.dim(0)?, // sk
|
|
||||||
);
|
|
||||||
let query_layer =
|
|
||||||
query_layer.reshape((output_size.2, output_size.0 * output_size.1, ()))?;
|
|
||||||
let key_layer = key_layer.reshape((output_size.3, output_size.0 * output_size.1, ()))?;
|
|
||||||
let matmul_result = Tensor::matmul(
|
|
||||||
&query_layer.transpose(0, 1)?,
|
|
||||||
&key_layer.transpose(0, 1)?.transpose(1, 2)?,
|
|
||||||
)?;
|
|
||||||
let matmul_result = (matmul_result / self.norm_factor)?.reshape(output_size)?;
|
|
||||||
let matmul_result = match self.coeff {
|
|
||||||
None => matmul_result,
|
|
||||||
Some(coeff) => (matmul_result * coeff)?,
|
|
||||||
};
|
|
||||||
let attention_scores = match attention_mask {
|
|
||||||
Some(mask) => masked_fill(
|
|
||||||
&matmul_result,
|
|
||||||
&mask.broadcast_left((matmul_result.dim(0)?, matmul_result.dim(1)?))?,
|
|
||||||
f32::NEG_INFINITY,
|
|
||||||
)?,
|
|
||||||
None => matmul_result,
|
|
||||||
};
|
|
||||||
let attention_probs = candle_nn::ops::softmax_last_dim(&attention_scores)?;
|
|
||||||
|
|
||||||
let output_size = (
|
|
||||||
value_layer.dim(1)?,
|
|
||||||
value_layer.dim(2)?,
|
|
||||||
query_layer.dim(0)?,
|
|
||||||
value_layer.dim(3)?,
|
|
||||||
);
|
|
||||||
let value_layer =
|
|
||||||
value_layer.reshape((value_layer.dim(0)?, output_size.0 * output_size.1, ()))?;
|
|
||||||
let attention_probs =
|
|
||||||
attention_probs.reshape((output_size.0 * output_size.1, output_size.2, ()))?;
|
|
||||||
let context_layer = Tensor::matmul(&attention_probs, &value_layer.transpose(0, 1)?)?;
|
|
||||||
let context_layer = context_layer.reshape(output_size)?;
|
|
||||||
let context_layer = context_layer.permute((2, 0, 1, 3))?.contiguous()?;
|
|
||||||
context_layer.flatten_from(D::Minus2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct SelfAttention {
|
|
||||||
query_key_value: Linear,
|
|
||||||
core_attention: CoreAttention,
|
|
||||||
dense: Linear,
|
|
||||||
multi_query_attention: bool,
|
|
||||||
num_attention_heads_per_partition: usize,
|
|
||||||
num_multi_query_groups_per_partition: usize,
|
|
||||||
hidden_size_per_attention_head: usize,
|
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SelfAttention {
|
|
||||||
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let projection_size = cfg.kv_channels * cfg.num_attention_heads;
|
|
||||||
let hidden_size_per_attention_head = projection_size / cfg.num_attention_heads;
|
|
||||||
let qkv_hidden_size = if cfg.multi_query_attention {
|
|
||||||
projection_size + 2 * hidden_size_per_attention_head * cfg.multi_query_group_num
|
|
||||||
} else {
|
|
||||||
3 * projection_size
|
|
||||||
};
|
|
||||||
let query_key_value = linear(
|
|
||||||
cfg.hidden_size,
|
|
||||||
qkv_hidden_size,
|
|
||||||
cfg.add_bias_linear || cfg.add_qkv_bias,
|
|
||||||
vb.pp("query_key_value"),
|
|
||||||
)?;
|
|
||||||
let core_attention = CoreAttention::new(layer_number, cfg)?;
|
|
||||||
let dense = linear(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.add_bias_linear,
|
|
||||||
vb.pp("dense"),
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
query_key_value,
|
|
||||||
core_attention,
|
|
||||||
dense,
|
|
||||||
multi_query_attention: cfg.multi_query_attention,
|
|
||||||
num_attention_heads_per_partition: cfg.num_attention_heads,
|
|
||||||
num_multi_query_groups_per_partition: cfg.multi_query_group_num,
|
|
||||||
hidden_size_per_attention_head: cfg.kv_channels,
|
|
||||||
kv_cache: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.kv_cache = None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
attention_mask: &Option<Tensor>,
|
|
||||||
rotary_emb: &RotaryEmbedding,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let mixed_x_layer = xs.apply(&self.query_key_value)?;
|
|
||||||
if !self.multi_query_attention {
|
|
||||||
candle::bail!("only multi_query_attention=true is supported")
|
|
||||||
}
|
|
||||||
let hpa = self.hidden_size_per_attention_head;
|
|
||||||
let query_layer =
|
|
||||||
mixed_x_layer.narrow(D::Minus1, 0, self.num_attention_heads_per_partition * hpa)?;
|
|
||||||
let key_layer = mixed_x_layer.narrow(
|
|
||||||
D::Minus1,
|
|
||||||
self.num_attention_heads_per_partition * hpa,
|
|
||||||
self.num_multi_query_groups_per_partition * hpa,
|
|
||||||
)?;
|
|
||||||
let value_layer = mixed_x_layer.narrow(
|
|
||||||
D::Minus1,
|
|
||||||
self.num_attention_heads_per_partition * hpa
|
|
||||||
+ self.num_multi_query_groups_per_partition * hpa,
|
|
||||||
self.num_multi_query_groups_per_partition * hpa,
|
|
||||||
)?;
|
|
||||||
let query_layer = query_layer.reshape((
|
|
||||||
query_layer.dim(0)?,
|
|
||||||
query_layer.dim(1)?,
|
|
||||||
self.num_attention_heads_per_partition,
|
|
||||||
hpa,
|
|
||||||
))?;
|
|
||||||
let key_layer = key_layer.reshape((
|
|
||||||
key_layer.dim(0)?,
|
|
||||||
key_layer.dim(1)?,
|
|
||||||
self.num_multi_query_groups_per_partition,
|
|
||||||
hpa,
|
|
||||||
))?;
|
|
||||||
let value_layer = value_layer.reshape((
|
|
||||||
value_layer.dim(0)?,
|
|
||||||
value_layer.dim(1)?,
|
|
||||||
self.num_multi_query_groups_per_partition,
|
|
||||||
hpa,
|
|
||||||
))?;
|
|
||||||
|
|
||||||
// Rotary embeddings.
|
|
||||||
let seqlen_offset = match &self.kv_cache {
|
|
||||||
None => 0,
|
|
||||||
Some((prev_k, _)) => prev_k.dim(0)?,
|
|
||||||
};
|
|
||||||
let query_layer = rotary_emb.apply(&query_layer, seqlen_offset)?;
|
|
||||||
let key_layer = rotary_emb.apply(&key_layer, seqlen_offset)?;
|
|
||||||
|
|
||||||
// KV cache.
|
|
||||||
let (key_layer, value_layer) = match &self.kv_cache {
|
|
||||||
None => (key_layer, value_layer),
|
|
||||||
Some((prev_k, prev_v)) => {
|
|
||||||
let k = Tensor::cat(&[prev_k, &key_layer], 0)?;
|
|
||||||
let v = Tensor::cat(&[prev_v, &value_layer], 0)?;
|
|
||||||
(k, v)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
self.kv_cache = Some((key_layer.clone(), value_layer.clone()));
|
|
||||||
|
|
||||||
// Repeat KV.
|
|
||||||
let ratio =
|
|
||||||
self.num_attention_heads_per_partition / self.num_multi_query_groups_per_partition;
|
|
||||||
let key_layer = {
|
|
||||||
let (d0, d1, d2, d3) = key_layer.dims4()?;
|
|
||||||
key_layer
|
|
||||||
.unsqueeze(D::Minus2)?
|
|
||||||
.expand((d0, d1, d2, ratio, d3))?
|
|
||||||
.reshape((
|
|
||||||
d0,
|
|
||||||
d1,
|
|
||||||
self.num_attention_heads_per_partition,
|
|
||||||
self.hidden_size_per_attention_head,
|
|
||||||
))?
|
|
||||||
};
|
|
||||||
let value_layer = {
|
|
||||||
let (d0, d1, d2, d3) = value_layer.dims4()?;
|
|
||||||
value_layer
|
|
||||||
.unsqueeze(D::Minus2)?
|
|
||||||
.expand((d0, d1, d2, ratio, d3))?
|
|
||||||
.reshape((
|
|
||||||
d0,
|
|
||||||
d1,
|
|
||||||
self.num_attention_heads_per_partition,
|
|
||||||
self.hidden_size_per_attention_head,
|
|
||||||
))?
|
|
||||||
};
|
|
||||||
|
|
||||||
let context_layer =
|
|
||||||
self.core_attention
|
|
||||||
.forward(&query_layer, &key_layer, &value_layer, attention_mask)?;
|
|
||||||
let output = context_layer.apply(&self.dense)?;
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct MLP {
|
|
||||||
dense_h_to_4h: Linear,
|
|
||||||
dense_4h_to_h: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MLP {
|
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let dense_h_to_4h = linear(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.ffn_hidden_size * 2,
|
|
||||||
cfg.add_bias_linear,
|
|
||||||
vb.pp("dense_h_to_4h"),
|
|
||||||
)?;
|
|
||||||
let dense_4h_to_h = linear(
|
|
||||||
cfg.ffn_hidden_size,
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.add_bias_linear,
|
|
||||||
vb.pp("dense_4h_to_h"),
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
dense_4h_to_h,
|
|
||||||
dense_h_to_4h,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for MLP {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
xs.apply(&self.dense_h_to_4h)?
|
|
||||||
.apply(&candle_nn::Activation::Swiglu)?
|
|
||||||
.apply(&self.dense_4h_to_h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Block {
|
|
||||||
input_layernorm: candle_nn::LayerNorm,
|
|
||||||
self_attention: SelfAttention,
|
|
||||||
post_attention_layernorm: candle_nn::LayerNorm,
|
|
||||||
mlp: MLP,
|
|
||||||
apply_residual_connection_post_layernorm: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Block {
|
|
||||||
fn new(layer_number: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let input_layernorm = if cfg.rmsnorm {
|
|
||||||
candle_nn::rms_norm(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.layernorm_epsilon,
|
|
||||||
vb.pp("input_layernorm"),
|
|
||||||
)?
|
|
||||||
.into_inner()
|
|
||||||
} else {
|
|
||||||
candle_nn::layer_norm(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.layernorm_epsilon,
|
|
||||||
vb.pp("input_layernorm"),
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let post_attention_layernorm = if cfg.rmsnorm {
|
|
||||||
candle_nn::rms_norm(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.layernorm_epsilon,
|
|
||||||
vb.pp("post_attention_layernorm"),
|
|
||||||
)?
|
|
||||||
.into_inner()
|
|
||||||
} else {
|
|
||||||
candle_nn::layer_norm(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.layernorm_epsilon,
|
|
||||||
vb.pp("post_attention_layernorm"),
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
let self_attention = SelfAttention::new(layer_number, cfg, vb.pp("self_attention"))?;
|
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
|
||||||
Ok(Self {
|
|
||||||
input_layernorm,
|
|
||||||
self_attention,
|
|
||||||
post_attention_layernorm,
|
|
||||||
mlp,
|
|
||||||
apply_residual_connection_post_layernorm: cfg.apply_residual_connection_post_layernorm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
self.self_attention.reset_kv_cache()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
attention_mask: &Option<Tensor>,
|
|
||||||
rotary_emb: &RotaryEmbedding,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let layernorm_output = xs.apply(&self.input_layernorm)?;
|
|
||||||
let attention_output =
|
|
||||||
self.self_attention
|
|
||||||
.forward(&layernorm_output, attention_mask, rotary_emb)?;
|
|
||||||
let residual = if self.apply_residual_connection_post_layernorm {
|
|
||||||
&layernorm_output
|
|
||||||
} else {
|
|
||||||
xs
|
|
||||||
};
|
|
||||||
let layernorm_input = (residual + attention_output)?;
|
|
||||||
let layernorm_output = layernorm_input.apply(&self.post_attention_layernorm)?;
|
|
||||||
let mlp_output = layernorm_output.apply(&self.mlp)?;
|
|
||||||
let residual = if self.apply_residual_connection_post_layernorm {
|
|
||||||
&layernorm_output
|
|
||||||
} else {
|
|
||||||
&layernorm_input
|
|
||||||
};
|
|
||||||
mlp_output + residual
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Transformer {
|
|
||||||
layers: Vec<Block>,
|
|
||||||
final_layernorm: Option<candle_nn::LayerNorm>,
|
|
||||||
rotary_emb: RotaryEmbedding,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Transformer {
|
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let vb_l = vb.pp("layers");
|
|
||||||
let mut layers = Vec::with_capacity(cfg.num_layers);
|
|
||||||
for layer_index in 0..cfg.num_layers {
|
|
||||||
let block = Block::new(layer_index + 1, cfg, vb_l.pp(layer_index))?;
|
|
||||||
layers.push(block)
|
|
||||||
}
|
|
||||||
let final_layernorm = if cfg.post_layer_norm {
|
|
||||||
let ln = if cfg.rmsnorm {
|
|
||||||
candle_nn::rms_norm(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.layernorm_epsilon,
|
|
||||||
vb.pp("final_layernorm"),
|
|
||||||
)?
|
|
||||||
.into_inner()
|
|
||||||
} else {
|
|
||||||
candle_nn::layer_norm(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.layernorm_epsilon,
|
|
||||||
vb.pp("final_layernorm"),
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
Some(ln)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
let rotary_emb = RotaryEmbedding::new(cfg, vb.dtype(), vb.device())?;
|
|
||||||
Ok(Self {
|
|
||||||
layers,
|
|
||||||
final_layernorm,
|
|
||||||
rotary_emb,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn reset_kv_cache(&mut self) {
|
|
||||||
for block in self.layers.iter_mut() {
|
|
||||||
block.reset_kv_cache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor, attention_mask: &Option<Tensor>) -> Result<Tensor> {
|
|
||||||
let mut xs = xs.clone();
|
|
||||||
for block in self.layers.iter_mut() {
|
|
||||||
xs = block.forward(&xs, attention_mask, &self.rotary_emb)?
|
|
||||||
}
|
|
||||||
match self.final_layernorm.as_ref() {
|
|
||||||
None => Ok(xs),
|
|
||||||
Some(ln) => xs.apply(ln),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Embedding {
|
|
||||||
word_embeddings: candle_nn::Embedding,
|
|
||||||
fp32_residual_connection: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Embedding {
|
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let word_embeddings = candle_nn::embedding(
|
|
||||||
cfg.padded_vocab_size,
|
|
||||||
cfg.hidden_size,
|
|
||||||
vb.pp("word_embeddings"),
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
word_embeddings,
|
|
||||||
fp32_residual_connection: cfg.fp32_residual_connection,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Embedding {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let xs = self.word_embeddings.forward(xs)?.transpose(0, 1)?; // b,s,h -> s,b,h
|
|
||||||
if self.fp32_residual_connection {
|
|
||||||
xs.to_dtype(candle::DType::F32)
|
|
||||||
} else {
|
|
||||||
xs.contiguous()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Model {
|
|
||||||
embedding: Embedding,
|
|
||||||
encoder: Transformer,
|
|
||||||
output_layer: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
|
||||||
let mask: Vec<_> = (0..size)
|
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
|
||||||
.collect();
|
|
||||||
Tensor::from_slice(&mask, (size, size), device)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let vb = vb.pp("transformer");
|
|
||||||
let embedding = Embedding::new(cfg, vb.pp("embedding"))?;
|
|
||||||
let encoder = Transformer::new(cfg, vb.pp("encoder"))?;
|
|
||||||
let output_layer = linear(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.padded_vocab_size,
|
|
||||||
false,
|
|
||||||
vb.pp("output_layer"),
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
embedding,
|
|
||||||
encoder,
|
|
||||||
output_layer,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reset_kv_cache(&mut self) {
|
|
||||||
self.encoder.reset_kv_cache()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let (_b_size, seq_len) = xs.dims2()?;
|
|
||||||
let input_embeds = xs.apply(&self.embedding)?;
|
|
||||||
let attention_mask = if seq_len <= 1 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(get_mask(seq_len, xs.device())?)
|
|
||||||
};
|
|
||||||
let xs = self.encoder.forward(&input_embeds, &attention_mask)?;
|
|
||||||
let lm_logits = xs.i(seq_len - 1)?.apply(&self.output_layer)?;
|
|
||||||
Ok(lm_logits)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,201 +0,0 @@
|
|||||||
//! ConvNeXt implementation.
|
|
||||||
//!
|
|
||||||
//! See "A ConvNet for the 2020s" Liu et al. 2022
|
|
||||||
//! <https://arxiv.org/abs/2201.03545>
|
|
||||||
|
|
||||||
//! Original code: https://github.com/facebookresearch/ConvNeXt/
|
|
||||||
//! timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/convnext.py
|
|
||||||
|
|
||||||
use candle::{Result, D};
|
|
||||||
use candle_nn::{conv2d, layer_norm, linear, Conv2dConfig, Func, VarBuilder};
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Config {
|
|
||||||
blocks: [usize; 4],
|
|
||||||
channels: [usize; 4],
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
pub fn tiny() -> Self {
|
|
||||||
Self {
|
|
||||||
blocks: [3, 3, 9, 3],
|
|
||||||
channels: [96, 192, 384, 768],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn small() -> Self {
|
|
||||||
Self {
|
|
||||||
blocks: [3, 3, 27, 3],
|
|
||||||
channels: [96, 192, 384, 768],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn base() -> Self {
|
|
||||||
Self {
|
|
||||||
blocks: [3, 3, 27, 3],
|
|
||||||
channels: [128, 256, 512, 1024],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn large() -> Self {
|
|
||||||
Self {
|
|
||||||
blocks: [3, 3, 27, 3],
|
|
||||||
channels: [192, 384, 768, 1536],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn xlarge() -> Self {
|
|
||||||
Self {
|
|
||||||
blocks: [3, 3, 27, 3],
|
|
||||||
channels: [256, 512, 1024, 2048],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initial downsampling via a patchify layer.
|
|
||||||
fn convnext_stem(out_channels: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
let conv2d_cfg = Conv2dConfig {
|
|
||||||
stride: 4,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let patchify = conv2d(3, out_channels, 4, conv2d_cfg, vb.pp(0))?;
|
|
||||||
let norm = layer_norm(out_channels, 1e-6, vb.pp(1))?;
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
// The layer norm works with channels-last format.
|
|
||||||
let xs = xs
|
|
||||||
.apply(&patchify)?
|
|
||||||
.permute((0, 2, 3, 1))?
|
|
||||||
.apply(&norm)?
|
|
||||||
.permute((0, 3, 1, 2))?;
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Downsampling applied after the stages.
|
|
||||||
fn convnext_downsample(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
let conv2d_cfg = Conv2dConfig {
|
|
||||||
stride: 2,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let norm = layer_norm(dim / 2, 1e-5, vb.pp(0))?;
|
|
||||||
let conv = conv2d(dim / 2, dim, 2, conv2d_cfg, vb.pp(1))?;
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let xs = xs
|
|
||||||
.permute((0, 2, 3, 1))?
|
|
||||||
.apply(&norm)?
|
|
||||||
.permute((0, 3, 1, 2))?
|
|
||||||
.apply(&conv)?;
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MLP equivalent of pointwise convolutions.
|
|
||||||
fn convnext_mlp(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
let fc1 = linear(dim, 4 * dim, vb.pp("fc1"))?;
|
|
||||||
let fc2 = linear(4 * dim, dim, vb.pp("fc2"))?;
|
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let xs = xs.apply(&fc1)?.gelu_erf()?.apply(&fc2)?;
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// A block consisting of a depthwise convolution, a MLP and layer scaling.
|
|
||||||
fn convnext_block(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
let conv2d_cfg = Conv2dConfig {
|
|
||||||
groups: dim,
|
|
||||||
padding: 3,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let conv_dw = conv2d(dim, dim, 7, conv2d_cfg, vb.pp("conv_dw"))?;
|
|
||||||
|
|
||||||
let gamma = vb.get(dim, "gamma")?;
|
|
||||||
let mlp = convnext_mlp(dim, vb.pp("mlp"))?;
|
|
||||||
let norm = layer_norm(dim, 1e-6, vb.pp("norm"))?;
|
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let residual = xs;
|
|
||||||
let xs = xs
|
|
||||||
.apply(&conv_dw)?
|
|
||||||
.permute((0, 2, 3, 1))?
|
|
||||||
.apply(&norm)?
|
|
||||||
.apply(&mlp)?
|
|
||||||
.broadcast_mul(&gamma)?
|
|
||||||
.permute((0, 3, 1, 2))?;
|
|
||||||
|
|
||||||
xs + residual
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each stage contains blocks and a downsampling layer for the previous stage.
|
|
||||||
fn convnext_stage(cfg: &Config, stage_idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
let nblocks = cfg.blocks[stage_idx];
|
|
||||||
let mut blocks = Vec::with_capacity(nblocks);
|
|
||||||
|
|
||||||
let dim = cfg.channels[stage_idx];
|
|
||||||
|
|
||||||
if stage_idx > 0 {
|
|
||||||
blocks.push(convnext_downsample(dim, vb.pp("downsample"))?);
|
|
||||||
}
|
|
||||||
|
|
||||||
for block_idx in 0..nblocks {
|
|
||||||
blocks.push(convnext_block(dim, vb.pp(format!("blocks.{block_idx}")))?);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let mut xs = xs.clone();
|
|
||||||
for block in blocks.iter() {
|
|
||||||
xs = xs.apply(block)?
|
|
||||||
}
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn convnext_head(outputs: usize, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
let norm = layer_norm(outputs, 1e-6, vb.pp("norm"))?;
|
|
||||||
let linear = linear(outputs, nclasses, vb.pp("fc"))?;
|
|
||||||
Ok(Func::new(move |xs| xs.apply(&norm)?.apply(&linear)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a convnext model for a given configuration.
|
|
||||||
fn convnext_model(
|
|
||||||
config: &Config,
|
|
||||||
nclasses: Option<usize>,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Func<'static>> {
|
|
||||||
let head = match nclasses {
|
|
||||||
None => None,
|
|
||||||
Some(nclasses) => {
|
|
||||||
let head = convnext_head(config.channels[3], nclasses, vb.pp("head"))?;
|
|
||||||
Some(head)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let stem = convnext_stem(config.channels[0], vb.pp("stem"))?;
|
|
||||||
let vb = vb.pp("stages");
|
|
||||||
let stage1 = convnext_stage(config, 0, vb.pp(0))?;
|
|
||||||
let stage2 = convnext_stage(config, 1, vb.pp(1))?;
|
|
||||||
let stage3 = convnext_stage(config, 2, vb.pp(2))?;
|
|
||||||
let stage4 = convnext_stage(config, 3, vb.pp(3))?;
|
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let xs = xs
|
|
||||||
.apply(&stem)?
|
|
||||||
.apply(&stage1)?
|
|
||||||
.apply(&stage2)?
|
|
||||||
.apply(&stage3)?
|
|
||||||
.apply(&stage4)?
|
|
||||||
.mean(D::Minus2)?
|
|
||||||
.mean(D::Minus1)?;
|
|
||||||
match &head {
|
|
||||||
None => Ok(xs),
|
|
||||||
Some(head) => xs.apply(head),
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn convnext(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
convnext_model(cfg, Some(nclasses), vb)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn convnext_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
convnext_model(cfg, None, vb)
|
|
||||||
}
|
|
@ -7,7 +7,7 @@ use std::sync::{Arc, Mutex};
|
|||||||
|
|
||||||
pub const MAX_SEQ_LEN: usize = 4096;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct LlamaConfig {
|
pub struct LlamaConfig {
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
@ -40,7 +40,6 @@ impl LlamaConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub intermediate_size: usize,
|
pub intermediate_size: usize,
|
||||||
@ -83,7 +82,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
pub use_kv_cache: bool,
|
pub use_kv_cache: bool,
|
||||||
@ -137,7 +136,6 @@ impl Cache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
inner: candle_nn::RmsNorm,
|
inner: candle_nn::RmsNorm,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
@ -156,7 +154,6 @@ impl RmsNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct CausalSelfAttention {
|
struct CausalSelfAttention {
|
||||||
q_proj: Linear,
|
q_proj: Linear,
|
||||||
k_proj: Linear,
|
k_proj: Linear,
|
||||||
@ -317,7 +314,6 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
Ok(m)
|
Ok(m)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Mlp {
|
struct Mlp {
|
||||||
c_fc1: Linear,
|
c_fc1: Linear,
|
||||||
c_fc2: Linear,
|
c_fc2: Linear,
|
||||||
@ -348,7 +344,6 @@ impl Mlp {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Block {
|
struct Block {
|
||||||
rms_1: RmsNorm,
|
rms_1: RmsNorm,
|
||||||
attn: CausalSelfAttention,
|
attn: CausalSelfAttention,
|
||||||
@ -388,7 +383,6 @@ impl Block {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Llama {
|
pub struct Llama {
|
||||||
wte: Embedding,
|
wte: Embedding,
|
||||||
blocks: Vec<Block>,
|
blocks: Vec<Block>,
|
||||||
|
@ -1,211 +0,0 @@
|
|||||||
#![allow(unused)]
|
|
||||||
/// A fast implementation of mamba for inference only.
|
|
||||||
/// This is based on: https://github.com/LaurentMazare/mamba.rs
|
|
||||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
|
||||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
|
||||||
use candle_nn::{RmsNorm, VarBuilder};
|
|
||||||
|
|
||||||
const D_CONV: usize = 4;
|
|
||||||
const D_STATE: usize = 16;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
d_model: usize,
|
|
||||||
n_layer: usize,
|
|
||||||
vocab_size: usize,
|
|
||||||
pad_vocab_size_multiple: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
fn vocab_size(&self) -> usize {
|
|
||||||
let pad = self.pad_vocab_size_multiple;
|
|
||||||
(self.vocab_size + pad - 1) / pad * pad
|
|
||||||
}
|
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
|
||||||
(self.d_model + 15) / 16
|
|
||||||
}
|
|
||||||
|
|
||||||
fn d_inner(&self) -> usize {
|
|
||||||
self.d_model * 2
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct State {
|
|
||||||
hs: Vec<Tensor>,
|
|
||||||
prev_xs: Vec<[Tensor; D_CONV]>,
|
|
||||||
pos: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl State {
|
|
||||||
pub fn new(batch_size: usize, cfg: &Config, device: &Device) -> Result<Self> {
|
|
||||||
let mut hs = Vec::with_capacity(cfg.n_layer);
|
|
||||||
let mut prev_xs = Vec::with_capacity(cfg.n_layer);
|
|
||||||
for _i in 0..cfg.n_layer {
|
|
||||||
let h = Tensor::zeros((batch_size, cfg.d_inner(), D_STATE), DType::F32, device)?;
|
|
||||||
let x = Tensor::zeros((batch_size, cfg.d_inner()), DType::F32, device)?;
|
|
||||||
hs.push(h);
|
|
||||||
prev_xs.push([x.clone(), x.clone(), x.clone(), x.clone()]);
|
|
||||||
}
|
|
||||||
Ok(Self {
|
|
||||||
hs,
|
|
||||||
prev_xs,
|
|
||||||
pos: 0,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct MambaBlock {
|
|
||||||
in_proj: Linear,
|
|
||||||
conv1d_bias: Tensor,
|
|
||||||
conv1d_weights: [Tensor; D_CONV],
|
|
||||||
x_proj: Linear,
|
|
||||||
dt_proj: Linear,
|
|
||||||
a_log: Tensor,
|
|
||||||
d: Tensor,
|
|
||||||
out_proj: Linear,
|
|
||||||
dt_rank: usize,
|
|
||||||
layer_index: usize,
|
|
||||||
d_inner: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MambaBlock {
|
|
||||||
pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let d_inner = cfg.d_inner();
|
|
||||||
let dt_rank = cfg.dt_rank();
|
|
||||||
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
|
|
||||||
let x_proj = linear_no_bias(d_inner, dt_rank + D_STATE * 2, vb.pp("x_proj"))?;
|
|
||||||
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
|
|
||||||
let a_log = vb.get((d_inner, D_STATE), "A_log")?;
|
|
||||||
let d = vb.get(d_inner, "D")?;
|
|
||||||
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
|
|
||||||
let conv1d_bias = vb.get(d_inner, "conv1d.bias")?;
|
|
||||||
let conv1d_weight = vb.get((d_inner, 1, D_CONV), "conv1d.weight")?;
|
|
||||||
let conv1d_weights = [
|
|
||||||
conv1d_weight.i((.., 0, 0))?,
|
|
||||||
conv1d_weight.i((.., 0, 1))?,
|
|
||||||
conv1d_weight.i((.., 0, 2))?,
|
|
||||||
conv1d_weight.i((.., 0, 3))?,
|
|
||||||
];
|
|
||||||
Ok(Self {
|
|
||||||
in_proj,
|
|
||||||
conv1d_bias,
|
|
||||||
conv1d_weights,
|
|
||||||
x_proj,
|
|
||||||
dt_proj,
|
|
||||||
a_log,
|
|
||||||
d,
|
|
||||||
out_proj,
|
|
||||||
dt_rank,
|
|
||||||
layer_index,
|
|
||||||
d_inner,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
|
||||||
let (b_sz, _dim) = xs.dims2()?;
|
|
||||||
let li = self.layer_index;
|
|
||||||
let mut xs = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
|
|
||||||
let proj_for_silu = xs.remove(1);
|
|
||||||
state.prev_xs[li][state.pos % D_CONV] = xs.remove(0);
|
|
||||||
let mut proj_for_conv = self.conv1d_bias.broadcast_as((b_sz, self.d_inner))?;
|
|
||||||
for d_c in 0..D_CONV {
|
|
||||||
proj_for_conv = (proj_for_conv
|
|
||||||
+ self.conv1d_weights[d_c]
|
|
||||||
.broadcast_mul(&state.prev_xs[li][(d_c + 1 + state.pos) % D_CONV])?)?;
|
|
||||||
}
|
|
||||||
let proj_for_conv = candle_nn::ops::silu(&proj_for_conv)?;
|
|
||||||
// SSM + Selection, we're doing inference here so only need the last step of
|
|
||||||
// the sequence.
|
|
||||||
// Algorithm 3.2 on page 6, https://arxiv.org/pdf/2312.00752.pdf
|
|
||||||
|
|
||||||
let x_proj = self.x_proj.forward(&proj_for_conv)?;
|
|
||||||
let delta = x_proj.narrow(D::Minus1, 0, self.dt_rank)?;
|
|
||||||
let b = x_proj.narrow(D::Minus1, self.dt_rank, D_STATE)?;
|
|
||||||
let c = x_proj.narrow(D::Minus1, self.dt_rank + D_STATE, D_STATE)?;
|
|
||||||
|
|
||||||
let delta = delta.apply(&self.dt_proj)?;
|
|
||||||
// softplus
|
|
||||||
let delta = (delta.exp()? + 1.)?.log()?;
|
|
||||||
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
|
|
||||||
let d = self.d.to_dtype(candle::DType::F32)?;
|
|
||||||
|
|
||||||
// Selective scan part
|
|
||||||
// Eqn (2a), page 3, h_t = Ab h_{t-1} + Bb x_t
|
|
||||||
let delta = delta
|
|
||||||
.unsqueeze(D::Minus1)?
|
|
||||||
.broadcast_as((b_sz, self.d_inner, D_STATE))?;
|
|
||||||
let a = a.broadcast_as((b_sz, self.d_inner, D_STATE))?;
|
|
||||||
let b = b.broadcast_as((b_sz, self.d_inner, D_STATE))?;
|
|
||||||
let proj_for_conv_b =
|
|
||||||
proj_for_conv
|
|
||||||
.unsqueeze(D::Minus1)?
|
|
||||||
.broadcast_as((b_sz, self.d_inner, D_STATE))?;
|
|
||||||
state.hs[li] = ((&state.hs[li] * (&delta * &a)?.exp()?)? + &delta * &b * &proj_for_conv_b)?;
|
|
||||||
let ss = (state.hs[li]
|
|
||||||
.matmul(&c.unsqueeze(D::Minus1)?)?
|
|
||||||
.squeeze(D::Minus1)?
|
|
||||||
+ proj_for_conv.broadcast_mul(&d)?)?;
|
|
||||||
|
|
||||||
let ys = (ss * candle_nn::ops::silu(&proj_for_silu))?;
|
|
||||||
ys.apply(&self.out_proj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct ResidualBlock {
|
|
||||||
mixer: MambaBlock,
|
|
||||||
norm: RmsNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResidualBlock {
|
|
||||||
pub fn new(layer_index: usize, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
|
|
||||||
let mixer = MambaBlock::new(layer_index, cfg, vb.pp("mixer"))?;
|
|
||||||
Ok(Self { mixer, norm })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, xs: &Tensor, state: &mut State) -> Result<Tensor> {
|
|
||||||
self.mixer.forward(&xs.apply(&self.norm)?, state)? + xs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct Model {
|
|
||||||
embedding: candle_nn::Embedding,
|
|
||||||
layers: Vec<ResidualBlock>,
|
|
||||||
norm_f: RmsNorm,
|
|
||||||
lm_head: Linear,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
|
|
||||||
let mut layers = Vec::with_capacity(cfg.n_layer);
|
|
||||||
let vb_l = vb.pp("layers");
|
|
||||||
for layer_idx in 0..cfg.n_layer {
|
|
||||||
let layer = ResidualBlock::new(layer_idx, cfg, vb_l.pp(layer_idx))?;
|
|
||||||
layers.push(layer)
|
|
||||||
}
|
|
||||||
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
|
|
||||||
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
|
|
||||||
Ok(Self {
|
|
||||||
embedding,
|
|
||||||
layers,
|
|
||||||
norm_f,
|
|
||||||
lm_head,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&self, input_ids: &Tensor, state: &mut State) -> Result<Tensor> {
|
|
||||||
let _b_size = input_ids.dims1()?;
|
|
||||||
let mut xs = self.embedding.forward(input_ids)?;
|
|
||||||
for layer in self.layers.iter() {
|
|
||||||
xs = layer.forward(&xs, state)?
|
|
||||||
}
|
|
||||||
state.pos += 1;
|
|
||||||
xs.apply(&self.norm_f)?.apply(&self.lm_head)
|
|
||||||
}
|
|
||||||
}
|
|
@ -8,7 +8,7 @@ use serde::Deserialize;
|
|||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
// https://huggingface.co/microsoft/phi-1_5/blob/d38e6f954ec29b96fe2cf033937dad64e279b5d9/configuration_mixformer_sequential.py
|
// https://huggingface.co/microsoft/phi-1_5/blob/main/configuration_mixformer_sequential.py
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub(crate) vocab_size: usize,
|
pub(crate) vocab_size: usize,
|
||||||
|
@ -1,333 +0,0 @@
|
|||||||
//! MobileOne inference implementation based on timm and candle-repvgg
|
|
||||||
//!
|
|
||||||
//! See "MobileOne: An Improved One millisecond Mobile Backbone"
|
|
||||||
//! https://arxiv.org/abs/2206.04040
|
|
||||||
|
|
||||||
use candle::{DType, Result, Tensor, D};
|
|
||||||
use candle_nn::{
|
|
||||||
batch_norm, conv2d, conv2d_no_bias, linear, ops::sigmoid, BatchNorm, Conv2d, Conv2dConfig,
|
|
||||||
Func, VarBuilder,
|
|
||||||
};
|
|
||||||
|
|
||||||
struct StageConfig {
|
|
||||||
blocks: usize,
|
|
||||||
channels: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
// The architecture in the paper has 6 stages. The timm implementation uses an equivalent form
|
|
||||||
// by concatenating the 5th stage (starts with stride 1) to the previous one.
|
|
||||||
const STAGES: [StageConfig; 5] = [
|
|
||||||
StageConfig {
|
|
||||||
blocks: 1,
|
|
||||||
channels: 64,
|
|
||||||
},
|
|
||||||
StageConfig {
|
|
||||||
blocks: 2,
|
|
||||||
channels: 64,
|
|
||||||
},
|
|
||||||
StageConfig {
|
|
||||||
blocks: 8,
|
|
||||||
channels: 128,
|
|
||||||
},
|
|
||||||
StageConfig {
|
|
||||||
blocks: 10,
|
|
||||||
channels: 256,
|
|
||||||
},
|
|
||||||
StageConfig {
|
|
||||||
blocks: 1,
|
|
||||||
channels: 512,
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct Config {
|
|
||||||
/// overparameterization factor
|
|
||||||
k: usize,
|
|
||||||
/// per-stage channel number multipliers
|
|
||||||
alphas: [f32; 5],
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Config {
|
|
||||||
pub fn s0() -> Self {
|
|
||||||
Self {
|
|
||||||
k: 4,
|
|
||||||
alphas: [0.75, 0.75, 1.0, 1.0, 2.0],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn s1() -> Self {
|
|
||||||
Self {
|
|
||||||
k: 1,
|
|
||||||
alphas: [1.5, 1.5, 1.5, 2.0, 2.5],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn s2() -> Self {
|
|
||||||
Self {
|
|
||||||
k: 1,
|
|
||||||
alphas: [1.5, 1.5, 2.0, 2.5, 4.0],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn s3() -> Self {
|
|
||||||
Self {
|
|
||||||
k: 1,
|
|
||||||
alphas: [2.0, 2.0, 2.5, 3.0, 4.0],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn s4() -> Self {
|
|
||||||
Self {
|
|
||||||
k: 1,
|
|
||||||
alphas: [3.0, 3.0, 3.5, 3.5, 4.0],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SE blocks are used in the last stages of the s4 variant.
|
|
||||||
fn squeeze_and_excitation(
|
|
||||||
in_channels: usize,
|
|
||||||
squeeze_channels: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Func<'static>> {
|
|
||||||
let conv2d_cfg = Conv2dConfig {
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
let fc1 = conv2d(in_channels, squeeze_channels, 1, conv2d_cfg, vb.pp("fc1"))?;
|
|
||||||
let fc2 = conv2d(squeeze_channels, in_channels, 1, conv2d_cfg, vb.pp("fc2"))?;
|
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let residual = xs;
|
|
||||||
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
|
|
||||||
let xs = sigmoid(&xs.apply(&fc1)?.relu()?.apply(&fc2)?)?;
|
|
||||||
|
|
||||||
residual.broadcast_mul(&xs)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// fuses a convolutional kernel and a batchnorm layer into a convolutional layer
|
|
||||||
// based on the _fuse_bn_tensor method in timm
|
|
||||||
// see https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L602
|
|
||||||
fn fuse_conv_bn(weights: &Tensor, bn: BatchNorm) -> Result<(Tensor, Tensor)> {
|
|
||||||
let (gamma, beta) = bn.weight_and_bias().unwrap();
|
|
||||||
let mu = bn.running_mean();
|
|
||||||
let sigma = (bn.running_var() + bn.eps())?.sqrt();
|
|
||||||
let gps = (gamma / sigma)?;
|
|
||||||
let bias = (beta - mu * &gps)?;
|
|
||||||
let weights = weights.broadcast_mul(&gps.reshape(((), 1, 1, 1))?)?;
|
|
||||||
|
|
||||||
Ok((weights, bias))
|
|
||||||
}
|
|
||||||
|
|
||||||
// A mobileone block has a different training time and inference time architecture.
|
|
||||||
// The latter is a simple and efficient equivalent transformation of the former
|
|
||||||
// realized by a structural reparameterization technique, where convolutions
|
|
||||||
// along with identity branches and batchnorm layers are fused into a single convolution.
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
fn mobileone_block(
|
|
||||||
has_identity: bool,
|
|
||||||
k: usize,
|
|
||||||
dim: usize,
|
|
||||||
stride: usize,
|
|
||||||
padding: usize,
|
|
||||||
groups: usize,
|
|
||||||
kernel: usize,
|
|
||||||
in_channels: usize,
|
|
||||||
out_channels: usize,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Func<'static>> {
|
|
||||||
let conv2d_cfg = Conv2dConfig {
|
|
||||||
stride,
|
|
||||||
padding,
|
|
||||||
groups,
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut w = Tensor::zeros(
|
|
||||||
(out_channels, in_channels / groups, kernel, kernel),
|
|
||||||
DType::F32,
|
|
||||||
vb.device(),
|
|
||||||
)?;
|
|
||||||
let mut b = Tensor::zeros(dim, DType::F32, vb.device())?;
|
|
||||||
|
|
||||||
// k is the training-time overparameterization factor, larger than 1 only in the s0 variant
|
|
||||||
for i in 0..k {
|
|
||||||
let conv_kxk_bn = batch_norm(dim, 1e-5, vb.pp(format!("conv_kxk.{i}.bn")))?;
|
|
||||||
let conv_kxk = conv2d_no_bias(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
kernel,
|
|
||||||
conv2d_cfg,
|
|
||||||
vb.pp(format!("conv_kxk.{i}.conv")),
|
|
||||||
)?;
|
|
||||||
let (wk, bk) = fuse_conv_bn(conv_kxk.weight(), conv_kxk_bn)?;
|
|
||||||
w = (w + wk)?;
|
|
||||||
b = (b + bk)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
if kernel > 1 {
|
|
||||||
let conv_scale_bn = batch_norm(dim, 1e-5, vb.pp("conv_scale.bn"))?;
|
|
||||||
let conv_scale = conv2d_no_bias(
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
1,
|
|
||||||
conv2d_cfg,
|
|
||||||
vb.pp("conv_scale.conv"),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let (mut ws, bs) = fuse_conv_bn(conv_scale.weight(), conv_scale_bn)?;
|
|
||||||
// resize to 3x3
|
|
||||||
ws = ws.pad_with_zeros(D::Minus1, 1, 1)?;
|
|
||||||
ws = ws.pad_with_zeros(D::Minus2, 1, 1)?;
|
|
||||||
|
|
||||||
w = (w + ws)?;
|
|
||||||
b = (b + bs)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Use SE blocks if present (last layers of the s4 variant)
|
|
||||||
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("attn"));
|
|
||||||
|
|
||||||
// read and reparameterize the identity bn into wi and bi
|
|
||||||
if has_identity {
|
|
||||||
let identity_bn = batch_norm(dim, 1e-5, vb.pp("identity"))?;
|
|
||||||
|
|
||||||
let mut weights: Vec<f32> = vec![0.0; w.elem_count()];
|
|
||||||
|
|
||||||
let id = in_channels / groups;
|
|
||||||
// See https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/byobnet.py#L809
|
|
||||||
for i in 0..in_channels {
|
|
||||||
if kernel > 1 {
|
|
||||||
weights[i * kernel * kernel + 4] = 1.0;
|
|
||||||
} else {
|
|
||||||
weights[i * (id + 1)] = 1.0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let weights = &Tensor::from_vec(weights, w.shape(), w.device())?;
|
|
||||||
let (wi, bi) = fuse_conv_bn(weights, identity_bn)?;
|
|
||||||
|
|
||||||
w = (w + wi)?;
|
|
||||||
b = (b + bi)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let reparam_conv = Conv2d::new(w, Some(b), conv2d_cfg);
|
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let mut xs = xs.apply(&reparam_conv)?;
|
|
||||||
if let Ok(f) = &se {
|
|
||||||
xs = xs.apply(f)?;
|
|
||||||
}
|
|
||||||
xs = xs.relu()?;
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the number of output channels per stage taking into account the multipliers
|
|
||||||
fn output_channels_per_stage(cfg: &Config, stage: usize) -> usize {
|
|
||||||
let channels = STAGES[stage].channels as f32;
|
|
||||||
let alpha = cfg.alphas[stage];
|
|
||||||
|
|
||||||
match stage {
|
|
||||||
0 => std::cmp::min(64, (channels * alpha) as usize),
|
|
||||||
_ => (channels * alpha) as usize,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Each stage is made of blocks. The first layer always downsamples with stride 2.
|
|
||||||
// All but the first block have a residual connection.
|
|
||||||
fn mobileone_stage(cfg: &Config, idx: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
let nblocks = STAGES[idx].blocks;
|
|
||||||
let mut blocks = Vec::with_capacity(nblocks);
|
|
||||||
|
|
||||||
let mut in_channels = output_channels_per_stage(cfg, idx - 1);
|
|
||||||
|
|
||||||
for block_idx in 0..nblocks {
|
|
||||||
let out_channels = output_channels_per_stage(cfg, idx);
|
|
||||||
let (has_identity, stride) = if block_idx == 0 {
|
|
||||||
(false, 2)
|
|
||||||
} else {
|
|
||||||
(true, 1)
|
|
||||||
};
|
|
||||||
|
|
||||||
// depthwise convolution layer
|
|
||||||
blocks.push(mobileone_block(
|
|
||||||
has_identity,
|
|
||||||
cfg.k,
|
|
||||||
in_channels,
|
|
||||||
stride,
|
|
||||||
1,
|
|
||||||
in_channels,
|
|
||||||
3,
|
|
||||||
in_channels,
|
|
||||||
in_channels,
|
|
||||||
vb.pp(block_idx * 2),
|
|
||||||
)?);
|
|
||||||
|
|
||||||
// pointwise convolution layer
|
|
||||||
blocks.push(mobileone_block(
|
|
||||||
has_identity,
|
|
||||||
cfg.k,
|
|
||||||
out_channels,
|
|
||||||
1, // stride
|
|
||||||
0, // padding
|
|
||||||
1, // groups
|
|
||||||
1, // kernel
|
|
||||||
in_channels,
|
|
||||||
out_channels,
|
|
||||||
vb.pp(block_idx * 2 + 1),
|
|
||||||
)?);
|
|
||||||
|
|
||||||
in_channels = out_channels;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let mut xs = xs.clone();
|
|
||||||
for block in blocks.iter() {
|
|
||||||
xs = xs.apply(block)?
|
|
||||||
}
|
|
||||||
Ok(xs)
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build a mobileone model for a given configuration.
|
|
||||||
fn mobileone_model(
|
|
||||||
config: &Config,
|
|
||||||
nclasses: Option<usize>,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Func<'static>> {
|
|
||||||
let cls = match nclasses {
|
|
||||||
None => None,
|
|
||||||
Some(nclasses) => {
|
|
||||||
let outputs = output_channels_per_stage(config, 4);
|
|
||||||
let linear = linear(outputs, nclasses, vb.pp("head.fc"))?;
|
|
||||||
Some(linear)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let stem_dim = output_channels_per_stage(config, 0);
|
|
||||||
let stem = mobileone_block(false, 1, stem_dim, 2, 1, 1, 3, 3, stem_dim, vb.pp("stem"))?;
|
|
||||||
let vb = vb.pp("stages");
|
|
||||||
let stage1 = mobileone_stage(config, 1, vb.pp(0))?;
|
|
||||||
let stage2 = mobileone_stage(config, 2, vb.pp(1))?;
|
|
||||||
let stage3 = mobileone_stage(config, 3, vb.pp(2))?;
|
|
||||||
let stage4 = mobileone_stage(config, 4, vb.pp(3))?;
|
|
||||||
|
|
||||||
Ok(Func::new(move |xs| {
|
|
||||||
let xs = xs
|
|
||||||
.apply(&stem)?
|
|
||||||
.apply(&stage1)?
|
|
||||||
.apply(&stage2)?
|
|
||||||
.apply(&stage3)?
|
|
||||||
.apply(&stage4)?
|
|
||||||
.mean(D::Minus2)?
|
|
||||||
.mean(D::Minus1)?;
|
|
||||||
match &cls {
|
|
||||||
None => Ok(xs),
|
|
||||||
Some(cls) => xs.apply(cls),
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn mobileone(cfg: &Config, nclasses: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
mobileone_model(cfg, Some(nclasses), vb)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn mobileone_no_final_layer(cfg: &Config, vb: VarBuilder) -> Result<Func<'static>> {
|
|
||||||
mobileone_model(cfg, None, vb)
|
|
||||||
}
|
|
@ -2,9 +2,7 @@ pub mod bert;
|
|||||||
pub mod bigcode;
|
pub mod bigcode;
|
||||||
pub mod blip;
|
pub mod blip;
|
||||||
pub mod blip_text;
|
pub mod blip_text;
|
||||||
pub mod chatglm;
|
|
||||||
pub mod convmixer;
|
pub mod convmixer;
|
||||||
pub mod convnext;
|
|
||||||
pub mod dinov2;
|
pub mod dinov2;
|
||||||
pub mod distilbert;
|
pub mod distilbert;
|
||||||
pub mod efficientnet;
|
pub mod efficientnet;
|
||||||
@ -13,12 +11,10 @@ pub mod jina_bert;
|
|||||||
pub mod llama;
|
pub mod llama;
|
||||||
pub mod llama2_c;
|
pub mod llama2_c;
|
||||||
pub mod llama2_c_weights;
|
pub mod llama2_c_weights;
|
||||||
pub mod mamba;
|
|
||||||
pub mod marian;
|
pub mod marian;
|
||||||
pub mod mistral;
|
pub mod mistral;
|
||||||
pub mod mixformer;
|
pub mod mixformer;
|
||||||
pub mod mixtral;
|
pub mod mixtral;
|
||||||
pub mod mobileone;
|
|
||||||
pub mod mpt;
|
pub mod mpt;
|
||||||
pub mod persimmon;
|
pub mod persimmon;
|
||||||
pub mod phi;
|
pub mod phi;
|
||||||
@ -31,7 +27,6 @@ pub mod quantized_mixformer;
|
|||||||
pub mod quantized_mpt;
|
pub mod quantized_mpt;
|
||||||
pub mod quantized_stable_lm;
|
pub mod quantized_stable_lm;
|
||||||
pub mod quantized_t5;
|
pub mod quantized_t5;
|
||||||
pub mod qwen2;
|
|
||||||
pub mod repvgg;
|
pub mod repvgg;
|
||||||
pub mod resnet;
|
pub mod resnet;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
|
@ -16,7 +16,7 @@ struct RmsNorm {
|
|||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
fn new(scale: QTensor, eps: f32) -> Result<Self> {
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||||
let scale = scale.dequantize(&scale.device())?;
|
let scale = scale.dequantize(&Device::Cpu)?;
|
||||||
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
let inner = candle_nn::LayerNorm::rms_norm(scale, eps as f64);
|
||||||
Ok(Self { inner, span })
|
Ok(Self { inner, span })
|
||||||
}
|
}
|
||||||
@ -275,17 +275,13 @@ pub struct ModelWeights {
|
|||||||
span_output: tracing::Span,
|
span_output: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn precomput_freqs_cis(
|
fn precomput_freqs_cis(head_dim: usize, freq_base: f32) -> Result<(Tensor, Tensor)> {
|
||||||
head_dim: usize,
|
|
||||||
freq_base: f32,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<(Tensor, Tensor)> {
|
|
||||||
let theta: Vec<_> = (0..head_dim)
|
let theta: Vec<_> = (0..head_dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
|
||||||
.collect();
|
.collect();
|
||||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
let theta = Tensor::new(theta.as_slice(), &Device::Cpu)?;
|
||||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, &Device::Cpu)?
|
||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((MAX_SEQ_LEN, 1))?
|
.reshape((MAX_SEQ_LEN, 1))?
|
||||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
@ -296,10 +292,11 @@ fn precomput_freqs_cis(
|
|||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
pub fn from_ggml(mut ct: ggml_file::Content, gqa: usize) -> Result<Self> {
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
let head_dim = (ct.hparams.n_embd / ct.hparams.n_head) as usize;
|
||||||
let (cos, sin) = precomput_freqs_cis(head_dim, 10000., &ct.device)?;
|
let (cos, sin) = precomput_freqs_cis(head_dim, 10000.)?;
|
||||||
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
|
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||||
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
let norm = RmsNorm::new(ct.remove("norm.weight")?, 1e-5)?;
|
||||||
let output = ct.remove("output.weight")?;
|
let output = ct.remove("output.weight")?;
|
||||||
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
|
||||||
@ -361,6 +358,7 @@ impl ModelWeights {
|
|||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
let cpu = &Device::Cpu;
|
||||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||||
None => candle::bail!("cannot find {s} in metadata"),
|
None => candle::bail!("cannot find {s} in metadata"),
|
||||||
Some(v) => Ok(v),
|
Some(v) => Ok(v),
|
||||||
@ -384,10 +382,10 @@ impl ModelWeights {
|
|||||||
let rope_freq_base = md_get("llama.rope.freq_base")
|
let rope_freq_base = md_get("llama.rope.freq_base")
|
||||||
.and_then(|m| m.to_f32())
|
.and_then(|m| m.to_f32())
|
||||||
.unwrap_or(10000f32);
|
.unwrap_or(10000f32);
|
||||||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base)?;
|
||||||
|
|
||||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
let tok_embeddings = tok_embeddings.dequantize(cpu)?;
|
||||||
let norm = RmsNorm::new(
|
let norm = RmsNorm::new(
|
||||||
ct.tensor(reader, "output_norm.weight", device)?,
|
ct.tensor(reader, "output_norm.weight", device)?,
|
||||||
rms_norm_eps,
|
rms_norm_eps,
|
||||||
@ -474,14 +472,14 @@ impl ModelWeights {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mask(&mut self, t: usize, device: &Device) -> Result<Tensor> {
|
fn mask(&mut self, t: usize) -> Result<Tensor> {
|
||||||
if let Some(mask) = self.masks.get(&t) {
|
if let Some(mask) = self.masks.get(&t) {
|
||||||
Ok(mask.clone())
|
Ok(mask.clone())
|
||||||
} else {
|
} else {
|
||||||
let mask: Vec<_> = (0..t)
|
let mask: Vec<_> = (0..t)
|
||||||
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..t).map(move |j| u8::from(j > i)))
|
||||||
.collect();
|
.collect();
|
||||||
let mask = Tensor::from_slice(&mask, (t, t), device)?;
|
let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?;
|
||||||
self.masks.insert(t, mask.clone());
|
self.masks.insert(t, mask.clone());
|
||||||
Ok(mask)
|
Ok(mask)
|
||||||
}
|
}
|
||||||
@ -489,7 +487,7 @@ impl ModelWeights {
|
|||||||
|
|
||||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.dims2()?;
|
let (_b_sz, seq_len) = x.dims2()?;
|
||||||
let mask = self.mask(seq_len, x.device())?;
|
let mask = self.mask(seq_len)?;
|
||||||
let _enter = self.span.enter();
|
let _enter = self.span.enter();
|
||||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use crate::quantized_nn::{layer_norm, linear, linear_no_bias, Embedding, Linear};
|
use crate::quantized_nn::{layer_norm, linear_no_bias, Embedding, Linear};
|
||||||
pub use crate::quantized_var_builder::VarBuilder;
|
pub use crate::quantized_var_builder::VarBuilder;
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, LayerNorm};
|
use candle_nn::{Activation, LayerNorm};
|
||||||
@ -67,14 +67,9 @@ impl Attention {
|
|||||||
let head_dim = cfg.head_dim();
|
let head_dim = cfg.head_dim();
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
let linear_layer = if cfg.use_qkv_bias {
|
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||||
linear
|
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||||
} else {
|
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||||
linear_no_bias
|
|
||||||
};
|
|
||||||
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
|
||||||
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
|
||||||
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
|
||||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
q_proj,
|
q_proj,
|
||||||
|
@ -1,377 +0,0 @@
|
|||||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
|
||||||
use candle_nn::{Activation, VarBuilder};
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
|
||||||
pub struct Config {
|
|
||||||
pub vocab_size: usize,
|
|
||||||
pub hidden_size: usize,
|
|
||||||
pub intermediate_size: usize,
|
|
||||||
pub num_hidden_layers: usize,
|
|
||||||
pub num_attention_heads: usize,
|
|
||||||
pub num_key_value_heads: usize,
|
|
||||||
pub max_position_embeddings: usize,
|
|
||||||
pub sliding_window: usize,
|
|
||||||
pub max_window_layers: usize,
|
|
||||||
pub tie_word_embeddings: bool,
|
|
||||||
pub rope_theta: f64,
|
|
||||||
pub rms_norm_eps: f64,
|
|
||||||
pub use_sliding_window: bool,
|
|
||||||
pub hidden_act: Activation,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct RmsNorm {
|
|
||||||
inner: candle_nn::RmsNorm,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RmsNorm {
|
|
||||||
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
|
||||||
let inner = candle_nn::rms_norm(size, eps, vb)?;
|
|
||||||
Ok(Self { inner, span })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for RmsNorm {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct RotaryEmbedding {
|
|
||||||
sin: Tensor,
|
|
||||||
cos: Tensor,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let last_dim = xs.dim(D::Minus1)?;
|
|
||||||
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
|
|
||||||
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
|
|
||||||
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
|
||||||
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
|
||||||
.step_by(2)
|
|
||||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
|
||||||
.collect();
|
|
||||||
let inv_freq_len = inv_freq.len();
|
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
|
||||||
.to_dtype(dtype)?
|
|
||||||
.reshape((max_seq_len, 1))?;
|
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
|
||||||
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
|
||||||
Ok(Self {
|
|
||||||
sin: freqs.sin()?,
|
|
||||||
cos: freqs.cos()?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn apply_rotary_emb_qkv(
|
|
||||||
&self,
|
|
||||||
q: &Tensor,
|
|
||||||
k: &Tensor,
|
|
||||||
seqlen_offset: usize,
|
|
||||||
) -> Result<(Tensor, Tensor)> {
|
|
||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
|
||||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
|
||||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
|
||||||
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
|
||||||
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
|
|
||||||
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
|
|
||||||
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
|
|
||||||
Ok((q_embed, k_embed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
|
||||||
struct MLP {
|
|
||||||
gate_proj: Linear,
|
|
||||||
up_proj: Linear,
|
|
||||||
down_proj: Linear,
|
|
||||||
act_fn: Activation,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MLP {
|
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let hidden_sz = cfg.hidden_size;
|
|
||||||
let intermediate_sz = cfg.intermediate_size;
|
|
||||||
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
|
||||||
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
|
||||||
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
|
||||||
Ok(Self {
|
|
||||||
gate_proj,
|
|
||||||
up_proj,
|
|
||||||
down_proj,
|
|
||||||
act_fn: cfg.hidden_act,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for MLP {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
|
||||||
let rhs = xs.apply(&self.up_proj)?;
|
|
||||||
(lhs * rhs)?.apply(&self.down_proj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct Attention {
|
|
||||||
q_proj: Linear,
|
|
||||||
k_proj: Linear,
|
|
||||||
v_proj: Linear,
|
|
||||||
o_proj: Linear,
|
|
||||||
num_heads: usize,
|
|
||||||
num_kv_heads: usize,
|
|
||||||
num_kv_groups: usize,
|
|
||||||
head_dim: usize,
|
|
||||||
hidden_size: usize,
|
|
||||||
rotary_emb: Arc<RotaryEmbedding>,
|
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attention {
|
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let hidden_sz = cfg.hidden_size;
|
|
||||||
let num_heads = cfg.num_attention_heads;
|
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
|
||||||
let num_kv_groups = num_heads / num_kv_heads;
|
|
||||||
let head_dim = hidden_sz / num_heads;
|
|
||||||
let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
|
||||||
let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
|
||||||
let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
|
||||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
|
||||||
Ok(Self {
|
|
||||||
q_proj,
|
|
||||||
k_proj,
|
|
||||||
v_proj,
|
|
||||||
o_proj,
|
|
||||||
num_heads,
|
|
||||||
num_kv_heads,
|
|
||||||
num_kv_groups,
|
|
||||||
head_dim,
|
|
||||||
hidden_size: hidden_sz,
|
|
||||||
rotary_emb,
|
|
||||||
kv_cache: None,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
|
|
||||||
let n_rep = self.num_kv_groups;
|
|
||||||
if n_rep == 1 {
|
|
||||||
Ok(xs)
|
|
||||||
} else {
|
|
||||||
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
|
|
||||||
xs.unsqueeze(2)?
|
|
||||||
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
|
|
||||||
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
attention_mask: Option<&Tensor>,
|
|
||||||
seqlen_offset: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let (b_sz, q_len, _) = xs.dims3()?;
|
|
||||||
|
|
||||||
let query_states = self.q_proj.forward(xs)?;
|
|
||||||
let key_states = self.k_proj.forward(xs)?;
|
|
||||||
let value_states = self.v_proj.forward(xs)?;
|
|
||||||
|
|
||||||
let query_states = query_states
|
|
||||||
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
let key_states = key_states
|
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
let value_states = value_states
|
|
||||||
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
|
||||||
.transpose(1, 2)?;
|
|
||||||
|
|
||||||
let (query_states, key_states) =
|
|
||||||
self.rotary_emb
|
|
||||||
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
|
||||||
|
|
||||||
let (key_states, value_states) = match &self.kv_cache {
|
|
||||||
None => (key_states, value_states),
|
|
||||||
Some((prev_k, prev_v)) => {
|
|
||||||
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
|
||||||
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
|
||||||
(key_states, value_states)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
|
||||||
|
|
||||||
let key_states = self.repeat_kv(key_states)?.contiguous()?;
|
|
||||||
let value_states = self.repeat_kv(value_states)?.contiguous()?;
|
|
||||||
|
|
||||||
let attn_output = {
|
|
||||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
|
||||||
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
|
||||||
|
|
||||||
let attn_weights = match attention_mask {
|
|
||||||
None => attn_weights,
|
|
||||||
Some(mask) => attn_weights.broadcast_add(mask)?,
|
|
||||||
};
|
|
||||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
|
||||||
attn_weights.matmul(&value_states)?
|
|
||||||
};
|
|
||||||
attn_output
|
|
||||||
.transpose(1, 2)?
|
|
||||||
.reshape((b_sz, q_len, self.hidden_size))?
|
|
||||||
.apply(&self.o_proj)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clear_kv_cache(&mut self) {
|
|
||||||
self.kv_cache = None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
struct DecoderLayer {
|
|
||||||
self_attn: Attention,
|
|
||||||
mlp: MLP,
|
|
||||||
input_layernorm: RmsNorm,
|
|
||||||
post_attention_layernorm: RmsNorm,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DecoderLayer {
|
|
||||||
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
|
||||||
let input_layernorm =
|
|
||||||
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
|
||||||
let post_attention_layernorm = RmsNorm::new(
|
|
||||||
cfg.hidden_size,
|
|
||||||
cfg.rms_norm_eps,
|
|
||||||
vb.pp("post_attention_layernorm"),
|
|
||||||
)?;
|
|
||||||
Ok(Self {
|
|
||||||
self_attn,
|
|
||||||
mlp,
|
|
||||||
input_layernorm,
|
|
||||||
post_attention_layernorm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(
|
|
||||||
&mut self,
|
|
||||||
xs: &Tensor,
|
|
||||||
attention_mask: Option<&Tensor>,
|
|
||||||
seqlen_offset: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
let residual = xs;
|
|
||||||
let xs = self.input_layernorm.forward(xs)?;
|
|
||||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
|
||||||
let xs = (xs + residual)?;
|
|
||||||
let residual = &xs;
|
|
||||||
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
|
|
||||||
residual + xs
|
|
||||||
}
|
|
||||||
|
|
||||||
fn clear_kv_cache(&mut self) {
|
|
||||||
self.self_attn.clear_kv_cache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Model {
|
|
||||||
embed_tokens: candle_nn::Embedding,
|
|
||||||
layers: Vec<DecoderLayer>,
|
|
||||||
norm: RmsNorm,
|
|
||||||
lm_head: Linear,
|
|
||||||
sliding_window: usize,
|
|
||||||
device: Device,
|
|
||||||
dtype: DType,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Model {
|
|
||||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let vb_m = vb.pp("model");
|
|
||||||
let embed_tokens =
|
|
||||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
|
||||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
|
||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
|
||||||
let vb_l = vb_m.pp("layers");
|
|
||||||
for layer_idx in 0..cfg.num_hidden_layers {
|
|
||||||
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
|
||||||
layers.push(layer)
|
|
||||||
}
|
|
||||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
|
||||||
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
|
||||||
Ok(Self {
|
|
||||||
embed_tokens,
|
|
||||||
layers,
|
|
||||||
norm,
|
|
||||||
lm_head,
|
|
||||||
sliding_window: cfg.sliding_window,
|
|
||||||
device: vb.device().clone(),
|
|
||||||
dtype: vb.dtype(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prepare_decoder_attention_mask(
|
|
||||||
&self,
|
|
||||||
b_size: usize,
|
|
||||||
tgt_len: usize,
|
|
||||||
seqlen_offset: usize,
|
|
||||||
) -> Result<Tensor> {
|
|
||||||
// Sliding window mask?
|
|
||||||
let mask: Vec<_> = (0..tgt_len)
|
|
||||||
.flat_map(|i| {
|
|
||||||
(0..tgt_len).map(move |j| {
|
|
||||||
if i < j || j + self.sliding_window < i {
|
|
||||||
f32::NEG_INFINITY
|
|
||||||
} else {
|
|
||||||
0.
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
|
||||||
let mask = if seqlen_offset > 0 {
|
|
||||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
|
||||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
|
||||||
} else {
|
|
||||||
mask
|
|
||||||
};
|
|
||||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
|
||||||
.to_dtype(self.dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
|
||||||
let (b_size, seq_len) = input_ids.dims2()?;
|
|
||||||
let attention_mask = if seq_len <= 1 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
|
||||||
Some(mask)
|
|
||||||
};
|
|
||||||
let mut xs = self.embed_tokens.forward(input_ids)?;
|
|
||||||
for layer in self.layers.iter_mut() {
|
|
||||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
|
||||||
}
|
|
||||||
xs.narrow(1, seq_len - 1, 1)?
|
|
||||||
.apply(&self.norm)?
|
|
||||||
.apply(&self.lm_head)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn clear_kv_cache(&mut self) {
|
|
||||||
for layer in self.layers.iter_mut() {
|
|
||||||
layer.clear_kv_cache()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,11 +1,10 @@
|
|||||||
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
use crate::models::with_tracing::{linear_no_bias, Linear};
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, LayerNorm, VarBuilder};
|
use candle_nn::{Activation, LayerNorm, VarBuilder};
|
||||||
use serde::Deserialize;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
|
// https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/configuration_stablelm_epoch.py
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub(crate) vocab_size: usize,
|
pub(crate) vocab_size: usize,
|
||||||
pub(crate) intermediate_size: usize,
|
pub(crate) intermediate_size: usize,
|
||||||
@ -19,10 +18,7 @@ pub struct Config {
|
|||||||
pub(crate) max_position_embeddings: usize,
|
pub(crate) max_position_embeddings: usize,
|
||||||
pub(crate) norm_eps: f64,
|
pub(crate) norm_eps: f64,
|
||||||
pub(crate) use_cache: bool,
|
pub(crate) use_cache: bool,
|
||||||
#[serde(default)]
|
pub(crate) use_flash_attn: bool,
|
||||||
pub(crate) use_qkv_bias: bool, // Used in StableLM-2
|
|
||||||
#[serde(default)]
|
|
||||||
pub(crate) use_flash_attn: bool, // Not in config.json
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -39,7 +35,6 @@ impl Config {
|
|||||||
rope_theta: 10_000.,
|
rope_theta: 10_000.,
|
||||||
max_position_embeddings: 4096,
|
max_position_embeddings: 4096,
|
||||||
norm_eps: 1e-5,
|
norm_eps: 1e-5,
|
||||||
use_qkv_bias: false,
|
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
}
|
}
|
||||||
@ -56,10 +51,6 @@ impl Config {
|
|||||||
pub fn num_kv_groups(&self) -> usize {
|
pub fn num_kv_groups(&self) -> usize {
|
||||||
self.num_attention_heads / self.num_key_value_heads
|
self.num_attention_heads / self.num_key_value_heads
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_use_flash_attn(&mut self, use_flash_attn: bool) {
|
|
||||||
self.use_flash_attn = use_flash_attn
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -188,15 +179,9 @@ impl Attention {
|
|||||||
let head_dim = cfg.head_dim();
|
let head_dim = cfg.head_dim();
|
||||||
let num_heads = cfg.num_attention_heads;
|
let num_heads = cfg.num_attention_heads;
|
||||||
let num_kv_heads = cfg.num_key_value_heads;
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
let linear_layer = if cfg.use_qkv_bias {
|
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
||||||
linear
|
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
||||||
} else {
|
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
||||||
linear_no_bias
|
|
||||||
};
|
|
||||||
|
|
||||||
let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
|
|
||||||
let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
|
|
||||||
let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
|
|
||||||
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
q_proj,
|
q_proj,
|
||||||
|
@ -1,21 +1,15 @@
|
|||||||
use crate::models::vit::{Config, Embeddings, Encoder};
|
use crate::models::vit::{Config, Embeddings, Encoder};
|
||||||
use candle::{DType, Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
use candle_nn::{
|
use candle_nn::{
|
||||||
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
embedding, layer_norm, linear_no_bias, Embedding, LayerNorm, Linear, Module, VarBuilder,
|
||||||
};
|
};
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
fn default_tie_word_embeddings() -> bool {
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
true
|
|
||||||
}
|
|
||||||
fn default_use_learned_position_embeddings() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
|
||||||
pub struct TrOCRConfig {
|
pub struct TrOCRConfig {
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
pub d_model: usize,
|
pub d_model: usize,
|
||||||
pub cross_attention_hidden_size: usize,
|
pub hidden_size: usize,
|
||||||
pub decoder_layers: usize,
|
pub decoder_layers: usize,
|
||||||
pub decoder_attention_heads: usize,
|
pub decoder_attention_heads: usize,
|
||||||
pub decoder_ffn_dim: usize,
|
pub decoder_ffn_dim: usize,
|
||||||
@ -29,14 +23,13 @@ pub struct TrOCRConfig {
|
|||||||
pub decoder_layerdrop: f64,
|
pub decoder_layerdrop: f64,
|
||||||
pub use_cache: bool,
|
pub use_cache: bool,
|
||||||
pub scale_embedding: bool,
|
pub scale_embedding: bool,
|
||||||
|
pub use_learned_position_embeddings: bool,
|
||||||
|
pub layernorm_embedding: bool,
|
||||||
pub pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
pub bos_token_id: usize,
|
pub bos_token_id: usize,
|
||||||
pub eos_token_id: u32,
|
pub eos_token_id: u32,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
pub decoder_vocab_size: Option<usize>,
|
pub decoder_vocab_size: Option<usize>,
|
||||||
#[serde(default = "default_use_learned_position_embeddings")]
|
|
||||||
pub use_learned_position_embeddings: bool,
|
|
||||||
#[serde(default = "default_tie_word_embeddings")]
|
|
||||||
pub tie_word_embeddings: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for TrOCRConfig {
|
impl Default for TrOCRConfig {
|
||||||
@ -44,7 +37,7 @@ impl Default for TrOCRConfig {
|
|||||||
Self {
|
Self {
|
||||||
vocab_size: 50265,
|
vocab_size: 50265,
|
||||||
d_model: 1024,
|
d_model: 1024,
|
||||||
cross_attention_hidden_size: 768,
|
hidden_size: 768,
|
||||||
decoder_layers: 12,
|
decoder_layers: 12,
|
||||||
decoder_attention_heads: 16,
|
decoder_attention_heads: 16,
|
||||||
decoder_ffn_dim: 4096,
|
decoder_ffn_dim: 4096,
|
||||||
@ -58,12 +51,13 @@ impl Default for TrOCRConfig {
|
|||||||
decoder_layerdrop: 0.0,
|
decoder_layerdrop: 0.0,
|
||||||
use_cache: true,
|
use_cache: true,
|
||||||
scale_embedding: false,
|
scale_embedding: false,
|
||||||
|
use_learned_position_embeddings: true,
|
||||||
|
layernorm_embedding: true,
|
||||||
pad_token_id: 1,
|
pad_token_id: 1,
|
||||||
bos_token_id: 0,
|
bos_token_id: 0,
|
||||||
eos_token_id: 2,
|
eos_token_id: 2,
|
||||||
|
num_attention_heads: 12,
|
||||||
decoder_vocab_size: Some(50265),
|
decoder_vocab_size: Some(50265),
|
||||||
use_learned_position_embeddings: true,
|
|
||||||
tie_word_embeddings: true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -84,49 +78,17 @@ impl TrOCRLearnedPositionalEmbedding {
|
|||||||
Ok(Self { offset, weights })
|
Ok(Self { offset, weights })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_sinusoidal(vb: VarBuilder, cfg: &TrOCRConfig) -> Result<Self> {
|
|
||||||
// https://github.com/huggingface/transformers/blob/58e3d23e97078f361a533b9ec4a6a2de674ea52a/src/transformers/models/trocr/modeling_trocr.py#L81
|
|
||||||
let embedding_dim = cfg.d_model;
|
|
||||||
let half_dim = embedding_dim / 2;
|
|
||||||
let num_positions = cfg.max_position_embeddings + cfg.pad_token_id + 1;
|
|
||||||
let dev = vb.device();
|
|
||||||
let inv_freq: Vec<_> = (0..half_dim)
|
|
||||||
.map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32))
|
|
||||||
.collect();
|
|
||||||
let inv_freq_len = inv_freq.len();
|
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
|
|
||||||
let t = Tensor::arange(0u32, num_positions as u32, dev)?
|
|
||||||
.to_dtype(DType::F32)?
|
|
||||||
.reshape((num_positions, 1))?;
|
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
|
||||||
let emb = Tensor::cat(&[freqs.sin()?, freqs.cos()?], 1)?;
|
|
||||||
let emb = Tensor::cat(
|
|
||||||
&[
|
|
||||||
emb.narrow(0, 0, cfg.pad_token_id)?,
|
|
||||||
Tensor::zeros((1, embedding_dim), DType::F32, dev)?,
|
|
||||||
emb.narrow(0, cfg.pad_token_id + 1, cfg.max_position_embeddings)?,
|
|
||||||
],
|
|
||||||
0,
|
|
||||||
)?
|
|
||||||
.contiguous()?;
|
|
||||||
let emb = Embedding::new(emb, embedding_dim);
|
|
||||||
Ok(Self {
|
|
||||||
offset: cfg.pad_token_id + 1,
|
|
||||||
weights: emb,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
fn forward(&mut self, input_ids: &Tensor, past_key_values_length: u32) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len) = input_ids.dims2()?;
|
let (b_sz, seq_len) = input_ids.dims2()?;
|
||||||
|
|
||||||
let positions = Tensor::arange(
|
let mut positions = Tensor::arange(
|
||||||
past_key_values_length,
|
past_key_values_length,
|
||||||
seq_len as u32 + past_key_values_length,
|
seq_len as u32 + past_key_values_length,
|
||||||
input_ids.device(),
|
input_ids.device(),
|
||||||
)?
|
)?
|
||||||
.expand((b_sz, seq_len))?;
|
.expand((b_sz, seq_len))?;
|
||||||
|
|
||||||
let positions =
|
positions =
|
||||||
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
positions.broadcast_add(&Tensor::new(self.offset as u32, input_ids.device())?)?;
|
||||||
self.weights.forward(&positions)
|
self.weights.forward(&positions)
|
||||||
}
|
}
|
||||||
@ -259,17 +221,19 @@ impl TrOCRDecoderLayer {
|
|||||||
let encoder_attn = TrOCRAttention::load(
|
let encoder_attn = TrOCRAttention::load(
|
||||||
vb.pp("encoder_attn"),
|
vb.pp("encoder_attn"),
|
||||||
cfg,
|
cfg,
|
||||||
Some(cfg.cross_attention_hidden_size),
|
Some(cfg.hidden_size),
|
||||||
Some(cfg.cross_attention_hidden_size),
|
Some(cfg.hidden_size),
|
||||||
)?;
|
)?;
|
||||||
let encoder_attn_layer_norm =
|
let encoder_attn_layer_norm =
|
||||||
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
layer_norm(embed_dim, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||||
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
let fc1 = linear_no_bias(embed_dim, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||||
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
let fc2 = linear_no_bias(cfg.decoder_ffn_dim, embed_dim, vb.pp("fc2"))?;
|
||||||
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
let final_layer_norm = layer_norm(embed_dim, 1e-5, vb.pp("final_layer_norm"))?;
|
||||||
|
let activation_fn = candle_nn::Activation::Gelu;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
self_attn,
|
self_attn,
|
||||||
activation_fn: cfg.activation_function,
|
activation_fn,
|
||||||
self_attn_layer_norm,
|
self_attn_layer_norm,
|
||||||
encoder_attn,
|
encoder_attn,
|
||||||
encoder_attn_layer_norm,
|
encoder_attn_layer_norm,
|
||||||
@ -330,11 +294,7 @@ impl TrOCRDecoder {
|
|||||||
let vb = vb.pp("decoder.model.decoder");
|
let vb = vb.pp("decoder.model.decoder");
|
||||||
|
|
||||||
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
let embed_tokens = embedding(cfg.vocab_size, cfg.d_model, vb.pp("embed_tokens"))?;
|
||||||
let embed_positions = if cfg.use_learned_position_embeddings {
|
let embed_positions = TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?;
|
||||||
TrOCRLearnedPositionalEmbedding::load(vb.pp("embed_positions"), cfg)?
|
|
||||||
} else {
|
|
||||||
TrOCRLearnedPositionalEmbedding::new_sinusoidal(vb.pp("embed_positions"), cfg)?
|
|
||||||
};
|
|
||||||
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
let mut layers = Vec::with_capacity(cfg.decoder_layers);
|
||||||
let vb_l = vb.pp("layers");
|
let vb_l = vb.pp("layers");
|
||||||
for idx in 0..cfg.decoder_layers {
|
for idx in 0..cfg.decoder_layers {
|
||||||
@ -423,15 +383,8 @@ pub struct TrOCRForCausalLM {
|
|||||||
impl TrOCRForCausalLM {
|
impl TrOCRForCausalLM {
|
||||||
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
pub fn new(decoder_cfg: &TrOCRConfig, vb: VarBuilder) -> Result<Self> {
|
||||||
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
let decoder = TrOCRDecoder::new(decoder_cfg, vb.clone())?;
|
||||||
let output_projection = if decoder_cfg.tie_word_embeddings {
|
let output_projection =
|
||||||
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None)
|
candle_nn::Linear::new(decoder.embed_tokens.embeddings().clone(), None);
|
||||||
} else {
|
|
||||||
candle_nn::linear_no_bias(
|
|
||||||
decoder_cfg.d_model,
|
|
||||||
decoder_cfg.vocab_size,
|
|
||||||
vb.pp("decoder.output_projection"),
|
|
||||||
)?
|
|
||||||
};
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
decoder,
|
decoder,
|
||||||
output_projection,
|
output_projection,
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user