Compare commits

..

1 Commits

Author SHA1 Message Date
a0282751d5 Tmp. 2023-12-11 19:51:46 +01:00
103 changed files with 1118 additions and 5625 deletions

View File

@ -63,7 +63,7 @@ This documents the main changes to the `candle` crate.
[760](https://github.com/huggingface/candle/pull/760).
- Add the Segment-Anything Model (SAM) as an example
[773](https://github.com/huggingface/candle/pull/773).
- TinyViT backbone for the segment anything example
- TinyViT backbone for the segemnt anything example
[787](https://github.com/huggingface/candle/pull/787).
- Shape with holes support
[770](https://github.com/huggingface/candle/pull/770).

View File

@ -19,7 +19,7 @@ exclude = [
resolver = "2"
[workspace.package]
version = "0.3.3"
version = "0.3.1"
edition = "2021"
description = "Minimalist ML framework."
repository = "https://github.com/huggingface/candle"
@ -32,7 +32,6 @@ accelerate-src = { version = "0.3.2" }
anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
criterion = { version = "0.5.1", default-features=false }
cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
hf-hub = "0.3.0"
@ -62,7 +61,7 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { version = "0.27.0", features = ["mps"]}
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
[profile.release-with-debug]
inherits = "release"

View File

@ -54,25 +54,19 @@ These online demos run entirely in your browser:
- [whisper](https://huggingface.co/spaces/lmz/candle-whisper): speech recognition.
- [LLaMA2](https://huggingface.co/spaces/lmz/candle-llama2): text generation.
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
We also provide a some command line based examples using state of the art models:
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets.
- [Minimal Mamba](./candle-examples/examples/minimal-mamba/): a minimal
implementation of the Mamba state space model.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
better performance than all publicly available 13b models as of 2023-09-28.
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
experts 8x7b general LLM with better performance than a Llama 2 70B model with
much faster inference.
performance larger than all publicly available 13b models as of 2023-09-28.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
@ -84,7 +78,7 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/quantized/assets/aoc.gif" width="600">
- [Stable Diffusion](./candle-examples/examples/stable-diffusion/): text to
image generative model, support for the 1.5, 2.1, SDXL 1.0 and Turbo versions.
image generative model, support for the 1.5, 2.1, and SDXL 1.0 versions.
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg" width="200">
@ -128,7 +122,7 @@ There are also some wasm examples for whisper and
[whisper](https://huggingface.co/spaces/lmz/candle-whisper),
[llama2](https://huggingface.co/spaces/lmz/candle-llama2),
[T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm),
[Phi-1.5, and Phi-2](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
[Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm),
[Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm).
For LLaMA2, run the following command to retrieve the weight files and start a
@ -147,10 +141,8 @@ And then head over to
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): A
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and
ergonomic LoRA implementation for Candle. `candle-lora` has
out-of-the-box LoRA support for many models from Candle, which can be found
[here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): Efficient and ergonomic LoRA implemenation for Candle. `candle-lora` has
out-of-the-box LoRA support for many models from Candle, which can be found [here](https://github.com/EricLBuehler/candle-lora/tree/master/candle-lora-transformers/examples).
- [`optimisers`](https://github.com/KGrewal1/optimisers): A collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-vllm`](https://github.com/EricLBuehler/candle-vllm): Efficient platform for inference and
@ -176,13 +168,11 @@ If you have an addition to this list, please submit a pull request.
- WASM support, run your models in a browser.
- Included models.
- Language Models.
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
- LLaMA v1 and v2.
- Falcon.
- StarCoder.
- Phi 1, 1.5, and 2.
- Minimal Mamba
- Phi v1.5.
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T.
- Replit-code-v1.5-3B.
- Bert.
@ -190,9 +180,8 @@ If you have an addition to this list, please submit a pull request.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
- Mixtral 8x7b.
- Zephyr 7b a and b (Mistral-7b based).
- OpenChat 3.5 (Mistral-7b based).
- Zephyr 7b a and b (Mistral based).
- OpenChat 3.5 (Mistral based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).

View File

@ -11,11 +11,11 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View File

@ -12,8 +12,8 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.3", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.3", optional = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.1", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.1", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
@ -34,8 +34,6 @@ zip = { workspace = true }
[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }
criterion = { workspace = true }
[features]
default = []
@ -44,8 +42,3 @@ cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:metal", "dep:candle-metal-kernels"]
[[bench]]
name = "matmul"
harness = false

View File

@ -1,42 +0,0 @@
use candle_core::{DType, Device, Tensor};
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};
use std::time::Instant;
fn run(a: &Tensor, b: &Tensor) {
a.matmul(&b.t().unwrap()).unwrap();
}
fn criterion_benchmark(c: &mut Criterion) {
let b = 1;
let m = 1;
let n = 2048;
let k = 2048;
let device = Device::new_metal(0).unwrap();
let dtype = DType::F32;
let lhs = Tensor::zeros((b, m, k), dtype, &device).unwrap();
let rhs = Tensor::zeros((b, n, k), dtype, &device).unwrap();
let flops = b * m * n * k;
let mut group = c.benchmark_group("matmul_metal");
group.throughput(Throughput::Bytes(flops as u64));
group.bench_function("iter", move |b| {
b.iter_custom(|iters| {
let start = Instant::now();
for _i in 0..iters {
run(black_box(&lhs), black_box(&rhs));
}
if let Device::Metal(device) = &device {
device.wait_until_completed().unwrap();
} else {
panic!("Expected metal device");
}
start.elapsed()
})
});
group.finish();
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@ -114,7 +114,7 @@ impl Tensor {
| Op::Unary(_node, UnaryOp::Round) => nodes,
Op::Reshape(node)
| Op::UpsampleNearest1D(node)
| Op::UpsampleNearest2D { arg: node, .. }
| Op::UpsampleNearest2D(node)
| Op::AvgPool2D { arg: node, .. }
| Op::MaxPool2D { arg: node, .. }
| Op::Copy(node)
@ -350,27 +350,9 @@ impl Tensor {
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest1d",
})?,
Op::UpsampleNearest2D {
arg,
target_h,
target_w,
} => {
let (_n, c, h, w) = arg.dims4()?;
if target_h % h != 0 || target_w % w != 0 {
crate::bail!("backward not supported for non integer upscaling factors")
}
let scale_h = target_h / h;
let scale_w = target_w / w;
if scale_h != scale_w {
crate::bail!("backward not supported for non uniform upscaling factors")
};
let kernel =
Tensor::ones((c, 1, scale_h, scale_w), arg.dtype(), arg.device())?;
let conv_sum = grad.conv2d(&kernel, 0, scale_h, 1, c)?;
let sum_grad = grads.or_insert(arg)?;
*sum_grad = conv_sum;
}
Op::UpsampleNearest2D { .. } => Err(Error::BackwardNotSupported {
op: "upsample-nearest2d",
})?,
Op::SliceScatter0(lhs, rhs, start_rhs) => {
let rhs_sum_grad = grads.or_insert(rhs)?;
let rhs_grad = grad.narrow(0, *start_rhs, rhs.dim(0)?)?;

View File

@ -201,9 +201,10 @@ impl Device {
Ok(Storage::Cuda(storage))
}
}
Device::Metal(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Metal(storage))
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
crate::bail!("Metal rand_uniform not implemented")
}
}
}

View File

@ -64,7 +64,7 @@ impl Tensor {
#[derive(Debug)]
/// Generic structure used to index a slice of the tensor
pub enum TensorIndexer {
/// This selects the elements for which an index has some specific value.
/// This selects the elemnts for which an index has some specific value.
Select(usize),
/// This is a regular slice, purely indexing a chunk of the tensor
Narrow(Bound<usize>, Bound<usize>),

File diff suppressed because it is too large Load Diff

View File

@ -132,11 +132,7 @@ pub enum Op {
},
UpsampleNearest1D(Tensor),
UpsampleNearest2D {
arg: Tensor,
target_h: usize,
target_w: usize,
},
UpsampleNearest2D(Tensor),
Cat(Vec<Tensor>, usize),

View File

@ -353,7 +353,7 @@ pub(crate) fn vec_dot_q3k_q8k(n: usize, xs: &[BlockQ3K], ys: &[BlockQ8K]) -> Res
q3 = q3.add(32);
// Prepare low and high bits
// We hardcode the shifts here to avoid loading them into a separate register
// We hardcode the shifts here to avoid loading them into a seperate register
let q3l_0 = _mm256_and_si256(q3bits, m3);
let q3h_0 = if j == 0 {
_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, 0)), 0)
@ -586,7 +586,7 @@ pub(crate) fn vec_dot_q5k_q8k(n: usize, xs: &[BlockQ5K], ys: &[BlockQ8K]) -> Res
let q5bits = _mm256_loadu_si256(q5 as *const __m256i);
q5 = q5.add(32);
//Similar to q3k we hardcode the shifts here to avoid loading them into a separate register
//Similar to q3k we hardcode the shifts here to avoid loading them into a seperate register
let q5l_0 = _mm256_and_si256(q5bits, m4);
let q5l_0_shift_input = _mm256_and_si256(hbits, hmask);
let q5l_0_right_shift = match j {

View File

@ -41,7 +41,7 @@ impl VersionedMagic {
(Magic::Gguf, 1) => Self::GgufV1,
(Magic::Gguf, 2) => Self::GgufV2,
(Magic::Gguf, 3) => Self::GgufV3,
_ => crate::bail!("gguf: unsupported magic/version {magic:?}/{version}"),
_ => crate::bail!("ggml: unsupported magic/version {magic:?}/{version}"),
};
Ok(versioned_magic)
}
@ -463,7 +463,7 @@ impl Content {
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor info for {name}"),
None => crate::bail!("cannot find tensor-infor for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset)
}

View File

@ -1,4 +1,4 @@
//! Tensors are N-dimensional matrixes of elements using a single data type.
//! Tensors are N-dimenional matrixes of elements using a single data type.
#![allow(clippy::redundant_closure_call)]
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{
@ -361,16 +361,6 @@ impl Tensor {
Self::new_impl(array, shape, device, false)
}
/// Returns a new tensor with all the elements having the same specified value. Note that
/// the tensor is not contiguous so you would have to call `.contiguous()` on it if needed.
pub fn full<D: crate::WithDType, S: Into<Shape>>(
value: D,
shape: S,
device: &Device,
) -> Result<Self> {
Self::from_vec_impl(vec![value], (), device, false)?.broadcast_as(shape)
}
/// Creates a new 1D tensor from an iterator.
pub fn from_iter<D: crate::WithDType>(
iter: impl IntoIterator<Item = D>,
@ -396,7 +386,7 @@ impl Tensor {
device: &Device,
) -> Result<Self> {
if D::is_zero(&step) {
bail!("step cannot be zero")
crate::bail!("step cannot be zero")
}
let mut data = vec![];
let mut current = start;
@ -679,7 +669,7 @@ impl Tensor {
}
/// Split a tensor into the specified number of chunks, this may return less chunks than
/// specified.
/// specificed.
pub fn chunk<D: Dim>(&self, chunks: usize, dim: D) -> Result<Vec<Self>> {
let dim = dim.to_index(self.shape(), "chunk")?;
let size = self.dim(dim)?;
@ -1004,11 +994,7 @@ impl Tensor {
/// tensor also has four dimensions, `(batch, channels, target_h, target_w)`.
pub fn interpolate2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
let (n, c, _h, _w) = self.dims4()?;
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest2D {
arg,
target_h,
target_w,
});
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
let storage = self
.storage()
.upsample_nearest2d(self.layout(), target_h, target_w)?;
@ -1041,9 +1027,6 @@ impl Tensor {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
if h < kernel_size.0 || w < kernel_size.1 {
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
}
// https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
@ -1079,9 +1062,6 @@ impl Tensor {
let kernel_size = kernel_size.to_usize2();
let stride = stride.to_usize2();
let (n, c, h, w) = self.dims4()?;
if h < kernel_size.0 || w < kernel_size.1 {
bail!("kernel-size {kernel_size:?} is larger than the input size {h},{w}")
}
// https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html#torch.nn.MaxPool2d
let h_out = (h - kernel_size.0) / stride.0 + 1;
let w_out = (w - kernel_size.1) / stride.1 + 1;
@ -1804,7 +1784,7 @@ impl Tensor {
let is_permutation =
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
if !is_permutation {
bail!(
crate::bail!(
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
self.dims(),
dims
@ -1883,7 +1863,10 @@ impl Tensor {
Storage::Metal(metal.storage_from_cpu_storage(storage)?)
}
(Storage::Cuda(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => Storage::Cpu(storage.to_cpu_storage()?),
(Storage::Metal(storage), Device::Cpu) => {
println!("{storage:?} - {:?}", storage.to_cpu_storage()?);
Storage::Cpu(storage.to_cpu_storage()?)
}
(Storage::Cuda(storage), Device::Cuda(cuda)) => {
// TODO: Avoid passing through the cpu storage here, especially if the gpu ids
// are the same.
@ -2299,7 +2282,7 @@ impl Tensor {
if left == 0 && right == 0 {
Ok(self.clone())
} else if self.elem_count() == 0 {
bail!("cannot use pad_with_same on an empty tensor")
crate::bail!("cannot use pad_with_same on an empty tensor")
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
@ -2463,13 +2446,13 @@ impl Tensor {
pub fn normalize_axis(&self, axis: i64) -> Result<usize> {
let rank = self.rank() as i64;
if rank <= axis {
bail!("axis {axis} is too large, tensor rank {rank}")
crate::bail!("axis {axis} is too large, tensor rank {rank}")
} else if 0 <= axis {
Ok(axis as usize)
} else {
let naxis = rank + axis;
if naxis < 0 {
bail!("axis {axis} is too small, tensor rank {rank}")
crate::bail!("axis {axis} is too small, tensor rank {rank}")
}
Ok(naxis as usize)
}
@ -2531,14 +2514,14 @@ impl Tensor {
let src_dims = src.dims();
let self_dims = self.dims();
if self_dims.len() != src_dims.len() {
bail!(
crate::bail!(
"slice-assign requires input with the same rank {} <> {}",
self_dims.len(),
src_dims.len()
)
}
if self_dims.len() != ranges.len() {
bail!(
crate::bail!(
"slice-assign requires input with the same rank as there are ranges {} <> {}",
self_dims.len(),
ranges.len()
@ -2558,16 +2541,18 @@ impl Tensor {
std::ops::Bound::Excluded(v) => *v,
};
if end_excluded <= start_included {
bail!("slice-assign: empty range for dim {i}, {start_included} {end_excluded}")
crate::bail!(
"slice-assign: empty range for dim {i}, {start_included} {end_excluded}"
)
}
if self_dims[i] < end_excluded {
bail!(
crate::bail!(
"slice-assign: upper bound is out of range for dim {i}, {end_excluded} {}",
self_dims[i]
)
}
if end_excluded - start_included != src_dims[i] {
bail!(
crate::bail!(
"slice-assign: the range for dim {i} ({start_included}..{end_excluded}) does not match the size of src {}", src_dims[i]
)
}
@ -2576,13 +2561,6 @@ impl Tensor {
}
mask.where_cond(/* on_true= */ &src, /* on_false= */ self)
}
/// Returns log(sum(exp(tensor), dim)).
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
}
macro_rules! bin_trait {

View File

@ -270,166 +270,6 @@ fn unary_grad(device: &Device) -> Result<()> {
[0.7358, 2.0000, 0.2707, 1.0000]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06.,
07., 08., 09., 10., 11., 12.,
13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
],
device,
)?;
// gradient should be
// row 1
// 1+2+7+8 = 18
// 3+4+9+10 = 26
// 5+6+11+12 = 34
// row 2
// 13+14+19+20 = 66
// 15+16+21+22 = 74
// 17+18+23+24 = 82
// row 3
// 25+26+31+32 = 114
// 27+28+33+34 = 122
// 29+30+35+36 = 130
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
[[18_f32, 26., 34.], [66., 74., 82.], [114., 122., 130.]]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
let y = x.interpolate2d(6, 6)?.reshape(36)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04., 05., 06.,
07., 08., 09., 10., 11., 12.,
13., 14., 15., 16., 17., 18.,
19., 20., 21., 22., 23., 24.,
25., 26., 27., 28., 29., 30.,
31., 32., 33., 34., 35., 36.,
],
device,
)?;
// gradient should be
// row 1
// 1+2+3+7+8+9+13+14+15 = 72
// 4+5+6+10+11+12+16+17+18 = 99
// row 2
// 19+20+21+25+26+27+31+32+33 = 234
// 22+23+24+28+29+30+34+35+36 = 243
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec2_round(&grad_x.flatten(0, 2)?, 4)?,
[[72_f32, 99.], [234., 261.]]
);
// manually checked: see comments
let x = Var::new(&[[[[1f32, 2.], [4., 5.]], [[6f32, 7.], [8., 9.]]]], device)?;
let y = x.interpolate2d(4, 4)?.reshape(32)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04.,
05., 06., 07., 08.,
09., 10., 11., 12.,
13., 14., 15., 16.,
17., 18., 19., 20.,
21., 22., 23., 24.,
25., 26., 27., 28.,
29., 30., 31., 32.
],
device,
)?;
// gradient should be
// m1r1
// 1+2+5+6=14
// 3+4+7+8=22
// m1r2
// 9+10+13+14=46
// 11+12+15+16=54
// m2r1
// 17+18+21+22=78
// 19+20+23+24=86
// m2r2
// 25+26+29+30=110
// 27+28+31+32=118
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
);
// manually checked: see comments
let x = Var::new(
&[[[[1f32, 2.], [4., 5.]]], [[[6f32, 7.], [8., 9.]]]],
device,
)?;
let y = x.interpolate2d(4, 4)?.reshape(32)?;
#[rustfmt::skip]
let z = Tensor::new(
&[
1_f32, 02., 03., 04.,
05., 06., 07., 08.,
09., 10., 11., 12.,
13., 14., 15., 16.,
17., 18., 19., 20.,
21., 22., 23., 24.,
25., 26., 27., 28.,
29., 30., 31., 32.
],
device,
)?;
// gradient should be
// m1r1
// 1+2+5+6=14
// 3+4+7+8=22
// m1r2
// 9+10+13+14=46
// 11+12+15+16=54
// m2r1
// 17+18+21+22=78
// 19+20+23+24=86
// m2r2
// 25+26+29+30=110
// 27+28+31+32=118
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
let grads = loss.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec3_round(&grad_x.flatten(0, 1)?, 4)?,
[[[14_f32, 22.], [46., 54.]], [[78., 86.], [110., 118.]]]
);
Ok(())
}

View File

@ -1,4 +1,4 @@
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor, D};
use candle_core::{test_device, test_utils, DType, Device, IndexOp, Result, Tensor};
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
@ -32,14 +32,6 @@ fn ones(device: &Device) -> Result<()> {
Ok(())
}
fn full(device: &Device) -> Result<()> {
assert_eq!(
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
[[42, 42, 42], [42, 42, 42]],
);
Ok(())
}
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
@ -1080,7 +1072,6 @@ fn randn(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
@ -1230,26 +1221,3 @@ fn cumsum() -> Result<()> {
);
Ok(())
}
/// A helper function for floating point comparison. Both a and b must be 1D Tensor and contains the same amount of data.
/// Assertion passes if the difference of all pairs of a and b is smaller than epsilon.
fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
let a_vec: Vec<f64> = a.to_vec1()?;
let b_vec: Vec<f64> = b.to_vec1()?;
assert_eq!(a_vec.len(), b_vec.len());
for (a, b) in a_vec.iter().zip(b_vec.iter()) {
assert!((a - b).abs() < epsilon);
}
Ok(())
}
#[test]
fn logsumexp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.logsumexp(D::Minus1)?;
// The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?;
Ok(())
}

View File

@ -11,8 +11,8 @@ readme = "README.md"
[dependencies]
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
hf-hub = { workspace = true}
intel-mkl-src = { workspace = true, optional = true }
memmap2 = { workspace = true }

View File

@ -11,17 +11,14 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.3" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../candle-transformers", version = "0.3.3" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.3", optional = true }
csv = "1.3.0"
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.3.1" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../candle-transformers", version = "0.3.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
candle-onnx = { path = "../candle-onnx", version = "0.3.1", optional = true }
cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true }
hf-hub = { workspace = true, features=["tokio"]}
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
@ -36,6 +33,7 @@ tokenizers = { workspace = true, features = ["onig"] }
anyhow = { workspace = true }
byteorder = { workspace = true }
clap = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
imageproc = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }

View File

@ -32,8 +32,6 @@ impl KernelDirectories {
if should_compile {
#[cfg(feature = "cuda")]
{
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let mut command = std::process::Command::new("nvcc");
let out_dir = ptx_file.parent().context("no parent for ptx file")?;
let include_dirs: Vec<String> =
@ -46,11 +44,6 @@ impl KernelDirectories {
.arg(format!("-I/{}", self.kernel_dir))
.args(include_dirs)
.arg(cu_file);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
let output = command
.spawn()
.context("failed spawning nvcc")?
@ -175,16 +168,8 @@ fn set_cuda_include_dir() -> Result<()> {
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute cap from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Grab compute cap from nvidia-smi
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
@ -200,7 +185,6 @@ fn compute_cap() -> Result<usize> {
.next()
.context("missing line in stdout")?
.replace('.', "");
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};

View File

@ -2,10 +2,10 @@
Bert is a general large language model. In this example it can be used for two
different tasks:
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
## Sentence embeddings
Bert is used to compute the sentence embeddings for a prompt. The model weights
@ -24,48 +24,6 @@ cargo run --example bert --release -- --prompt "Here is a test sentence"
> Tensor[[1, 7, 384], f32]
```
### Custom models
You can specify different models, such as BGE, with the `--model-id` flag:
```bash
cargo run --example bert --release -- \
--model-id BAAI/bge-large-zh-v1.5 \
--prompt "Here is a test sentence"
Loaded and encoded 435.70775ms
[[[ 3.0944e-1, -7.8455e-5, -1.2768e0, ..., 1.3755e-2, -3.2371e-1, 2.3819e-1],
[-2.8506e-1, 1.9953e-1, -1.3076e0, ..., 6.9819e-2, 1.0833e-2, -1.1512e0],
[ 3.9892e-1, 2.0000e-1, -9.3178e-1, ..., -4.1393e-1, -4.9644e-2, -3.3786e-1],
...
[ 6.0345e-1, 3.5744e-1, -1.2672e0, ..., -6.9165e-1, -3.4973e-3, -8.4214e-1],
[ 3.9218e-1, -3.2735e-1, -1.3123e0, ..., -4.9318e-1, -5.1334e-1, -3.6391e-1],
[ 3.0978e-1, 2.5662e-4, -1.2773e0, ..., 1.3357e-2, -3.2390e-1, 2.3858e-1]]]
Tensor[[1, 9, 1024], f32]
Took 176.744667ms
```
### Gelu approximation
You can get a speedup by using an approximation of the gelu activation, with a
small loss of precision, by passing the `--approximate-gelu` flag:
```bash
$ cargo run --example bert --release -- \
--model-id BAAI/bge-large-zh-v1.5 \
--prompt "Here is a test sentence" \
--approximate-gelu
Loaded and encoded 244.388042ms
[[[ 3.1048e-1, -6.0339e-4, -1.2758e0, ..., 1.3718e-2, -3.2362e-1, 2.3775e-1],
[-2.8354e-1, 1.9984e-1, -1.3077e0, ..., 6.9390e-2, 9.9681e-3, -1.1531e0],
[ 3.9947e-1, 1.9917e-1, -9.3178e-1, ..., -4.1301e-1, -5.0719e-2, -3.3955e-1],
...
[ 6.0499e-1, 3.5664e-1, -1.2642e0, ..., -6.9134e-1, -3.4581e-3, -8.4471e-1],
[ 3.9311e-1, -3.2812e-1, -1.3105e0, ..., -4.9291e-1, -5.1270e-1, -3.6543e-1],
[ 3.1082e-1, -2.6737e-4, -1.2762e0, ..., 1.3319e-2, -3.2381e-1, 2.3815e-1]]]
Tensor[[1, 9, 1024], f32]
Took 116.840791ms
```
## Similarities
In this example, Bert is used to compute the sentence embeddings for a set of

View File

@ -3,7 +3,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, HiddenAct, DTYPE};
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{Error as E, Result};
use candle::Tensor;
@ -45,10 +45,6 @@ struct Args {
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
/// Use tanh based approximation for Gelu instead of erf implementation.
#[arg(long, default_value = "false")]
approximate_gelu: bool,
}
impl Args {
@ -77,7 +73,7 @@ impl Args {
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let mut config: Config = serde_json::from_str(&config)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
@ -85,9 +81,6 @@ impl Args {
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
if self.approximate_gelu {
config.hidden_act = HiddenAct::GeluApproximate;
}
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}

View File

@ -165,7 +165,14 @@ fn main() -> Result<()> {
args.revision,
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = repo.get(rfilename)?;
filenames.push(filename);
}
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

View File

@ -13,7 +13,7 @@ extern crate accelerate_src;
extern crate intel_mkl_src;
use anyhow::{bail, Error as E, Result};
use clap::{Parser, ValueEnum};
use clap::Parser;
use candle::{DType, Tensor};
use candle_nn::VarBuilder;
@ -22,19 +22,11 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write;
use candle_transformers::models::llama as model;
use model::{Llama, LlamaConfig};
use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>";
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
V1,
V2,
#[value(name = "solar-10.7b")]
Solar10_7B,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -42,6 +34,10 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Use npy instead of safetensors
#[arg(long)]
npy: Option<String>,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
@ -80,13 +76,17 @@ struct Args {
#[arg(long)]
revision: Option<String>,
/// The model size to use.
#[arg(long, default_value = "v2")]
which: Which,
#[arg(long)]
v1: bool,
#[arg(long)]
use_flash_attn: bool,
/// The folder name that contains safetensor weights and json files
/// (same structure as huggingface online)
#[arg(long)]
local_weights: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.0)]
repeat_penalty: f32,
@ -118,29 +118,65 @@ fn main() -> Result<()> {
Some(dtype) => bail!("Unsupported dtype {dtype}"),
None => DType::F16,
};
let (llama, tokenizer_filename, cache) = {
let (llama, tokenizer_filename, cache) = match args.npy {
Some(filename) => {
let config = if args.v1 {
Config::config_7b_v1(args.use_flash_attn)
} else {
Config::config_7b_v2(args.use_flash_attn)
};
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_npz(filename, dtype, &device)?;
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
(Llama::load(vb, &cache, &config)?, tokenizer, cache)
}
None => {
let api = Api::new()?;
let model_id = args.model_id.unwrap_or_else(|| match args.which {
Which::V1 => "Narsil/amall-7b".to_string(),
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
let model_id = args.model_id.unwrap_or_else(|| {
if args.v1 {
"Narsil/amall-7b".to_string()
} else {
"meta-llama/Llama-2-7b-hf".to_string()
}
});
println!("loading the model weights from {model_id}");
let revision = args.revision.unwrap_or("main".to_string());
let api = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = api.get("tokenizer.json")?;
let config_filename = api.get("config.json")?;
let tokenizer_filename = match &args.local_weights {
Some(path) => (path.to_owned() + "tokenizer.json").into(),
_ => api.get("tokenizer.json")?,
};
let config_filename = match &args.local_weights {
Some(path) => (path.to_owned() + "config.json").into(),
_ => api.get("config.json")?,
};
let config: LlamaConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let config = config.into_config(args.use_flash_attn);
let filenames =
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
match &args.local_weights {
Some(path) => {
filenames.push((path.to_owned() + rfilename).into());
}
_ => {
let filename = api.get(rfilename)?;
filenames.push(filename);
}
};
}
println!("building the model");
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
}
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);

View File

@ -143,7 +143,14 @@ fn main() -> Result<()> {
let config_filename = api.get("config.json")?;
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let tokenizer_filename = api.get("tokenizer.json")?;
let filenames = candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
let mut filenames = vec![];
for rfilename in [
"model-00001-of-00002.safetensors",
"model-00002-of-00002.safetensors",
] {
let filename = api.get(rfilename)?;
filenames.push(filename);
}
if args.rank.is_none() {
let children: Vec<_> = (0..args.num_shards)

View File

@ -1,12 +0,0 @@
# candle-mamba-minimal: minimal implementation of Mamba
This is based on [mamba-minimal](https://github.com/johnma2006/mamba-minimal).
## Running the example
```bash
$ cargo run --example mamba-minimal --release -- --prompt "Mamba is the"
Mamba is the most popular and best-selling game in the world. It has been downloaded more than 1,000 times by over 1 million people worldwide since its release on March 18th 2016.
The Mamba series of games are a collection that combines elements from all genres including action, adventure, strategy & puzzle games with some unique gameplay features such as stealth and survival. The game is also known for its innovative graphics and the ability to play in a variety of different modes like single player or multiplayer.
```

View File

@ -1,287 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
mod model;
use model::{Config, Model};
use candle::{DType, Device, Module, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for _ in 0..sample_len {
let input = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, ValueEnum, Clone, Copy, PartialEq, Eq, Debug)]
enum Which {
Mamba130m,
Mamba370m,
Mamba790m,
Mamba1_4b,
Mamba2_8b,
Mamba2_8bSlimPj,
}
impl std::fmt::Display for Which {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Which {
fn model_id(&self) -> &'static str {
match self {
Self::Mamba130m => "state-spaces/mamba-130m",
Self::Mamba370m => "state-spaces/mamba-370m",
Self::Mamba790m => "state-spaces/mamba-790m",
Self::Mamba1_4b => "state-spaces/mamba-1.4b",
Self::Mamba2_8b => "state-spaces/mamba-2.8b",
Self::Mamba2_8bSlimPj => "state-spaces/mamba-2.8b-slimpj'",
}
}
fn revision(&self) -> &'static str {
match self {
Self::Mamba130m
| Self::Mamba370m
| Self::Mamba790m
| Self::Mamba1_4b
| Self::Mamba2_8bSlimPj => "refs/pr/1",
Self::Mamba2_8b => "refs/pr/4",
}
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
sample_len: usize,
#[arg(long, default_value = "mamba130m")]
which: Which,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
config_file: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id
.unwrap_or_else(|| args.which.model_id().to_string()),
RepoType::Model,
args.revision
.unwrap_or_else(|| args.which.revision().to_string()),
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("EleutherAI/gpt-neox-20b".to_string())
.get("tokenizer.json")?,
};
let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => {
vec![repo.get("model.safetensors")?]
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = Model::new(&config, vb.pp("backbone"))?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,204 +0,0 @@
/// This follows the lines of:
/// https://github.com/johnma2006/mamba-minimal/blob/master/model.py
/// Simple, minimal implementation of Mamba in one file of PyTorch.
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::{RmsNorm, VarBuilder};
use candle_transformers::models::with_tracing::{linear, linear_no_bias, Linear};
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
d_model: usize,
n_layer: usize,
vocab_size: usize,
pad_vocab_size_multiple: usize,
}
impl Config {
fn vocab_size(&self) -> usize {
let pad = self.pad_vocab_size_multiple;
(self.vocab_size + pad - 1) / pad * pad
}
fn dt_rank(&self) -> usize {
(self.d_model + 15) / 16
}
fn d_conv(&self) -> usize {
4
}
fn d_state(&self) -> usize {
16
}
fn d_inner(&self) -> usize {
self.d_model * 2
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L177
#[derive(Clone, Debug)]
pub struct MambaBlock {
in_proj: Linear,
conv1d: candle_nn::Conv1d,
x_proj: Linear,
dt_proj: Linear,
a_log: Tensor,
d: Tensor,
out_proj: Linear,
dt_rank: usize,
}
impl MambaBlock {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let d_inner = cfg.d_inner();
let d_conv = cfg.d_conv();
let d_state = cfg.d_state();
let dt_rank = cfg.dt_rank();
let in_proj = linear_no_bias(cfg.d_model, d_inner * 2, vb.pp("in_proj"))?;
let conv_cfg = candle_nn::Conv1dConfig {
groups: d_inner,
padding: d_conv - 1,
..Default::default()
};
let conv1d = candle_nn::conv1d(d_inner, d_inner, d_conv, conv_cfg, vb.pp("conv1d"))?;
let x_proj = linear_no_bias(d_inner, dt_rank + d_state * 2, vb.pp("x_proj"))?;
let dt_proj = linear(dt_rank, d_inner, vb.pp("dt_proj"))?;
let a_log = vb.get((d_inner, d_state), "A_log")?;
let d = vb.get(d_inner, "D")?;
let out_proj = linear_no_bias(d_inner, cfg.d_model, vb.pp("out_proj"))?;
Ok(Self {
in_proj,
conv1d,
x_proj,
dt_proj,
a_log,
d,
out_proj,
dt_rank,
})
}
fn ssm(&self, xs: &Tensor) -> Result<Tensor> {
let (_d_in, n) = self.a_log.dims2()?;
let a = self.a_log.to_dtype(candle::DType::F32)?.exp()?.neg()?;
let d = self.d.to_dtype(candle::DType::F32)?;
let x_dbl = xs.apply(&self.x_proj)?;
let delta = x_dbl.narrow(D::Minus1, 0, self.dt_rank)?;
let b = x_dbl.narrow(D::Minus1, self.dt_rank, n)?;
let c = x_dbl.narrow(D::Minus1, self.dt_rank + n, n)?;
let delta = delta.contiguous()?.apply(&self.dt_proj)?;
// softplus without threshold
let delta = (delta.exp()? + 1.)?.log()?;
let ss = selective_scan(xs, &delta, &a, &b, &c, &d)?;
Ok(ss)
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L275
fn selective_scan(
u: &Tensor,
delta: &Tensor,
a: &Tensor,
b: &Tensor,
c: &Tensor,
d: &Tensor,
) -> Result<Tensor> {
let (b_sz, l, d_in) = u.dims3()?;
let n = a.dim(1)?;
let delta = delta.t()?.reshape((b_sz, d_in, l, 1))?; // b d_in l 1
let delta_a = delta.broadcast_mul(&a.reshape((1, d_in, 1, n))?)?.exp()?;
let delta_b_u = delta
.broadcast_mul(&b.reshape((b_sz, 1, l, n))?)?
.broadcast_mul(&u.t()?.reshape((b_sz, d_in, l, 1))?)?;
let mut xs = Tensor::zeros((b_sz, d_in, n), delta_a.dtype(), delta_a.device())?;
let mut ys = Vec::with_capacity(l);
for i in 0..l {
xs = ((delta_a.i((.., .., i))? * xs)? + delta_b_u.i((.., .., i))?)?;
let y = xs.matmul(&c.i((.., i, ..))?.unsqueeze(2)?)?.squeeze(2)?;
ys.push(y)
}
let ys = Tensor::stack(ys.as_slice(), 1)?;
ys + u.broadcast_mul(d)
}
impl Module for MambaBlock {
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L206
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b_sz, seq_len, _dim) = xs.dims3()?;
let xs_and_res = xs.apply(&self.in_proj)?.chunk(2, D::Minus1)?;
let (xs, res) = (&xs_and_res[0], &xs_and_res[1]);
let xs = xs
.t()?
.apply(&self.conv1d)?
.narrow(D::Minus1, 0, seq_len)?
.t()?;
let xs = candle_nn::ops::silu(&xs)?;
let ys = (self.ssm(&xs)? * candle_nn::ops::silu(res))?;
ys.apply(&self.out_proj)
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L143
#[derive(Clone, Debug)]
pub struct ResidualBlock {
mixer: MambaBlock,
norm: RmsNorm,
}
impl ResidualBlock {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm"))?;
let mixer = MambaBlock::new(cfg, vb.pp("mixer"))?;
Ok(Self { mixer, norm })
}
}
impl Module for ResidualBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.norm)?.apply(&self.mixer)? + xs
}
}
// https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/model.py#L56
#[derive(Clone, Debug)]
pub struct Model {
embedding: candle_nn::Embedding,
layers: Vec<ResidualBlock>,
norm_f: RmsNorm,
lm_head: Linear,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let embedding = candle_nn::embedding(cfg.vocab_size(), cfg.d_model, vb.pp("embedding"))?;
let mut layers = Vec::with_capacity(cfg.n_layer);
let vb_l = vb.pp("layers");
for layer_idx in 0..cfg.n_layer {
let layer = ResidualBlock::new(cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm_f = candle_nn::rms_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
let lm_head = Linear::from_weights(embedding.embeddings().clone(), None);
Ok(Self {
embedding,
layers,
norm_f,
lm_head,
})
}
}
impl Module for Model {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
let mut xs = self.embedding.forward(input_ids)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm_f)?
.apply(&self.lm_head)
}
}

View File

@ -155,8 +155,8 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "lmz/candle-mistral")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
@ -207,18 +207,8 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
if args.quantized {
"lmz/candle-mistral".to_string()
} else {
"mistralai/Mistral-7B-v0.1".to_string()
}
}
};
let repo = api.repo(Repo::with_revision(
model_id,
args.model_id,
RepoType::Model,
args.revision,
));
@ -235,7 +225,10 @@ fn main() -> Result<()> {
if args.quantized {
vec![repo.get("model-q4k.gguf")?]
} else {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
vec![
repo.get("pytorch_model-00001-of-00002.safetensors")?,
repo.get("pytorch_model-00002-of-00002.safetensors")?,
]
}
}
};

View File

@ -1,25 +0,0 @@
# candle-mixtral: 8x7b LLM using a sparse mixture of experts.
Mixtral-8x7B-v0.1 is a pretrained generative LLM with 56 billion parameters.
- [Blog post](https://mistral.ai/news/mixtral-of-experts/) from Mistral announcing the model release.
- [Model card](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) on the HuggingFace Hub.
## Running the example
```bash
$ cargo run --example mixtral --release -- --prompt "def print_prime(n): "
def print_prime(n): # n is the number of prime numbers to be printed
i = 2
count = 0
while (count < n):
if (isPrime(i)):
print(i)
count += 1
i += 1
def isPrime(n):
for x in range(2, int(n**0.5)+1):
if (n % x == 0):
...
```

View File

@ -1,241 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mixtral::{Config, Model};
use candle::{DType, Device, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
struct TextGeneration {
model: Model,
device: Device,
tokenizer: TokenOutputStream,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer: TokenOutputStream::new(tokenizer),
logits_processor,
repeat_penalty,
repeat_last_n,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
self.tokenizer.clear();
let mut tokens = self
.tokenizer
.tokenizer()
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
for &t in tokens.iter() {
if let Some(t) = self.tokenizer.next_token(t)? {
print!("{t}")
}
}
std::io::stdout().flush()?;
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_token("</s>") {
Some(token) => token,
None => anyhow::bail!("cannot find the </s> token"),
};
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
if let Some(t) = self.tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
let dt = start_gen.elapsed();
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
use_flash_attn: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long, default_value = "mistralai/Mixtral-8x7B-v0.1")]
model_id: String,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::v0_1_8x7b(args.use_flash_attn);
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,33 +1,14 @@
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
only 1.3 billion parameters but with state of the art performance compared to
models with up to 10 billion parameters.
The candle implementation provides both the standard version as well as a
quantized variant.
## Running some examples
## Running some example
For the v2 version.
```bash
$ cargo run --example phi --release -- --model 2 \
--prompt "A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?"
A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?
Solution:
The potential energy of the skier is converted into kinetic energy as it slides down the slope. The formula for potential energy is mgh, where m is mass, g is acceleration due to gravity (9.8 m/s^2), and h is height. Since there's no friction, all the potential energy is converted into kinetic energy at the bottom of the slope. The formula for kinetic energy is 1/2mv^2, where v is velocity. We can equate these two formulas:
mgh = 1/2mv^2
Solving for v, we get:
v = sqrt(2gh)
Substituting the given values, we get:
v = sqrt(2*9.8*40) = 28 m/s
Therefore, the skier speed at the bottom of the slope is 28 m/s.
```
For the v1.5 version.
```bash
$ cargo run --example phi --release -- --prompt "def print_prime(n): "

View File

@ -123,8 +123,6 @@ enum WhichModel {
V1,
#[value(name = "1.5")]
V1_5,
#[value(name = "2")]
V2,
PuffinPhiV2,
PhiHermes,
}
@ -145,10 +143,7 @@ struct Args {
verbose_prompt: bool,
#[arg(long)]
prompt: Option<String>,
#[arg(long)]
mmlu_dir: Option<String>,
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
@ -163,7 +158,7 @@ struct Args {
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 5000)]
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long)]
@ -230,7 +225,6 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::V2 => "microsoft/phi-2".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
@ -247,9 +241,7 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
WhichModel::V2 | WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"main".to_string()
}
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
}
}
}
@ -258,32 +250,27 @@ fn main() -> Result<()> {
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => repo.get("tokenizer.json")?,
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
},
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
None => {
if args.quantized {
match args.model {
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
WhichModel::V2 => candle_examples::hub_load_safetensors(
&repo,
"model.safetensors.index.json",
)?,
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
}
}
}
@ -295,33 +282,21 @@ fn main() -> Result<()> {
let config = match args.model {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
WhichModel::V2 => Config::v2(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
let model = match args.model {
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let model = QMixFormer::new(&config, vb)?;
(Model::Quantized(model), Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = match args.model {
WhichModel::V2 => MixFormer::new_v2(&config, vb)?,
_ => MixFormer::new(&config, vb)?,
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = MixFormer::new(&config, vb)?;
(Model::MixFormer(model), device)
};
println!("loaded the model in {:?}", start.elapsed());
match (args.prompt, args.mmlu_dir) {
(None, None) | (Some(_), Some(_)) => {
anyhow::bail!("exactly one of --prompt and --mmlu-dir must be specified")
}
(Some(prompt), None) => {
let mut pipeline = TextGeneration::new(
model,
tokenizer,
@ -333,89 +308,6 @@ fn main() -> Result<()> {
args.verbose_prompt,
&device,
);
pipeline.run(&prompt, args.sample_len)?;
}
(None, Some(mmlu_dir)) => mmlu(model, tokenizer, &device, mmlu_dir)?,
}
Ok(())
}
fn mmlu<P: AsRef<std::path::Path>>(
mut model: Model,
tokenizer: Tokenizer,
device: &Device,
mmlu_dir: P,
) -> anyhow::Result<()> {
for dir_entry in mmlu_dir.as_ref().read_dir()?.flatten() {
let dir_entry = dir_entry.path();
let theme = match dir_entry.file_stem().and_then(|v| v.to_str()) {
None => "".to_string(),
Some(v) => match v.strip_suffix("_test") {
None => v.replace('_', " "),
Some(v) => v.replace('_', " "),
},
};
if dir_entry.extension().as_ref().and_then(|v| v.to_str()) != Some("csv") {
continue;
}
println!("reading {dir_entry:?}");
let dir_entry = std::fs::File::open(dir_entry)?;
let mut reader = csv::ReaderBuilder::new()
.has_headers(false)
.from_reader(dir_entry);
let token_a = tokenizer.token_to_id("A").unwrap();
let token_b = tokenizer.token_to_id("B").unwrap();
let token_c = tokenizer.token_to_id("C").unwrap();
let token_d = tokenizer.token_to_id("D").unwrap();
for row in reader.records() {
let row = match row {
Err(_) => continue,
Ok(row) => row,
};
if row.len() < 5 {
continue;
}
let question = row.get(0).unwrap();
let answer_a = row.get(1).unwrap();
let answer_b = row.get(2).unwrap();
let answer_c = row.get(3).unwrap();
let answer_d = row.get(4).unwrap();
let answer = row.get(5).unwrap();
let prompt = format!(
"{} {theme}.\n{question}\nA. {answer_a}\nB. {answer_b}\nC. {answer_c}\nD. {answer_d}\nAnswer:\n",
"The following are multiple choice questions (with answers) about"
);
let tokens = tokenizer.encode(prompt.as_str(), true).map_err(E::msg)?;
let tokens = tokens.get_ids().to_vec();
let input = Tensor::new(tokens, device)?.unsqueeze(0)?;
let logits = match &mut model {
Model::MixFormer(m) => {
m.clear_kv_cache();
m.forward(&input)?
}
Model::Quantized(m) => {
m.clear_kv_cache();
m.forward(&input)?
}
};
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits_v: Vec<f32> = logits.to_vec1()?;
let pr_a = logits_v[token_a as usize];
let pr_b = logits_v[token_b as usize];
let pr_c = logits_v[token_c as usize];
let pr_d = logits_v[token_d as usize];
let model_answer = if pr_a > pr_b && pr_a > pr_c && pr_a > pr_d {
"A"
} else if pr_b > pr_c && pr_b > pr_d {
"B"
} else if pr_c > pr_d {
"C"
} else {
"D"
};
println!("{prompt}\n -> {model_answer} vs {answer}");
}
}
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -26,19 +26,6 @@ cargo run --example quantized --release -- --prompt "The best thing about coding
> The best thing about coding in rust is 1.) that I dont need to worry about memory leaks, 2.) speed and 3.) my program will compile even on old machines.
```
Using the mixtral sparse mixture of expert model:
```bash
$ cargo run --example quantized --release -- --which mixtral --prompt "Lebesgue's integral is superior to Riemann's because "
> avx: true, neon: false, simd128: false, f16c: true
> temp: 0.80 repeat-penalty: 1.10 repeat-last-n: 64
> loaded 995 tensors (26.44GB) in 0.03s
Lebesgue's integral is superior to Riemann's because 1. it is defined for a wider class of functions, those which are absolutely integrable; 2. the definition does not involve limits in two variables---one being computed before the other (which makes some computations more difficult); and 3. interchange of order of integration is easier to establish than with Riemann's integral. On the other hand, Lebesgue's integral applies only for bounded functions defined on finite intervals; it does not provide numerical values for improper integrals. The latter are best evaluated using Cauchy's limit definition.
The reason $f(x) = x^2$ is discontinuous at the ends of its interval of definition, and Riemann's integral requires continuity on the whole of an open interval containing it (see our earlier post), sine no such function exists with this property, is that the endpoints are infinite in measure for Lebesgue's integral.
```
## Command-line flags
Run with `--help` to see all options.

View File

@ -45,28 +45,16 @@ enum Which {
L13bCode,
#[value(name = "32b-code")]
L34bCode,
#[value(name = "7b-leo")]
Leo7b,
#[value(name = "13b-leo")]
Leo13b,
#[value(name = "7b-mistral")]
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
#[value(name = "7b-mistral-instruct-v0.2")]
Mistral7bInstructV02,
#[value(name = "7b-zephyr-a")]
Zephyr7bAlpha,
#[value(name = "7b-zephyr-b")]
Zephyr7bBeta,
#[value(name = "7b-open-chat-3.5")]
OpenChat35,
#[value(name = "7b-starling-a")]
Starling7bAlpha,
#[value(name = "mixtral")]
Mixtral,
#[value(name = "mixtral-instruct")]
MixtralInstruct,
}
impl Which {
@ -80,20 +68,14 @@ impl Which {
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b => false,
| Self::L34bCode => false,
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
// same way. Starling is a fine tuned version of OpenChat.
// same way.
Self::OpenChat35
| Self::Starling7bAlpha
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02 => true,
| Self::Mistral7bInstruct => true,
}
}
@ -108,44 +90,14 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::OpenChat35
| Self::Starling7bAlpha => false,
| Self::OpenChat35 => false,
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
}
}
fn is_open_chat(&self) -> bool {
match self {
Self::L7b
| Self::L13b
| Self::L70b
| Self::L7bChat
| Self::L13bChat
| Self::L70bChat
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode
| Self::Leo7b
| Self::Leo13b
| Self::Mixtral
| Self::MixtralInstruct
| Self::Mistral7b
| Self::Mistral7bInstruct
| Self::Mistral7bInstructV02
| Self::Zephyr7bAlpha
| Self::Zephyr7bBeta => false,
Self::OpenChat35 | Self::Starling7bAlpha => true,
}
}
fn tokenizer_repo(&self) -> &'static str {
match self {
Which::L7b
| Which::L13b
@ -155,18 +107,12 @@ impl Which {
| Which::L70bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => "hf-internal-testing/llama-tokenizer",
Which::Leo7b => "LeoLM/leo-hessianai-7b",
Which::Leo13b => "LeoLM/leo-hessianai-13b",
Which::Mixtral => "mistralai/Mixtral-8x7B-v0.1",
Which::MixtralInstruct => "mistralai/Mixtral-8x7B-Instruct-v0.1",
Which::Mistral7b
| Which::L34bCode
| Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta => "mistralai/Mistral-7B-v0.1",
Which::OpenChat35 => "openchat/openchat_3.5",
Which::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
| Which::Zephyr7bBeta => false,
Which::OpenChat35 => true,
}
}
}
@ -174,7 +120,7 @@ impl Which {
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// GGML/GGUF file to load, typically a .bin/.gguf file generated by the quantize command from llama.cpp
/// GGML file to load, typically a .bin file generated by the quantize command from llama.cpp
#[arg(long)]
model: Option<String>,
@ -235,7 +181,13 @@ impl Args {
Some(config) => std::path::PathBuf::from(config),
None => {
let api = hf_hub::api::sync::Api::new()?;
let repo = self.which.tokenizer_repo();
let repo = if self.which.is_open_chat() {
"openchat/openchat_3.5"
} else if self.which.is_mistral() {
"mistralai/Mistral-7B-v0.1"
} else {
"hf-internal-testing/llama-tokenizer"
};
let api = api.model(repo.to_string());
api.get("tokenizer.json")?
}
@ -266,22 +218,6 @@ impl Args {
Which::L7bCode => ("TheBloke/CodeLlama-7B-GGUF", "codellama-7b.Q8_0.gguf"),
Which::L13bCode => ("TheBloke/CodeLlama-13B-GGUF", "codellama-13b.Q8_0.gguf"),
Which::L34bCode => ("TheBloke/CodeLlama-34B-GGUF", "codellama-34b.Q8_0.gguf"),
Which::Leo7b => (
"TheBloke/leo-hessianai-7B-GGUF",
"leo-hessianai-7b.Q4_K_M.gguf",
),
Which::Leo13b => (
"TheBloke/leo-hessianai-13B-GGUF",
"leo-hessianai-13b.Q4_K_M.gguf",
),
Which::Mixtral => (
"TheBloke/Mixtral-8x7B-v0.1-GGUF",
"mixtral-8x7b-v0.1.Q4_K_M.gguf",
),
Which::MixtralInstruct => (
"TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF",
"mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf",
),
Which::Mistral7b => (
"TheBloke/Mistral-7B-v0.1-GGUF",
"mistral-7b-v0.1.Q4_K_S.gguf",
@ -290,10 +226,6 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
Which::Mistral7bInstructV02 => (
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF",
"mistral-7b-instruct-v0.2.Q4_K_S.gguf",
),
Which::Zephyr7bAlpha => (
"TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf",
@ -302,10 +234,6 @@ impl Args {
("TheBloke/zephyr-7B-beta-GGUF", "zephyr-7b-beta.Q4_K_M.gguf")
}
Which::OpenChat35 => ("TheBloke/openchat_3.5-GGUF", "openchat_3.5.Q4_K_M.gguf"),
Which::Starling7bAlpha => (
"TheBloke/Starling-LM-7B-alpha-GGUF",
"starling-lm-7b-alpha.Q4_K_M.gguf",
),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -364,7 +292,7 @@ fn main() -> anyhow::Result<()> {
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
Some("gguf") => {
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let model = gguf_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
@ -380,7 +308,7 @@ fn main() -> anyhow::Result<()> {
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
let model = ggml_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
@ -401,20 +329,14 @@ fn main() -> anyhow::Result<()> {
| Which::L13bChat
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode
| Which::Leo7b
| Which::Leo13b => 1,
Which::Mixtral
| Which::MixtralInstruct
| Which::Mistral7b
| Which::L34bCode => 1,
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Mistral7bInstructV02
| Which::Zephyr7bAlpha
| Which::Zephyr7bBeta
| Which::L70b
| Which::L70bChat
| Which::OpenChat35
| Which::Starling7bAlpha => 8,
| Which::OpenChat35 => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
@ -447,7 +369,7 @@ fn main() -> anyhow::Result<()> {
}
}
if args.which.is_open_chat() {
format!("GPT4 Correct User: {prompt}<|end_of_turn|>GPT4 Correct Assistant:")
format!("User: {prompt}<|end_of_turn|>Assistant: ")
} else if args.which.is_zephyr() {
if prompt_index == 0 || is_interactive {
format!("<|system|>\n</s>\n<|user|>\n{prompt}</s>\n<|assistant|>",)

View File

@ -78,7 +78,7 @@ class EpisodicLifeEnv(gym.Wrapper):
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert sometimes we stay in lives == 0 condition for a few frames
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True

View File

@ -8,7 +8,7 @@ XL using Rust and [candle](https://github.com/huggingface/candle).
The `stable-diffusion` example is a conversion of
[diffusers-rs](https://github.com/LaurentMazare/diffusers-rs) using candle
rather than libtorch. This implementation supports Stable Diffusion v1.5, v2.1,
as well as Stable Diffusion XL 1.0, and Turbo.
as well as Stable Diffusion XL 1.0.
## Getting the weights
@ -23,26 +23,16 @@ cargo run --example stable-diffusion --release --features=cuda,cudnn \
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def)"
```
The final image is named `sd_final.png` by default. The Turbo version is much
faster than previous versions, to give it a try add a `--sd-version turbo` flag,
e.g.:
```bash
cargo run --example stable-diffusion --release --features=cuda,cudnn \
-- --prompt "a cosmonaut on a horse (hd, realistic, high-def) --sd-version turbo"
```
The default scheduler for the v1.5, v2.1 and XL 1.0 version is the Denoising
Diffusion Implicit Model scheduler (DDIM). The original paper and some code can
be found in the [associated repo](https://github.com/ermongroup/ddim).
The default scheduler for the XL Turbo version is the Euler Ancestral scheduler.
The final image is named `sd_final.png` by default.
The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The
original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim).
### Command-line flags
- `--prompt`: the prompt to be used to generate the image.
- `--uncond-prompt`: the optional unconditional prompt.
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`,
`xl`, or `turbo`.
- `--sd-version`: the Stable Diffusion version to use, can be `v1-5`, `v2-1`, or
`xl`.
- `--cpu`: use the cpu rather than the gpu (much slower).
- `--height`, `--width`: set the height and width for the generated image.
- `--n-steps`: the number of steps to be used in the diffusion process.

View File

@ -11,6 +11,8 @@ use candle::{DType, Device, IndexOp, Module, Tensor, D};
use clap::Parser;
use tokenizers::Tokenizer;
const GUIDANCE_SCALE: f64 = 7.5;
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -61,8 +63,8 @@ struct Args {
sliced_attention_size: Option<usize>,
/// The number of steps to run the diffusion for.
#[arg(long)]
n_steps: Option<usize>,
#[arg(long, default_value_t = 30)]
n_steps: usize,
/// The number of samples to generate.
#[arg(long, default_value_t = 1)]
@ -85,9 +87,6 @@ struct Args {
#[arg(long)]
use_f16: bool,
#[arg(long)]
guidance_scale: Option<f64>,
#[arg(long, value_name = "FILE")]
img2img: Option<String>,
@ -103,7 +102,6 @@ enum StableDiffusionVersion {
V1_5,
V2_1,
Xl,
Turbo,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -122,13 +120,12 @@ impl StableDiffusionVersion {
Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0",
Self::V2_1 => "stabilityai/stable-diffusion-2-1",
Self::V1_5 => "runwayml/stable-diffusion-v1-5",
Self::Turbo => "stabilityai/sdxl-turbo",
}
}
fn unet_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"unet/diffusion_pytorch_model.fp16.safetensors"
} else {
@ -140,7 +137,7 @@ impl StableDiffusionVersion {
fn vae_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"vae/diffusion_pytorch_model.fp16.safetensors"
} else {
@ -152,7 +149,7 @@ impl StableDiffusionVersion {
fn clip_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder/model.fp16.safetensors"
} else {
@ -164,7 +161,7 @@ impl StableDiffusionVersion {
fn clip2_file(&self, use_f16: bool) -> &'static str {
match self {
Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => {
Self::V1_5 | Self::V2_1 | Self::Xl => {
if use_f16 {
"text_encoder_2/model.fp16.safetensors"
} else {
@ -192,7 +189,7 @@ impl ModelFile {
StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => {
"openai/clip-vit-base-patch32"
}
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => {
StableDiffusionVersion::Xl => {
// This seems similar to the patch32 version except some very small
// difference in the split regex.
"openai/clip-vit-large-patch14"
@ -209,11 +206,7 @@ impl ModelFile {
Self::Vae => {
// Override for SDXL when using f16 weights.
// See https://github.com/huggingface/candle/issues/1060
if matches!(
version,
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo,
) && use_f16
{
if version == StableDiffusionVersion::Xl && use_f16 {
(
"madebyollin/sdxl-vae-fp16-fix",
"diffusion_pytorch_model.safetensors",
@ -268,7 +261,6 @@ fn text_embeddings(
use_f16: bool,
device: &Device,
dtype: DType,
use_guide_scale: bool,
first: bool,
) -> Result<Tensor> {
let tokenizer_file = if first {
@ -293,6 +285,16 @@ fn text_embeddings(
}
let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?;
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
println!("Building the Clip transformer.");
let clip_weights_file = if first {
ModelFile::Clip
@ -308,24 +310,8 @@ fn text_embeddings(
let text_model =
stable_diffusion::build_clip_transformer(clip_config, clip_weights, device, DType::F32)?;
let text_embeddings = text_model.forward(&tokens)?;
let text_embeddings = if use_guide_scale {
let mut uncond_tokens = tokenizer
.encode(uncond_prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
while uncond_tokens.len() < sd_config.clip.max_position_embeddings {
uncond_tokens.push(pad_id)
}
let uncond_tokens = Tensor::new(uncond_tokens.as_slice(), device)?.unsqueeze(0)?;
let uncond_embeddings = text_model.forward(&uncond_tokens)?;
Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?
} else {
text_embeddings.to_dtype(dtype)?
};
let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(dtype)?;
Ok(text_embeddings)
}
@ -370,7 +356,6 @@ fn run(args: Args) -> Result<()> {
unet_weights,
tracing,
use_f16,
guidance_scale,
use_flash_attn,
img2img,
img2img_strength,
@ -389,24 +374,6 @@ fn run(args: Args) -> Result<()> {
None
};
let guidance_scale = match guidance_scale {
Some(guidance_scale) => guidance_scale,
None => match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 7.5,
StableDiffusionVersion::Turbo => 0.,
},
};
let n_steps = match n_steps {
Some(n_steps) => n_steps,
None => match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 30,
StableDiffusionVersion::Turbo => 1,
},
};
let dtype = if use_f16 { DType::F16 } else { DType::F32 };
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
@ -418,19 +385,13 @@ fn run(args: Args) -> Result<()> {
StableDiffusionVersion::Xl => {
stable_diffusion::StableDiffusionConfig::sdxl(sliced_attention_size, height, width)
}
StableDiffusionVersion::Turbo => stable_diffusion::StableDiffusionConfig::sdxl_turbo(
sliced_attention_size,
height,
width,
),
};
let scheduler = sd_config.build_scheduler(n_steps)?;
let device = candle_examples::device(cpu)?;
let use_guide_scale = guidance_scale > 1.0;
let which = match sd_version {
StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => vec![true, false],
StableDiffusionVersion::Xl => vec![true, false],
_ => vec![true],
};
let text_embeddings = which
@ -446,12 +407,10 @@ fn run(args: Args) -> Result<()> {
use_f16,
&device,
dtype,
use_guide_scale,
*first,
)
})
.collect::<Result<Vec<_>>>()?;
let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?;
println!("{text_embeddings:?}");
@ -475,19 +434,11 @@ fn run(args: Args) -> Result<()> {
0
};
let bsize = 1;
let vae_scale = match sd_version {
StableDiffusionVersion::V1_5
| StableDiffusionVersion::V2_1
| StableDiffusionVersion::Xl => 0.18215,
StableDiffusionVersion::Turbo => 0.13025,
};
for idx in 0..num_samples {
let timesteps = scheduler.timesteps();
let latents = match &init_latent_dist {
Some(init_latent_dist) => {
let latents = (init_latent_dist.sample()? * vae_scale)?.to_device(&device)?;
let latents = (init_latent_dist.sample()? * 0.18215)?.to_device(&device)?;
if t_start < timesteps.len() {
let noise = latents.randn_like(0f64, 1f64)?;
scheduler.add_noise(&latents, noise, timesteps[t_start])?
@ -514,31 +465,21 @@ fn run(args: Args) -> Result<()> {
continue;
}
let start_time = std::time::Instant::now();
let latent_model_input = if use_guide_scale {
Tensor::cat(&[&latents, &latents], 0)?
} else {
latents.clone()
};
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)?;
let noise_pred =
unet.forward(&latent_model_input, timestep as f64, &text_embeddings)?;
let noise_pred = if use_guide_scale {
let noise_pred = noise_pred.chunk(2, 0)?;
let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]);
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * guidance_scale)?)?
} else {
noise_pred
};
let noise_pred =
(noise_pred_uncond + ((noise_pred_text - noise_pred_uncond)? * GUIDANCE_SCALE)?)?;
latents = scheduler.step(&noise_pred, timestep, &latents)?;
let dt = start_time.elapsed().as_secs_f32();
println!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt);
if args.intermediary_images {
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = vae.decode(&(&latents / 0.18215)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename =
@ -552,7 +493,7 @@ fn run(args: Args) -> Result<()> {
idx + 1,
num_samples
);
let image = vae.decode(&(&latents / vae_scale)?)?;
let image = vae.decode(&(&latents / 0.18215)?)?;
let image = ((image / 2.)? + 0.5)?.to_device(&Device::Cpu)?;
let image = (image.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?.i(0)?;
let image_filename = output_filename(&final_image, idx + 1, num_samples, None);

View File

@ -96,9 +96,25 @@ impl T5ModelBuilder {
let api = api.repo(repo);
let config_filename = api.get("config.json")?;
let tokenizer_filename = api.get("tokenizer.json")?;
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
{
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
let weights_filename = if model_id == "google/flan-t5-xxl" {
vec![
api.get("model-00001-of-00005.safetensors")?,
api.get("model-00002-of-00005.safetensors")?,
api.get("model-00003-of-00005.safetensors")?,
api.get("model-00004-of-00005.safetensors")?,
api.get("model-00005-of-00005.safetensors")?,
]
} else if model_id == "google/flan-ul2" {
vec![
api.get("model-00001-of-00008.safetensors")?,
api.get("model-00002-of-00008.safetensors")?,
api.get("model-00003-of-00008.safetensors")?,
api.get("model-00004-of-00008.safetensors")?,
api.get("model-00005-of-00008.safetensors")?,
api.get("model-00006-of-00008.safetensors")?,
api.get("model-00007-of-00008.safetensors")?,
api.get("model-00008-of-00008.safetensors")?,
]
} else {
vec![api.get("model.safetensors")?]
};

View File

@ -218,7 +218,21 @@ fn main() -> Result<()> {
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
None => match args.which {
Which::L6b => vec![
repo.get("model-00001-of-00002.safetensors")?,
repo.get("model-00002-of-00002.safetensors")?,
],
Which::L34b => vec![
repo.get("model-00001-of-00007.safetensors")?,
repo.get("model-00002-of-00007.safetensors")?,
repo.get("model-00003-of-00007.safetensors")?,
repo.get("model-00004-of-00007.safetensors")?,
repo.get("model-00005-of-00007.safetensors")?,
repo.get("model-00006-of-00007.safetensors")?,
repo.get("model-00007-of-00007.safetensors")?,
],
},
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

View File

@ -117,30 +117,3 @@ pub fn save_image_resize<P: AsRef<std::path::Path>>(
image.save(p).map_err(candle::Error::wrap)?;
Ok(())
}
/// Loads the safetensors files for a model from the hub based on a json index file.
pub fn hub_load_safetensors(
repo: &hf_hub::api::sync::ApiRepo,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
let json_file = std::fs::File::open(json_file)?;
let json: serde_json::Value =
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file.to_string());
}
}
let safetensors_files = safetensors_files
.iter()
.map(|v| repo.get(v).map_err(candle::Error::wrap))
.collect::<Result<Vec<_>>>()?;
Ok(safetensors_files)
}

View File

@ -1,6 +1,6 @@
[package]
name = "candle-flash-attn"
version = "0.3.3"
version = "0.3.1"
edition = "2021"
description = "Flash attention layer for the candle ML framework."
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
readme = "README.md"
[dependencies]
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.3", package = "candle-core" }
candle = { path = "../candle-core", features = ["cuda"], version = "0.3.1", package = "candle-core" }
half = { version = "2.3.1", features = ["num-traits"] }
[build-dependencies]
@ -21,4 +21,4 @@ rayon = "1.7.0"
[dev-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
candle-nn = { path = "../candle-nn", version = "0.3.3", features = ["cuda"] }
candle-nn = { path = "../candle-nn", version = "0.3.1", features = ["cuda"] }

View File

@ -1,6 +1,6 @@
[package]
name = "candle-kernels"
version = "0.3.3"
version = "0.3.1"
edition = "2021"
description = "CUDA kernels for Candle"

View File

@ -1,6 +1,6 @@
[package]
name = "candle-metal-kernels"
version = "0.3.3"
version = "0.3.1"
edition = "2021"
description = "Metal kernels for Candle"
@ -10,7 +10,7 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
metal = { version = "0.27.0", features = ["mps"]}
metal = { version = "0.27.1", features = ["mps"], package="candle-metal" }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -29,7 +29,9 @@ kernel void FN_NAME( \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(float(input[id]) * mul + add); \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[id] * m + a; \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
@ -45,80 +47,15 @@ kernel void FN_NAME##_strided( \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(float(input[get_strided_index(id, num_dims, dims, strides)]) * mul + add); \
}
#define POWF(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \
}
#define ELU(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME x = input[id]; \
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
} \
kernel void FN_NAME##_strided( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant float &mul, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint id [[ thread_position_in_grid ]] \
) { \
if (id >= dim) { \
return; \
} \
const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
const TYPENAME m = TYPENAME(mul); \
const TYPENAME a = TYPENAME(add); \
output[id] = input[get_strided_index(id, num_dims, dims, strides)] * m + a; \
} \
AFFINE(affine_f32, float)
AFFINE(affine_f16, half)
POWF(powf_f32, float)
POWF(powf_f16, half)
ELU(elu_f32, float)
ELU(elu_f16, half)
AFFINE(affine_float, float)
AFFINE(affine_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bf16, bfloat);
POWF(powf_bf16, bfloat);
ELU(elu_bf16, bfloat);
AFFINE(affine_bfloat, bfloat);
#endif

View File

@ -1,8 +1,5 @@
#include <metal_stdlib>
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
@ -25,15 +22,15 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device OUT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[tid]; \
TYPENAME y = right[tid]; \
output[tid] = OUT_TYPENAME(FN); \
TYPENAME x = left[thread_position_in_grid]; \
TYPENAME y = right[thread_position_in_grid]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -43,48 +40,33 @@ kernel void FN_NAME_STRIDED( \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device OUT_TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
device TYPENAME *output, \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \
output[tid] = OUT_TYPENAME(FN); \
TYPENAME x = left[get_strided_index(thread_position_in_grid, num_dims, dims, left_strides)]; \
TYPENAME y = right[get_strided_index(thread_position_in_grid, num_dims, dims, right_strides)]; \
output[thread_position_in_grid] = OUT_TYPENAME(FN); \
}
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided);
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided);
#define BINARY_OP_OUT(NAME, FN) \
BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \
BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided);
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
BINARY_OP(MIN(x, y), min)
BINARY_OP(MAX(x, y), max)
BINARY_OP_OUT(eq, x == y)
BINARY_OP_OUT(ne, x != y)
BINARY_OP_OUT(le, x <= y)
BINARY_OP_OUT(lt, x < y)
BINARY_OP_OUT(ge, x >= y)
BINARY_OP_OUT(gt, x > y)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
BFLOAT_BINARY_OP(MIN(x, y), min)
BFLOAT_BINARY_OP(MAX(x, y), max)
#endif

View File

@ -48,7 +48,6 @@ kernel void FN_NAME_STRIDED( \
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)

View File

@ -1,213 +0,0 @@
template <typename T>
METAL_FUNC void im2col(
constant size_t &dst_numel,
constant size_t &h_out,
constant size_t &w_out,
constant size_t &h_k,
constant size_t &w_k,
constant size_t &stride,
constant size_t &padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// dst: (b_size, h_out, w_out, c_in, h_k, w_k)
// src: (b_size, c_in, h_in, w_in)
if (tid >= dst_numel) {
return;
}
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t h_in = src_dims[2];
const size_t w_in = src_dims[3];
const size_t dst_s4 = w_k;
const size_t dst_s3 = h_k * dst_s4;
const size_t dst_s2 = c_in * dst_s3;
const size_t dst_s1 = w_out * dst_s2;
const size_t dst_s0 = h_out * dst_s1;
size_t tmp_tid = tid;
const size_t b_idx = tmp_tid / dst_s0;
tmp_tid -= b_idx * dst_s0;
const size_t h_idx = tmp_tid / dst_s1;
tmp_tid -= h_idx * dst_s1;
const size_t w_idx = tmp_tid / dst_s2;
tmp_tid -= w_idx * dst_s2;
const size_t c_idx = tmp_tid / dst_s3;
tmp_tid -= c_idx * dst_s3;
const size_t h_k_idx = tmp_tid / dst_s4;
tmp_tid -= h_k_idx * dst_s4;
const size_t w_k_idx = tmp_tid;
size_t src_h_idx = h_idx * stride + h_k_idx * dilation;
size_t src_w_idx = w_idx * stride + w_k_idx * dilation;
if (src_h_idx < padding || src_h_idx >= h_in + padding) {
dst[tid] = static_cast<T>(0);
}
else if (src_w_idx < padding || src_w_idx >= w_in + padding) {
dst[tid] = static_cast<T>(0);
}
else {
src_h_idx -= padding;
src_w_idx -= padding;
const size_t src_i =
b_idx * src_strides[0]
+ c_idx * src_strides[1]
+ src_h_idx * src_strides[2]
+ src_w_idx * src_strides[3];
dst[tid] = src[src_i];
}
}
template <typename T>
METAL_FUNC void im2col1d(
constant size_t &dst_numel,
constant size_t &l_out,
constant size_t &l_k,
constant size_t &stride,
constant size_t &padding,
constant size_t &dilation,
constant size_t *src_dims,
constant size_t *src_strides,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// dst: (b_size, l_out, c_in, l_k)
// src: (b_size, c_in, l_in)
if (tid >= dst_numel) {
return;
}
const size_t b_in = src_dims[0];
const size_t c_in = src_dims[1];
const size_t l_in = src_dims[2];
const size_t dst_s2 = l_k;
const size_t dst_s1 = c_in * dst_s2;
const size_t dst_s0 = l_out * dst_s1;
size_t tmp_dst_i = tid;
const size_t b_idx = tmp_dst_i / dst_s0;
tmp_dst_i -= b_idx * dst_s0;
const size_t l_idx = tmp_dst_i / dst_s1;
tmp_dst_i -= l_idx * dst_s1;
const size_t c_idx = tmp_dst_i / dst_s2;
tmp_dst_i -= c_idx * dst_s2;
const size_t l_k_idx = tmp_dst_i;
size_t src_l_idx = l_idx * stride + l_k_idx * dilation;
if (src_l_idx < padding || src_l_idx >= l_in + padding) {
dst[tid] = static_cast<T>(0);
}
else {
src_l_idx -= padding;
const size_t src_i = b_idx * src_strides[0] + c_idx * src_strides[1] + src_l_idx * src_strides[2];
dst[tid] = src[src_i];
}
}
template <typename T>
METAL_FUNC void upsample_nearest2d(
constant size_t &w_out,
constant size_t &h_out,
constant float &w_scale,
constant float &h_scale,
constant size_t *src_dims,
constant size_t *src_s,
device const T *src,
device T *dst,
uint tid [[ thread_position_in_grid ]]
) {
// src: (b_size, c_in, w_in, h_in)
const size_t c = src_dims[1];
const size_t w_in = src_dims[2];
const size_t h_in = src_dims[3];
if (tid >= src_dims[0] * c * w_out * h_out) {
return;
}
// TODO: Improve this.
const size_t b_idx = tid / (w_out * h_out * c);
const size_t c_idx = (tid / (w_out * h_out)) % c;
const size_t dst_w = (tid / h_out) % w_out;
const size_t dst_h = tid % h_out;
size_t src_w = static_cast<size_t>(dst_w * w_scale);
size_t src_h = static_cast<size_t>(dst_h * h_scale);
if (src_w >= w_in) {
src_w = w_in - 1;
}
if (src_h >= h_in) {
src_h = h_in - 1;
}
const size_t src_i = b_idx * src_s[0] + c_idx * src_s[1] + src_w * src_s[2] + src_h * src_s[3];
dst[tid] = src[src_i];
}
#define IM2COL_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
constant size_t &h_out, \
constant size_t &w_out, \
constant size_t &h_k, \
constant size_t &w_k, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
im2col<T>(dst_numel, h_out, w_out, h_k, w_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
#define IM2COL1D_OP(T, FN_NAME) \
kernel void FN_NAME( \
constant size_t &dst_numel, \
constant size_t &l_out, \
constant size_t &l_k, \
constant size_t &stride, \
constant size_t &padding, \
constant size_t &dilation, \
constant size_t *src_dims, \
constant size_t *src_strides, \
device const T *src, \
device T *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
} \
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
kernel void FN_NAME( \
constant size_t &w_out, \
constant size_t &h_out, \
constant float &w_scale, \
constant float &h_scale, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *src, \
device TYPENAME *dst, \
uint tid [[ thread_position_in_grid ]] \
) { \
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
} \
IM2COL_OP(float, im2col_f32)
IM2COL_OP(uint8_t, im2col_u8)
IM2COL_OP(uint32_t, im2col_u32)
IM2COL1D_OP(float, im2col1d_f32)
IM2COL1D_OP(uint8_t, im2col1d_u8)
IM2COL1D_OP(uint32_t, im2col1d_u32)
UPSAMPLE_NEAREST2D_OP(float, upsample_nearest2d_f32)
UPSAMPLE_NEAREST2D_OP(uint8_t, upsample_nearest2d_u8)
UPSAMPLE_NEAREST2D_OP(uint32_t, upsample_nearest2d_u32)

View File

@ -1,34 +1,6 @@
#include <metal_stdlib>
using namespace metal;
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t id_i = (tid / right_size) % ids_size;
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
/*
// Force prevent out of bounds indexing
// since there doesn't seem to be a good way to force crash
// No need to check for zero we're only allowing unsized.
*/
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
output[tid] = input[src_i];
}
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
@ -39,160 +11,93 @@ kernel void NAME( \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
uint gid [[ thread_position_in_grid ]] \
) { \
index<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
if (gid >= dst_size) { \
return; \
} \
const size_t id_i = (gid / right_size) % ids_size; \
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1)); \
const size_t right_rank_i = gid % right_size; \
const size_t left_rank_i = gid / right_size / ids_size; \
/* \
// Force prevent out of bounds indexing \
// since there doesn't seem to be a good way to force crash \
// No need to check for zero we're only allowing unsized. \
*/ \
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i; \
output[gid] = input[src_i]; \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void gather(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &ids_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
uint gid [[ thread_position_in_grid ]] \
) {
if (tid >= dst_size) {
if (gid >= left_size * right_size) {
return;
}
const INDEX_TYPENAME input_i = input_ids[tid];
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size / ids_size;
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
output[tid] = input[src_i];
}
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &ids_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void scatter_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < src_dim_size; ++j) {
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const INDEX_TYPENAME idx = input_ids[src_i];
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &dst_dim_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
scatter_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
}
template<typename TYPENAME, typename INDEX_TYPENAME>
METAL_FUNC void index_add(
constant size_t &dst_size,
constant size_t &left_size,
constant size_t &src_dim_size,
constant size_t &right_size,
constant size_t &dst_dim_size,
constant size_t &ids_dim_size,
const device TYPENAME *input,
const device INDEX_TYPENAME *input_ids,
device TYPENAME *output,
uint tid [[ thread_position_in_grid ]]
) {
if (tid >= dst_size) {
return;
}
const size_t right_rank_i = tid % right_size;
const size_t left_rank_i = tid / right_size;
for (unsigned int j = 0; j < ids_dim_size; ++j) {
const INDEX_TYPENAME idx = input_ids[j];
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
output[dst_i] += input[src_i];
}
}
# define INDEX_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
kernel void NAME( \
constant size_t &dst_size, \
constant size_t &left_size, \
constant size_t &src_dim_size, \
constant size_t &right_size, \
constant size_t &dst_dim_size, \
constant size_t &ids_dim_size, \
const device TYPENAME *input, \
const device INDEX_TYPENAME *input_ids, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
) { \
index_add<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, ids_dim_size, input, input_ids, output, tid); \
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint gid [[ thread_position_in_grid ]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, gid); } \
INDEX_OP(is_u32_f32, uint, float)
INDEX_OP(is_u32_f16, uint, half)
GATHER_OP(gather_u32_f32, uint, float)
GATHER_OP(gather_u32_f16, uint, half)
SCATTER_ADD_OP(sa_u32_f32, uint, float)
SCATTER_ADD_OP(sa_u32_f16, uint, half)
#if __METAL_VERSION__ >= 310
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
INDEX_ADD_OP(ia_u8_f16, uint8_t, half)
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
INDEX_ADD_OP(ia_i64_u8, int64_t, uint8_t)
INDEX_ADD_OP(ia_i64_i64, int64_t, int64_t)
INDEX_ADD_OP(ia_i64_u32, int64_t, uint32_t)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
INDEX_ADD_OP(ia_u32_f32, uint32_t, float)
INDEX_ADD_OP(ia_u32_u8, uint32_t, uint8_t)
INDEX_ADD_OP(ia_u32_i64, uint32_t, int64_t)
INDEX_ADD_OP(ia_u32_u32, uint32_t, uint32_t)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
INDEX_ADD_OP(ia_u8_f32, uint8_t, float)
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
INDEX_ADD_OP(ia_u8_i64, uint8_t, int64_t)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,6 @@
using namespace metal;
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
METAL_FUNC uint get_strided_index(
uint idx,
@ -19,132 +18,11 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
constant int THREADGROUP_SIZE = 2048;
constant int THREADGROUP_SIZE = 1024;
#define ARGMIN(NAME, T, MAXVALUE) \
# define REDUCE(FN, NAME, T) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MAXVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
bool notset = true; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || src[strided_i] < shared_memory[tid]) { \
shared_memory[tid] = src[strided_i]; \
/* Assume that the reduction takes place over the last dimension which is contiguous. */ \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] < shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define ARGMAX(NAME, T, MINVALUE) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device uint *dst, \
uint id [[ thread_position_in_grid ]], \
uint tid [[ thread_index_in_threadgroup ]], \
uint dst_id [[ threadgroup_position_in_grid ]], \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup uint shared_indices[THREADGROUP_SIZE]; \
\
shared_memory[tid] = MINVALUE; \
shared_indices[tid] = 0xFFFFFFFF; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t idx = start_idx + tid; \
bool notset = true; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
if (notset || shared_memory[tid] < src[strided_i]) { \
shared_memory[tid] = src[strided_i]; \
shared_indices[tid] = idx % dims[num_dims - 1]; \
notset = false; \
} \
idx += block_dim; \
} \
\
threadgroup_barrier(mem_flags::mem_none); \
\
/* \
// reduction in shared memory \
*/ \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s && shared_memory[tid + s] > shared_memory[tid]) { \
shared_indices[tid] = shared_indices[tid + s]; \
shared_memory[tid] = shared_memory[tid + s]; \
} \
threadgroup_barrier(mem_flags::mem_none); \
} \
\
if (tid == 0){ \
dst[dst_id] = shared_indices[0]; \
} \
} \
#define REDUCE(FN, NAME, T, START) \
kernel void NAME( \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
constant size_t &src_numel, \
constant size_t &el_to_sum_per_block, \
device const T *src, \
device T *dst, \
@ -154,23 +32,23 @@ kernel void NAME( \
uint block_dim [[ threads_per_threadgroup ]] \
) { \
\
threadgroup T shared_memory[THREADGROUP_SIZE]; \
threadgroup float shared_memory[THREADGROUP_SIZE]; \
\
shared_memory[tid] = START; \
shared_memory[tid] = 0; \
/* \
// Elements summed in this block range from dst_id * el_to_sum_per_block \
// to (dst_id + 1) * el_to_sum_per_block. \
*/ \
size_t start_idx = dst_id * el_to_sum_per_block; \
size_t stop_idx = start_idx + el_to_sum_per_block; \
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
while (idx < stop_idx) { \
/* \
// TODO: Fast version for the contiguous case. \
// size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
*/ \
size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \
T x = shared_memory[tid]; \
T y = src[strided_i]; \
T y = src[idx]; \
shared_memory[tid] = FN; \
idx += block_dim; \
} \
@ -193,6 +71,10 @@ kernel void NAME( \
} \
REDUCE(x + y, fast_sum_float, float)
REDUCE(x * y, fast_mul_float, float)
REDUCE(max(x, y), fast_max_float, float)
#define SOFTMAX(NAME, T) \
kernel void NAME( \
constant size_t &src_numel, \
@ -211,13 +93,12 @@ kernel void NAME(
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel); \
size_t idx = start_idx + tid; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float tmp = -INFINITY; \
while (idx < stop_idx) { \
tmp = MAX(tmp, float(src[idx])); \
shared_memory[tid] = MAX(shared_memory[tid], src[idx]); \
idx += block_dim; \
} \
shared_memory[tid] = tmp; \
\
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
@ -225,26 +106,21 @@ kernel void NAME(
if (tid < s) { \
shared_memory[tid] = MAX(shared_memory[tid], shared_memory[tid + s]); \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
/* wait for shared_memory[0] to be filled */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
\
float _max = shared_memory[0]; \
\
/* prevent tid=0 from overwriting _max before other threads have written it */ \
threadgroup_barrier(mem_flags::mem_threadgroup); \
shared_memory[tid] = 0; \
\
idx = start_idx + tid; \
while (idx < stop_idx) { \
const float val = exp(float(src[idx]) - _max); \
dst[idx] = T(val); \
const T val = T(exp(src[idx] - _max)); \
dst[idx] = val; \
shared_memory[tid] += val; \
idx += block_dim; \
} \
threadgroup_barrier(mem_flags::mem_threadgroup); \
for (uint s = block_dim / 2; s > 0; s >>= 1) { \
if (tid < s) { \
shared_memory[tid] += shared_memory[tid + s]; \
@ -252,7 +128,7 @@ kernel void NAME(
threadgroup_barrier(mem_flags::mem_threadgroup); \
} \
\
const T inv_acc = T(1.0/shared_memory[0]); \
const T inv_acc = T(1/shared_memory[0]); \
idx = start_idx + tid; \
while (idx < stop_idx) { \
dst[idx] *= inv_acc; \
@ -260,33 +136,8 @@ kernel void NAME(
} \
} \
REDUCE(x + y, fast_sum_f32_strided, float, 0)
REDUCE(x + y, fast_sum_u32_strided, uint, 0)
REDUCE(x + y, fast_sum_f16_strided, half, 0)
REDUCE(x * y, fast_mul_f32_strided, float, 1)
REDUCE(x * y, fast_mul_u32_strided, uint, 1)
REDUCE(x * y, fast_mul_f16_strided, half, 1)
REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF)
REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0)
REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH)
REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF)
REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF)
REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF)
ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH)
ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF)
ARGMAX(fast_argmax_f32_strided, float, -HUGE_VALF)
ARGMAX(fast_argmax_f16_strided, half, -HUGE_VALH)
ARGMAX(fast_argmax_u32_strided, uint, 0)
SOFTMAX(softmax_f32, float)
SOFTMAX(softmax_f16, half)
SOFTMAX(softmax_float, float)
SOFTMAX(softmax_half, half)
#if __METAL_VERSION__ >= 310
REDUCE(x + y, fast_sum_bf16, bfloat, 0)
REDUCE(x * y, fast_mul_bf16, bfloat, 1)
REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF)
REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF)
ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
SOFTMAX(softmax_bf16, bfloat)
SOFTMAX(softmax_bfloat, bfloat)
#endif

View File

@ -1,13 +1,6 @@
use super::*;
use half::{bf16, f16};
use metal::{Device, MTLResourceOptions};
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
assert!(!ptr.is_null());
let slice = unsafe { std::slice::from_raw_parts(ptr, n) };
slice.to_vec()
}
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
fn new_buffer<T>(device: &Device, data: &[T]) -> Buffer {
let options = MTLResourceOptions::StorageModeManaged;
@ -37,7 +30,6 @@ fn approx_bf16(v: Vec<bf16>, digits: i32) -> Vec<f32> {
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@ -55,13 +47,12 @@ fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, v.len())
output.read_to_vec::<T>(v.len())
}
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@ -81,7 +72,7 @@ fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> V
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, x.len())
output.read_to_vec::<T>(x.len())
}
fn run_strided<T: Clone>(
@ -96,8 +87,7 @@ fn run_strided<T: Clone>(
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let output = new_buffer(&device, v);
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
call_unary_strided(
&device,
command_buffer,
@ -113,7 +103,7 @@ fn run_strided<T: Clone>(
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, v.len())
output.read_to_vec::<T>(v.len())
}
#[test]
@ -250,8 +240,7 @@ fn binary_add_f32() {
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -272,7 +261,7 @@ fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, v.len())
output.read_to_vec::<U>(v.len())
}
#[test]
@ -298,8 +287,7 @@ fn cast_u32_f32() {
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@ -312,7 +300,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
&device,
command_buffer,
&kernels,
"affine_f32",
"affine_float",
size,
&input,
&output,
@ -323,7 +311,7 @@ fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, v.len())
output.read_to_vec::<T>(v.len())
}
fn run_affine_strided<T: Clone>(
@ -334,8 +322,7 @@ fn run_affine_strided<T: Clone>(
add: f64,
) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
@ -346,7 +333,7 @@ fn run_affine_strided<T: Clone>(
&device,
command_buffer,
&kernels,
"affine_f32_strided",
"affine_float_strided",
shape,
&input,
strides,
@ -360,7 +347,7 @@ fn run_affine_strided<T: Clone>(
command_buffer.wait_until_completed();
let len: usize = shape.iter().product();
read_to_vec(&output, len)
output.read_to_vec::<T>(len)
}
#[test]
@ -463,8 +450,7 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
_ => unimplemented!(),
};
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
call_index_select(
&device,
&command_buffer,
@ -482,7 +468,74 @@ fn run_index_select<T: Clone, I: Clone + std::fmt::Debug>(
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&dst_buffer, dst_el)
dst_buffer.read_to_vec::<T>(dst_el)
}
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let function = library.get_function("ia_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let index_buffer = new_buffer(&device, &index);
let inputs_buffer = new_buffer(&device, &left);
let outputs_buffer = new_buffer(&device, &right);
set_params!(
encoder,
(
&index_buffer,
&inputs_buffer,
&outputs_buffer,
ids_dim_size,
left_size,
dst_dim_size,
right_size
)
);
let grid_size = MTLSize {
width: right.len() as NSUInteger,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs_buffer.read_to_vec::<f32>(right.len());
assert_eq!(result, expected);
}
#[test]
@ -499,23 +552,19 @@ fn cos_f16() {
fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
let dims = vec![v.len()];
let strides = vec![1];
call_reduce_strided(
call_reduce_contiguous(
&device,
command_buffer,
&kernels,
name,
&dims,
&strides,
v.len(),
out_length,
&input,
0,
@ -525,13 +574,12 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, out_length)
output.read_to_vec::<T>(out_length)
}
fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'static str) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = new_buffer(&device, v);
@ -544,14 +592,13 @@ fn run_softmax<T: Clone + std::fmt::Debug>(v: &[T], last_dim: usize, name: &'sta
v.len(),
last_dim,
&input,
0,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, v.len())
output.read_to_vec::<T>(v.len())
}
#[test]
@ -559,7 +606,7 @@ fn reduce_sum() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 1;
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
let results = run_reduce(&v, out_length, "fast_sum_float");
assert_eq!(approx(results, 4), vec![21.0]);
}
@ -568,7 +615,7 @@ fn reduce_sum2() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let out_length = 2;
let results = run_reduce(&v, out_length, "fast_sum_f32_strided");
let results = run_reduce(&v, out_length, "fast_sum_float");
assert_eq!(approx(results, 4), vec![6.0, 15.0]);
}
@ -576,33 +623,15 @@ fn reduce_sum2() {
fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
);
let last_dim = 4096;
let n = 200;
let mut v = vec![0.0; n * last_dim];
for i in 0..n {
v[i * last_dim] = 20.0;
}
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = approx(results, 4);
println!("{results:?}");
assert_eq!(
results.iter().map(|&s| s.round() as usize).sum::<usize>(),
n
);
assert_eq!(results[0], 1.0);
assert_eq!(results[1], 0.0);
assert_eq!(results[last_dim], 1.0);
assert_eq!(results[2 * last_dim], 1.0);
let v = vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0];
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0858, 0.2331, 0.6337]
@ -610,7 +639,7 @@ fn softmax() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let last_dim = 3;
let results = run_softmax(&v, last_dim, "softmax_f32");
let results = run_softmax(&v, last_dim, "softmax_float");
assert_eq!(
approx(results, 4),
vec![0.0900, 0.2447, 0.6652, 0.0900, 0.2447, 0.6652]
@ -621,7 +650,7 @@ fn softmax() {
.map(|v| f16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_f16");
let results = run_softmax(&v, last_dim, "softmax_half");
assert_eq!(
approx_f16(results, 4),
vec![0.0043, 0.0116, 0.0316, 0.0858, 0.2332, 0.6338]
@ -632,7 +661,7 @@ fn softmax() {
.map(|v| bf16::from_f32(*v))
.collect::<Vec<_>>();
let last_dim = 6;
let results = run_softmax(&v, last_dim, "softmax_bf16");
let results = run_softmax(&v, last_dim, "softmax_bfloat");
assert_eq!(
approx_bf16(results, 4),
vec![0.0043, 0.0116, 0.0315, 0.0859, 0.2324, 0.6328]
@ -650,8 +679,7 @@ fn run_where_cond<I: Clone, T: Clone>(
name: &'static str,
) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
@ -692,7 +720,7 @@ fn run_where_cond<I: Clone, T: Clone>(
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
output.read_to_vec::<T>(length)
}
#[test]
@ -716,93 +744,3 @@ fn where_cond() {
);
assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]);
}
fn run_gemm<T: Clone>(
(b, m, n, k): (usize, usize, usize, usize),
lhs: &[T],
lhs_stride: Vec<usize>,
lhs_offset: usize,
rhs: &[T],
rhs_stride: Vec<usize>,
rhs_offset: usize,
) -> Vec<T> {
let device = device();
let fence = device.new_fence();
let kernels = Kernels::new(fence);
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let lhs = device.new_buffer_with_data(
lhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(lhs) as u64,
options,
);
let rhs = device.new_buffer_with_data(
rhs.as_ptr() as *const core::ffi::c_void,
std::mem::size_of_val(rhs) as u64,
options,
);
let length = b * m * n;
let output = device.new_buffer((length * core::mem::size_of::<T>()) as u64, options);
call_gemm(
&device,
command_buffer,
&kernels,
"sgemm",
(b, m, n, k),
&lhs_stride,
lhs_offset,
&lhs,
&rhs_stride,
rhs_offset,
&rhs,
&output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
read_to_vec(&output, length)
}
#[test]
fn gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0);
assert_eq!(
approx(results, 4),
vec![
20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0, 344.0, 365.0, 386.0, 407.0, 488.0,
518.0, 548.0, 578.0
]
);
// OFFSET
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
// Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32
let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4);
assert_eq!(
approx(results, 4),
vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0]
);
}

View File

@ -64,12 +64,12 @@ kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[tid] = TYPENAME(FN(float(input[tid]))); \
output[thread_position_in_grid] = TYPENAME(FN(input[thread_position_in_grid])); \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
@ -78,20 +78,20 @@ kernel void FN_NAME_STRIDED( \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint tid [[ thread_position_in_grid ]] \
uint thread_position_in_grid [[ thread_position_in_grid ]] \
) { \
if (tid >= dim) { \
if (thread_position_in_grid >= dim) { \
return; \
} \
output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \
output[thread_position_in_grid] = TYPENAME(FN(input[get_strided_index(thread_position_in_grid, num_dims, dims, strides)])); \
}
#define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \
UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY_OP(cos)
@ -107,9 +107,8 @@ UNARY_OP(floor)
UNARY_OP(round)
UNARY_OP(gelu_erf)
UNARY_OP(erf)
UNARY_OP(tanh)
UNARY(id, float, copy_f32, copy_f32_strided)
UNARY(id, half, copy_f16, copy_f16_strided)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
UNARY(id, uint32_t, copy_u32, copy_u32_strided)
@ -127,7 +126,6 @@ BFLOAT_UNARY_OP(floor)
BFLOAT_UNARY_OP(round)
BFLOAT_UNARY_OP(gelu_erf)
BFLOAT_UNARY_OP(erf)
BFLOAT_UNARY_OP(tanh)
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
UNARY(id, bfloat, copy_bfloat, copy_bfloat_strided)
#endif

View File

@ -11,7 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
@ -19,7 +19,6 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
metal = { workspace = true, optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
@ -31,4 +30,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -1,4 +1,4 @@
use candle::{Result, Tensor};
use candle::Tensor;
use serde::Deserialize;
#[derive(Debug, Clone, Copy, PartialEq, Deserialize, Default)]
@ -21,7 +21,7 @@ pub enum Activation {
}
impl super::Module for Activation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::Gelu => xs.gelu_erf(),
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
@ -40,60 +40,3 @@ impl super::Module for Activation {
}
}
}
#[derive(Clone, Debug)]
pub struct PReLU {
weight: Tensor,
is_scalar: bool,
}
impl PReLU {
pub fn new(weight: Tensor, is_scalar: bool) -> Self {
Self { weight, is_scalar }
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn is_scalar(&self) -> bool {
self.is_scalar
}
}
impl candle::Module for PReLU {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let weight = if self.is_scalar {
self.weight.reshape(())?
} else if xs.rank() >= 2 {
let num_channels = xs.dim(1)?;
let num_weights = self.weight.elem_count();
if num_weights != num_channels {
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
}
let mut s = vec![1; xs.rank()];
s[1] = self.weight.elem_count();
self.weight.reshape(s)?
} else {
self.weight.clone()
};
let zeros = xs.zeros_like()?;
xs.maximum(&zeros)? + xs.minimum(&zeros)?.broadcast_mul(&weight)?
}
}
/// Create or initialize a new PReLU layer.
///
/// This uses some default name for weights, namely `"weight"`.
/// # Arguments
///
/// * `num_channels` - The number of channels. Use `None` to have as single trainable value and
/// `Some` for a 1D vector with the appropriate number of channels. When applying the `forward`
/// function, the input tensor shape `s` should either be one dimension with this number of
/// channels or if `s.len() >= 2` it should have `s[1]` equal to this number.
pub fn prelu(num_channels: Option<usize>, vs: crate::VarBuilder) -> Result<PReLU> {
let init_ws = crate::init::Init::Const(0.25);
// When using a scalar weight, the PyTorch encoding is to use a 1d vector of length 1.
let ws = vs.get_with_hints((num_channels.unwrap_or(1),), "weight", init_ws)?;
Ok(PReLU::new(ws, num_channels.is_none()))
}

View File

@ -15,7 +15,7 @@ pub mod sequential;
pub mod var_builder;
pub mod var_map;
pub use activation::{prelu, Activation, PReLU};
pub use activation::Activation;
pub use batch_norm::{batch_norm, BatchNorm, BatchNormConfig};
pub use conv::{
conv1d, conv2d, conv2d_no_bias, conv_transpose2d, conv_transpose2d_no_bias, Conv1d,

View File

@ -56,7 +56,7 @@ impl super::Module for Linear {
/// Create or initialize a new linear layer.
///
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
/// This uses some default names for weight and biases, namely `"weight"` and `"bias"`.
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
@ -69,7 +69,6 @@ pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Li
Ok(Linear::new(ws, Some(bs)))
}
/// Create or initialize a new linear layer without biases.
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;

View File

@ -210,33 +210,32 @@ impl candle::CustomOp1 for SoftmaxLastDim {
) -> Result<(candle::MetalStorage, Shape)> {
use candle::{backend::BackendStorage, DType};
let device = storage.device();
let command_buffer = device.command_buffer()?;
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = match storage.dtype() {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
DType::F32 => "softmax_float",
DType::F16 => "softmax_half",
DType::BF16 => "softmax_bfloat",
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
};
let n = layout.stride().len();
if !(layout.is_contiguous() && layout.stride()[n - 1] == 1) {
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
candle::bail!("Non contiguous softmax-last-dim is not implemented");
}
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
kernels,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(),
&output,
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());

View File

@ -190,12 +190,4 @@ impl AdamW {
};
Self::new(vars, params)
}
pub fn params(&self) -> &ParamsAdamW {
&self.params
}
pub fn set_params(&mut self, params: ParamsAdamW) {
self.params = params;
}
}

View File

@ -40,7 +40,7 @@ struct TensorData<B: Backend> {
/// A trait that defines how tensor data is retrieved.
///
/// Typically this would use disk storage in some specific format, or random initialization.
/// Note that there is a specialized version of this trait (`SimpleBackend`) that can be used most
/// Note that there is a speciliazed version of this trait (`SimpleBackend`) that can be used most
/// of the time. The main restriction is that it doesn't allow for specific args (besides
/// initialization hints).
pub trait Backend: Send + Sync {
@ -535,18 +535,12 @@ impl Backend for ShardedSafeTensors {
fn get(
&self,
target_shape: Shape, // The size is only checked when the world size is 1.
_target_shape: Shape, // The size is not checked for ShardedTensors
path: &str,
h: Self::Hints,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
if h.world_size == 1 {
// There is no sharding to be applied here so we use the default backend to speed
// things up.
return SimpleBackend::get(&self.0, target_shape, path, Default::default(), dtype, dev);
}
let Shard {
dim,
rank,

View File

@ -1,6 +1,6 @@
[package]
name = "candle-onnx"
version = "0.3.3"
version = "0.3.1"
edition = "2021"
description = "ONNX support for Candle"
@ -10,8 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
prost = "0.12.1"
[build-dependencies]

View File

@ -15,9 +15,9 @@ crate-type = ["cdylib"]
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle-onnx = {path= "../candle-onnx", version = "0.3.3", optional = true}
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
candle-onnx = {path= "../candle-onnx", version = "0.3.1", optional = true}
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }

View File

@ -4,8 +4,7 @@ try:
from .candle import *
except ImportError as e:
# If we are in development mode, or we did not bundle the DLLs, we try to locate them here
# PyO3 wont give us any information about what DLLs are missing, so we can only try to load
# the DLLs and re-import the module
# PyO3 wont give us any infomration about what DLLs are missing, so we can only try to load the DLLs and re-import the module
logging.warning("DLLs were not bundled with this package. Trying to locate them...")
import os
import platform

View File

@ -363,7 +363,7 @@ class ModuleList(Module):
self.add_module(str(offset + i), module)
return self
# remove forward altogether to fallback on Module's _forward_unimplemented
# remove forward alltogether to fallback on Module's _forward_unimplemented
class ModuleDict(Module):
@ -480,4 +480,4 @@ class ModuleDict(Module):
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment]
# remove forward altogether to fallback on Module's _forward_unimplemented
# remove forward alltogether to fallback on Module's _forward_unimplemented

View File

@ -212,7 +212,7 @@ trait MapDType {
enum Indexer {
Index(usize),
Slice(usize, usize),
Ellipsis,
Elipsis,
Expand,
IndexSelect(Tensor),
}
@ -568,7 +568,7 @@ impl PyTensor {
"Ellipsis ('...') can only be used at the start of an indexing operation",
));
}
Ok((Indexer::Ellipsis, dims.len() - (index_argument_count - 1)))
Ok((Indexer::Elipsis, dims.len() - (index_argument_count - 1)))
} else if py_indexer.is_none() {
// Handle None e.g. tensor[None, 0]
Ok((Indexer::Expand, current_dim))
@ -616,9 +616,8 @@ impl PyTensor {
current_dim += 1;
out
}
Indexer::Ellipsis => {
// Ellipsis is a special case, it means that all remaining dimensions should be
// selected => advance the current_dim to the last dimension we have indexers for
Indexer::Elipsis => {
// Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for
current_dim += dims.len() - (indexers.len() - 1);
x
}
@ -961,11 +960,11 @@ impl PyTensor {
extraction_result: PyResult<T>,
err_msg: &'static str,
) -> PyResult<()> {
if let Ok(successful_extraction) = extraction_result {
if let Ok(sucessfull_extraction) = extraction_result {
if opt.is_some() {
return Err(PyValueError::new_err(err_msg));
}
*opt = Some(successful_extraction);
*opt = Some(sucessfull_extraction);
}
Ok(())
}
@ -1046,7 +1045,9 @@ impl PyTensor {
.map_err(wrap_err)?,
(Some(device), None) => self.0.to_device(&device.as_device()?).map_err(wrap_err)?,
(None, Some(dtype)) => self.0.to_dtype(dtype.0).map_err(wrap_err)?,
(None, None) => return Err(PyTypeError::new_err("No valid dtype or device specified")),
(None, None) => {
return Err(PyTypeError::new_err("No valide dtype or device specified"))
}
};
Ok(PyTensor(result))

View File

@ -156,7 +156,7 @@ def pyi_file(obj, indent=""):
string += function(obj, indent)
elif inspect.isgetsetdescriptor(obj):
# TODO it would be interesting to add the setter maybe ?
# TODO it would be interesing to add the setter maybe ?
string += f"{indent}@property\n"
string += function(obj, indent, text_signature="(self)")

View File

@ -74,7 +74,7 @@ def test_module_can_load_statedict():
a.load_state_dict(statedict)
def test_module_throws_on_shape_mismatch():
def test_module_throws_on_shape_missmatch():
class A(Module):
def __init__(self):
super().__init__()
@ -121,7 +121,7 @@ def test_module_can_load_quantized_tensors():
assert a.t.ggml_dtype == "Q4_0"
def test_module_dequantizes_tensors_automatically():
def test_module_dequantizes_tensors_automaticaly():
class A(Module):
def __init__(self):
super().__init__()

View File

@ -84,7 +84,7 @@ def assert_bool(t: Tensor, expected: bool):
assert bool(t.values()) == expected
def test_tensor_supports_equality_operations_with_scalars():
def test_tensor_supports_equality_opperations_with_scalars():
t = Tensor(42.0)
assert_bool(t == 42.0, True)
@ -106,7 +106,7 @@ def test_tensor_supports_equality_operations_with_scalars():
assert_bool(t <= 42.0, True)
def test_tensor_supports_equality_operations_with_tensors():
def test_tensor_supports_equality_opperations_with_tensors():
t = Tensor(42.0)
same = Tensor(42.0)
other = Tensor(43.0)
@ -130,7 +130,7 @@ def test_tensor_supports_equality_operations_with_tensors():
assert_bool(t <= other, True)
def test_tensor_equality_operations_can_broadcast():
def test_tensor_equality_opperations_can_broadcast():
# Create a decoder attention mask as a test case
# e.g.
# [[1,0,0]

View File

@ -12,9 +12,9 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.3", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.3", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.3" }
candle = { path = "../candle-core", version = "0.3.1", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.1", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.1" }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true }
@ -31,4 +31,3 @@ accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"]
flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
metal = ["candle/metal", "candle-nn/metal"]

View File

@ -7,9 +7,8 @@ pub const DTYPE: DType = DType::F32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum HiddenAct {
enum HiddenAct {
Gelu,
GeluApproximate,
Relu,
}
@ -29,7 +28,6 @@ impl HiddenActLayer {
match self.act {
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/activations.py#L213
HiddenAct::Gelu => xs.gelu_erf(),
HiddenAct::GeluApproximate => xs.gelu(),
HiddenAct::Relu => xs.relu(),
}
}
@ -50,7 +48,7 @@ pub struct Config {
num_hidden_layers: usize,
num_attention_heads: usize,
intermediate_size: usize,
pub hidden_act: HiddenAct,
hidden_act: HiddenAct,
hidden_dropout_prob: f64,
max_position_embeddings: usize,
type_vocab_size: usize,

View File

@ -21,7 +21,6 @@ pub struct Config {
}
impl Config {
// https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json
pub fn config_7b_v0_1(use_flash_attn: bool) -> Self {
Self {
vocab_size: 32000,
@ -38,25 +37,6 @@ impl Config {
use_flash_attn,
}
}
// https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca/blob/main/config.json
// https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
pub fn config_chat_ml(use_flash_attn: bool) -> Self {
Self {
vocab_size: 32002,
hidden_size: 4096,
intermediate_size: 14336,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 8,
hidden_act: Activation::Silu,
max_position_embeddings: 32768,
rms_norm_eps: 1e-5,
rope_theta: 10_000.,
sliding_window: 4096,
use_flash_attn,
}
}
}
#[derive(Debug, Clone)]
@ -297,10 +277,6 @@ impl Attention {
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
@ -344,10 +320,6 @@ impl DecoderLayer {
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
@ -431,10 +403,4 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -57,22 +57,6 @@ impl Config {
}
}
pub fn v2() -> Self {
Self {
vocab_size: 51200,
n_positions: 2048,
n_embd: 2560,
n_layer: 32,
n_inner: None,
n_head: 32,
rotary_dim: usize::min(32, 2560 / 32),
activation_function: Activation::Gelu,
layer_norm_epsilon: 1e-5,
tie_word_embeddings: false,
pad_vocab_size_multiple: 64,
}
}
// https://huggingface.co/teknium/Puffin-Phi-v2/blob/main/config.json
pub fn puffin_phi_v2() -> Self {
Self {
@ -388,24 +372,6 @@ pub struct MixFormerSequentialForCausalLM {
}
impl MixFormerSequentialForCausalLM {
pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_head = vb.pp("lm_head");
let vb = vb.pp("transformer");
let embedding = Embedding::new(cfg, vb.pp("embd"))?;
let mut blocks = Vec::new();
for i in 0..cfg.n_layer {
let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
blocks.push(block)
}
let head = CausalLMHead::new(cfg, vb_head)?;
Ok(Self {
embedding,
blocks,
head,
span: tracing::span!(tracing::Level::TRACE, "mixformer"),
})
}
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("layers");
let embedding = Embedding::new(cfg, vb.pp(0))?;

View File

@ -1,499 +0,0 @@
use crate::models::with_tracing::{linear_no_bias, Linear};
/// Mixtral Model
/// https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
/// https://mistral.ai/news/mixtral-of-experts/
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use serde::Deserialize;
use std::sync::Arc;
/// https://github.com/huggingface/transformers/blob/1a585c1222a56bcaecc070966d558d4a9d862e83/src/transformers/models/mixtral/configuration_mixtral.py#L113
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub(crate) vocab_size: usize,
pub(crate) hidden_size: usize,
pub(crate) intermediate_size: usize,
pub(crate) num_hidden_layers: usize,
pub(crate) num_attention_heads: usize,
pub(crate) num_key_value_heads: usize,
pub(crate) hidden_act: Activation,
pub(crate) max_position_embeddings: usize,
pub(crate) rms_norm_eps: f64,
pub(crate) rope_theta: f64,
pub(crate) sliding_window: usize,
pub(crate) num_experts_per_tok: usize,
pub(crate) num_local_experts: usize,
pub(crate) use_flash_attn: bool,
}
impl Config {
/// https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json
pub fn v0_1_8x7b(use_flash_attn: bool) -> Self {
Self {
vocab_size: 32000,
hidden_size: 4096,
intermediate_size: 14336,
num_hidden_layers: 32,
num_attention_heads: 32,
num_key_value_heads: 8,
hidden_act: Activation::Silu,
max_position_embeddings: 32768,
rms_norm_eps: 1e-5,
rope_theta: 1e6,
sliding_window: 4096,
num_experts_per_tok: 2,
num_local_experts: 8,
use_flash_attn,
}
}
}
#[derive(Debug, Clone)]
struct RmsNorm {
inner: candle_nn::RmsNorm,
span: tracing::Span,
}
impl RmsNorm {
fn new(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let inner = candle_nn::rms_norm(size, eps, vb)?;
Ok(Self { inner, span })
}
}
impl Module for RmsNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Debug, Clone)]
struct RotaryEmbedding {
sin: Tensor,
cos: Tensor,
}
fn rotate_half(xs: &Tensor) -> Result<Tensor> {
let last_dim = xs.dim(D::Minus1)?;
let xs1 = xs.narrow(D::Minus1, 0, last_dim / 2)?;
let xs2 = xs.narrow(D::Minus1, last_dim / 2, last_dim - last_dim / 2)?;
Tensor::cat(&[&xs2.neg()?, &xs1], D::Minus1)
}
impl RotaryEmbedding {
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
let dim = cfg.hidden_size / cfg.num_attention_heads;
let max_seq_len = cfg.max_position_embeddings;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / (cfg.rope_theta as f32).powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
.to_dtype(dtype)?
.reshape((max_seq_len, 1))?;
let freqs = t.matmul(&inv_freq)?;
let freqs = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
seqlen_offset: usize,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
let cos = cos.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let sin = sin.unsqueeze(0)?.unsqueeze(0)?; // (1, 1, seq_len, dim)
let q_embed = (q.broadcast_mul(&cos)? + rotate_half(q)?.broadcast_mul(&sin))?;
let k_embed = (k.broadcast_mul(&cos)? + rotate_half(k)?.broadcast_mul(&sin))?;
Ok((q_embed, k_embed))
}
}
#[cfg(feature = "flash-attn")]
fn flash_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
softmax_scale: f32,
causal: bool,
) -> Result<Tensor> {
candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal)
}
#[cfg(not(feature = "flash-attn"))]
fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Tensor> {
unimplemented!("compile with '--features flash-attn'")
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
num_heads: usize,
num_kv_heads: usize,
num_kv_groups: usize,
head_dim: usize,
hidden_size: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: Option<(Tensor, Tensor)>,
use_flash_attn: bool,
}
impl Attention {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
let num_kv_heads = cfg.num_key_value_heads;
let num_kv_groups = num_heads / num_kv_heads;
let head_dim = hidden_sz / num_heads;
let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?;
let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?;
let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?;
let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
num_heads,
num_kv_heads,
num_kv_groups,
head_dim,
hidden_size: hidden_sz,
rotary_emb,
kv_cache: None,
use_flash_attn: cfg.use_flash_attn,
})
}
fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
let key_states = self.k_proj.forward(xs)?;
let value_states = self.v_proj.forward(xs)?;
let query_states = query_states
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
.transpose(1, 2)?;
let key_states = key_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let value_states = value_states
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
.transpose(1, 2)?;
let (query_states, key_states) =
self.rotary_emb
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
let (key_states, value_states) = match &self.kv_cache {
None => (key_states, value_states),
Some((prev_k, prev_v)) => {
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));
let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
let q = query_states.transpose(1, 2)?;
let k = key_states.transpose(1, 2)?;
let v = value_states.transpose(1, 2)?;
let softmax_scale = 1f32 / (self.head_dim as f32).sqrt();
flash_attn(&q, &k, &v, softmax_scale, q_len > 1)?.transpose(1, 2)?
} else {
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
let attn_weights = match attention_mask {
None => attn_weights,
Some(mask) => attn_weights.broadcast_add(mask)?,
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
attn_weights.matmul(&value_states)?
};
attn_output
.transpose(1, 2)?
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
}
#[derive(Debug, Clone)]
struct BlockSparseTop2MLP {
w1: Linear,
w2: Linear,
w3: Linear,
act_fn: Activation,
}
impl BlockSparseTop2MLP {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let intermediate_sz = cfg.intermediate_size;
let w1 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w1"))?;
let w2 = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("w2"))?;
let w3 = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("w3"))?;
Ok(Self {
w1,
w2,
w3,
act_fn: cfg.hidden_act,
})
}
}
impl Module for BlockSparseTop2MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let lhs = xs.apply(&self.w1)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.w3)?;
(lhs * rhs)?.apply(&self.w2)
}
}
#[derive(Debug, Clone)]
struct SparseMoeBlock {
gate: Linear,
experts: Vec<BlockSparseTop2MLP>,
num_experts_per_tok: usize,
}
impl SparseMoeBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let gate = linear_no_bias(cfg.hidden_size, cfg.num_local_experts, vb.pp("gate"))?;
let mut experts = Vec::with_capacity(cfg.num_local_experts);
let vb = vb.pp("experts");
for idx in 0..cfg.num_local_experts {
let expert = BlockSparseTop2MLP::new(cfg, vb.pp(idx))?;
experts.push(expert)
}
Ok(SparseMoeBlock {
gate,
experts,
num_experts_per_tok: cfg.num_experts_per_tok,
})
}
}
impl Module for SparseMoeBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
let xs = xs.reshape(((), hidden_dim))?;
let router_logits = xs.apply(&self.gate)?;
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
// In order to extract topk, we extract the data from the tensor and manipulate it
// directly. Maybe we will want to use some custom ops instead at some point.
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
// top_x contains the row indexes to evaluate for each expert.
let mut top_x = vec![vec![]; self.experts.len()];
let mut selected_rws = vec![vec![]; self.experts.len()];
for (row_idx, rw) in routing_weights.iter().enumerate() {
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
let mut sum_routing_weights = 0f32;
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
let expert_idx = expert_idx as usize;
let routing_weight = rw[expert_idx];
sum_routing_weights += routing_weight;
top_x[expert_idx].push(row_idx as u32);
}
for &expert_idx in dst.iter().take(self.num_experts_per_tok) {
let expert_idx = expert_idx as usize;
let routing_weight = rw[expert_idx];
selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
}
}
// routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
// expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
let mut ys = xs.zeros_like()?;
for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
let top_x = &top_x[expert_idx];
if top_x.is_empty() {
continue;
}
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
let selected_rws =
Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?.reshape(((), 1))?;
// Index the correct hidden states and compute the expert hidden state for
// the current expert. We need to make sure to multiply the output hidden
// states by `routing_weights` on the corresponding tokens (top-1 and top-2)
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
// current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
let current_hidden_states = expert_layer.forward(&current_state)?;
let current_hidden_states = current_hidden_states.broadcast_mul(&selected_rws)?;
ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
}
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
Ok(ys)
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
block_sparse_moe: SparseMoeBlock,
input_layernorm: RmsNorm,
post_attention_layernorm: RmsNorm,
}
impl DecoderLayer {
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
let block_sparse_moe = SparseMoeBlock::new(cfg, vb.pp("block_sparse_moe"))?;
let input_layernorm =
RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm = RmsNorm::new(
cfg.hidden_size,
cfg.rms_norm_eps,
vb.pp("post_attention_layernorm"),
)?;
Ok(Self {
self_attn,
block_sparse_moe,
input_layernorm,
post_attention_layernorm,
})
}
fn forward(
&mut self,
xs: &Tensor,
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = xs
.apply(&self.post_attention_layernorm)?
.apply(&self.block_sparse_moe)?;
residual + xs
}
}
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
sliding_window: usize,
device: Device,
dtype: DType,
}
impl Model {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_m = vb.pp("model");
let embed_tokens =
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}
fn prepare_decoder_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
seqlen_offset: usize,
) -> Result<Tensor> {
// Sliding window mask?
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if i < j || j + self.sliding_window < i {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
};
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
.to_dtype(self.dtype)
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}
}

View File

@ -14,7 +14,6 @@ pub mod llama2_c_weights;
pub mod marian;
pub mod mistral;
pub mod mixformer;
pub mod mixtral;
pub mod mpt;
pub mod persimmon;
pub mod quantized_blip;

View File

@ -47,102 +47,6 @@ impl QMatMul {
}
}
#[derive(Debug, Clone)]
struct Mlp {
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let w1 = self.feed_forward_w1.forward(xs)?;
let w3 = self.feed_forward_w3.forward(xs)?;
self.feed_forward_w2
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)
}
}
#[derive(Debug, Clone)]
enum MlpOrMoe {
Mlp(Mlp),
MoE {
n_expert_used: usize,
feed_forward_gate_inp: QMatMul,
experts: Vec<Mlp>,
},
}
impl Module for MlpOrMoe {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::MoE {
feed_forward_gate_inp,
experts,
n_expert_used,
} => {
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
let xs = xs.reshape(((), hidden_dim))?;
let router_logits = feed_forward_gate_inp.forward(&xs)?;
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
// In order to extract topk, we extract the data from the tensor and manipulate it
// directly. Maybe we will want to use some custom ops instead at some point.
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
// routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
// top_x contains the row indexes to evaluate for each expert.
let mut top_x = vec![vec![]; experts.len()];
let mut selected_rws = vec![vec![]; experts.len()];
for (row_idx, rw) in routing_weights.iter().enumerate() {
let mut dst = (0..rw.len() as u32).collect::<Vec<u32>>();
dst.sort_by(|&i, &j| rw[j as usize].total_cmp(&rw[i as usize]));
let mut sum_routing_weights = 0f32;
for &expert_idx in dst.iter().take(*n_expert_used) {
let expert_idx = expert_idx as usize;
let routing_weight = rw[expert_idx];
sum_routing_weights += routing_weight;
top_x[expert_idx].push(row_idx as u32);
}
for &expert_idx in dst.iter().take(*n_expert_used) {
let expert_idx = expert_idx as usize;
let routing_weight = rw[expert_idx];
selected_rws[expert_idx].push(routing_weight / sum_routing_weights)
}
}
// routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
// expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
let mut ys = xs.zeros_like()?;
for (expert_idx, expert_layer) in experts.iter().enumerate() {
let top_x = &top_x[expert_idx];
if top_x.is_empty() {
continue;
}
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
let selected_rws =
Tensor::new(selected_rws[expert_idx].as_slice(), xs.device())?
.reshape(((), 1))?;
// Index the correct hidden states and compute the expert hidden state for
// the current expert. We need to make sure to multiply the output hidden
// states by `routing_weights` on the corresponding tokens (top-1 and top-2)
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
// current_hidden_states = expert_layer(current_state, routing_weights[top_x_list, idx_list, None])
let current_hidden_states = expert_layer.forward(&current_state)?;
let current_hidden_states =
current_hidden_states.broadcast_mul(&selected_rws)?;
ys = ys.index_add(&top_x, &current_hidden_states, 0)?;
}
let ys = ys.reshape((b_size, seq_len, hidden_dim))?;
Ok(ys)
}
Self::Mlp(mlp) => mlp.forward(xs),
}
}
}
#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
@ -150,7 +54,9 @@ struct LayerWeights {
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
mlp_or_moe: MlpOrMoe,
feed_forward_w1: QMatMul,
feed_forward_w2: QMatMul,
feed_forward_w3: QMatMul,
ffn_norm: RmsNorm,
n_head: usize,
n_kv_head: usize,
@ -306,16 +212,9 @@ impl ModelWeights {
let attention_wk = ct.remove(&format!("{prefix}.attention.wk.weight"))?;
let attention_wv = ct.remove(&format!("{prefix}.attention.wv.weight"))?;
let attention_wo = ct.remove(&format!("{prefix}.attention.wo.weight"))?;
let mlp_or_moe = {
let feed_forward_w1 = ct.remove(&format!("{prefix}.feed_forward.w1.weight"))?;
let feed_forward_w2 = ct.remove(&format!("{prefix}.feed_forward.w2.weight"))?;
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
};
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
@ -327,7 +226,9 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, 1e-5)?,
mlp_or_moe,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
ffn_norm: RmsNorm::new(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
@ -364,12 +265,6 @@ impl ModelWeights {
};
// Parameter extraction from metadata.
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
@ -394,38 +289,9 @@ impl ModelWeights {
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"))?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"))?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"))?;
let mlp_or_moe = if n_expert <= 1 {
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"))?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"))?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"))?;
MlpOrMoe::Mlp(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
} else {
let feed_forward_gate_inp =
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"))?;
let mut experts = Vec::with_capacity(n_expert);
for i in 0..n_expert {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.{i}.weight"))?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.{i}.weight"))?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.{i}.weight"))?;
experts.push(Mlp {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
}
MlpOrMoe::MoE {
n_expert_used,
feed_forward_gate_inp: QMatMul::from_qtensor(feed_forward_gate_inp)?,
experts,
}
};
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"))?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"))?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
@ -437,7 +303,9 @@ impl ModelWeights {
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::new(attention_norm, rms_norm_eps)?,
mlp_or_moe,
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
ffn_norm: RmsNorm::new(ffn_norm, rms_norm_eps)?,
n_head: head_count,
n_kv_head: head_count_kv,
@ -492,9 +360,12 @@ impl ModelWeights {
let _enter = layer.span_mlp.enter();
let residual = &x;
let x = layer.ffn_norm.forward(&x)?;
let x = layer.mlp_or_moe.forward(&x)?;
let x = (x + residual)?;
layer_in = x
let w1 = layer.feed_forward_w1.forward(&x)?;
let w3 = layer.feed_forward_w3.forward(&x)?;
let mlp = layer
.feed_forward_w2
.forward(&(candle_nn::ops::silu(&w1)? * w3)?)?;
layer_in = (mlp + residual)?;
}
let x = self.norm.forward(&layer_in)?;
let x = x.i((.., seq_len - 1, ..))?;

View File

@ -198,10 +198,6 @@ impl Attention {
.reshape((b_sz, q_len, self.hidden_size))?
.apply(&self.o_proj)
}
fn clear_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
@ -245,10 +241,6 @@ impl DecoderLayer {
let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?;
residual + xs
}
fn clear_kv_cache(&mut self) {
self.self_attn.clear_kv_cache()
}
}
#[derive(Debug, Clone)]
@ -330,10 +322,4 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
}
}
}

View File

@ -287,24 +287,6 @@ pub struct MixFormerSequentialForCausalLM {
}
impl MixFormerSequentialForCausalLM {
pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb_head = vb.pp("lm_head");
let vb = vb.pp("transformer");
let embedding = Embedding::new(cfg, vb.pp("embd"))?;
let mut blocks = Vec::new();
for i in 0..cfg.n_layer {
let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
blocks.push(block)
}
let head = CausalLMHead::new(cfg, vb_head)?;
Ok(Self {
embedding,
blocks,
head,
span: tracing::span!(tracing::Level::TRACE, "mixformer"),
})
}
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let vb = vb.pp("layers");
let embedding = Embedding::new(cfg, vb.pp(0))?;

View File

@ -182,7 +182,7 @@ impl MaskDecoder {
sparse_prompt_embeddings: &Tensor,
dense_prompt_embeddings: &Tensor,
) -> Result<(Tensor, Tensor)> {
// Concatenate output tokens.
// Concatenate ouput tokens.
let output_tokens = Tensor::cat(
&[self.iou_token.embeddings(), self.mask_tokens.embeddings()],
0,

View File

@ -2,11 +2,11 @@ use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::VarBuilder;
#[derive(Debug)]
struct PositionEmbeddingRandom {
struct PostionEmbeddingRandom {
positional_encoding_gaussian_matrix: Tensor,
}
impl PositionEmbeddingRandom {
impl PostionEmbeddingRandom {
fn new(num_pos_feats: usize, vb: VarBuilder) -> Result<Self> {
let positional_encoding_gaussian_matrix =
vb.get((2, num_pos_feats), "positional_encoding_gaussian_matrix")?;
@ -52,7 +52,7 @@ impl PositionEmbeddingRandom {
#[derive(Debug)]
pub struct PromptEncoder {
pe_layer: PositionEmbeddingRandom,
pe_layer: PostionEmbeddingRandom,
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
@ -76,7 +76,7 @@ impl PromptEncoder {
vb: VarBuilder,
) -> Result<Self> {
let num_points_embeddings = 4;
let pe_layer = PositionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
let pe_layer = PostionEmbeddingRandom::new(embed_dim / 2, vb.pp("pe_layer"))?;
let not_a_point_embed = candle_nn::embedding(1, embed_dim, vb.pp("not_a_point_embed"))?;
let no_mask_embed = candle_nn::embedding(1, embed_dim, vb.pp("no_mask_embed"))?;
let cfg = candle_nn::Conv2dConfig {

View File

@ -7,9 +7,7 @@
//!
//! Denoising Diffusion Implicit Models, J. Song et al, 2020.
//! https://arxiv.org/abs/2010.02502
use super::schedulers::{
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig, TimestepSpacing,
};
use super::schedulers::{betas_for_alpha_bar, BetaSchedule, PredictionType};
use candle::{Result, Tensor};
/// The configuration for the DDIM scheduler.
@ -31,8 +29,6 @@ pub struct DDIMSchedulerConfig {
pub prediction_type: PredictionType,
/// number of diffusion steps used to train the model
pub train_timesteps: usize,
/// time step spacing for the diffusion process
pub timestep_spacing: TimestepSpacing,
}
impl Default for DDIMSchedulerConfig {
@ -45,17 +41,10 @@ impl Default for DDIMSchedulerConfig {
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
timestep_spacing: TimestepSpacing::Leading,
}
}
}
impl SchedulerConfig for DDIMSchedulerConfig {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
Ok(Box::new(DDIMScheduler::new(inference_steps, *self)?))
}
}
/// The DDIM scheduler.
#[derive(Debug, Clone)]
pub struct DDIMScheduler {
@ -71,32 +60,12 @@ impl DDIMScheduler {
/// Creates a new DDIM scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> = match config.timestep_spacing {
TimestepSpacing::Leading => (0..(inference_steps))
let timesteps: Vec<usize> = (0..(inference_steps))
.map(|s| s * step_ratio + config.steps_offset)
.rev()
.collect(),
TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
if *n > step_ratio {
Some(n - step_ratio)
} else {
None
}
})
.map(|n| n - 1)
.collect(),
TimestepSpacing::Linspace => {
super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
.to_vec1::<f64>()?
.iter()
.map(|&f| f as usize)
.rev()
.collect()
}
};
.collect();
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => super::utils::linspace(
config.beta_start.sqrt(),
@ -123,11 +92,19 @@ impl DDIMScheduler {
config,
})
}
}
impl Scheduler for DDIMScheduler {
pub fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
Ok(sample)
}
/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@ -186,17 +163,7 @@ impl Scheduler for DDIMScheduler {
}
}
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor> {
Ok(sample)
}
fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let timestep = if timestep >= self.alphas_cumprod.len() {
timestep - 1
} else {
@ -207,7 +174,7 @@ impl Scheduler for DDIMScheduler {
(original * sqrt_alpha_prod)? + (noise * sqrt_one_minus_alpha_prod)?
}
fn init_noise_sigma(&self) -> f64 {
pub fn init_noise_sigma(&self) -> f64 {
self.init_noise_sigma
}
}

View File

@ -1,235 +0,0 @@
//! Ancestral sampling with Euler method steps.
//!
//! Reference implementation in Rust:
//!
//! https://github.com/pykeio/diffusers/blob/250b9ad1898af41e76a74c0d8d4292652823338a/src/schedulers/euler_ancestral_discrete.rs
//!
//! Based on the original [`k-diffusion` implementation by Katherine Crowson][kd].
///
/// [kd]: https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
use super::{
schedulers::{
betas_for_alpha_bar, BetaSchedule, PredictionType, Scheduler, SchedulerConfig,
TimestepSpacing,
},
utils::interp,
};
use candle::{bail, Error, Result, Tensor};
/// The configuration for the EulerAncestral Discrete scheduler.
#[derive(Debug, Clone, Copy)]
pub struct EulerAncestralDiscreteSchedulerConfig {
/// The value of beta at the beginning of training.n
pub beta_start: f64,
/// The value of beta at the end of training.
pub beta_end: f64,
/// How beta evolved during training.
pub beta_schedule: BetaSchedule,
/// Adjust the indexes of the inference schedule by this value.
pub steps_offset: usize,
/// prediction type of the scheduler function, one of `epsilon` (predicting
/// the noise of the diffusion process), `sample` (directly predicting the noisy sample`)
/// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf)
pub prediction_type: PredictionType,
/// number of diffusion steps used to train the model
pub train_timesteps: usize,
/// time step spacing for the diffusion process
pub timestep_spacing: TimestepSpacing,
}
impl Default for EulerAncestralDiscreteSchedulerConfig {
fn default() -> Self {
Self {
beta_start: 0.00085f64,
beta_end: 0.012f64,
beta_schedule: BetaSchedule::ScaledLinear,
steps_offset: 1,
prediction_type: PredictionType::Epsilon,
train_timesteps: 1000,
timestep_spacing: TimestepSpacing::Leading,
}
}
}
impl SchedulerConfig for EulerAncestralDiscreteSchedulerConfig {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>> {
Ok(Box::new(EulerAncestralDiscreteScheduler::new(
inference_steps,
*self,
)?))
}
}
/// The EulerAncestral Discrete scheduler.
#[derive(Debug, Clone)]
pub struct EulerAncestralDiscreteScheduler {
timesteps: Vec<usize>,
sigmas: Vec<f64>,
init_noise_sigma: f64,
pub config: EulerAncestralDiscreteSchedulerConfig,
}
// clip_sample: False, set_alpha_to_one: False
impl EulerAncestralDiscreteScheduler {
/// Creates a new EulerAncestral Discrete scheduler given the number of steps to be
/// used for inference as well as the number of steps that was used
/// during training.
pub fn new(
inference_steps: usize,
config: EulerAncestralDiscreteSchedulerConfig,
) -> Result<Self> {
let step_ratio = config.train_timesteps / inference_steps;
let timesteps: Vec<usize> = match config.timestep_spacing {
TimestepSpacing::Leading => (0..(inference_steps))
.map(|s| s * step_ratio + config.steps_offset)
.rev()
.collect(),
TimestepSpacing::Trailing => std::iter::successors(Some(config.train_timesteps), |n| {
if *n > step_ratio {
Some(n - step_ratio)
} else {
None
}
})
.map(|n| n - 1)
.collect(),
TimestepSpacing::Linspace => {
super::utils::linspace(0.0, (config.train_timesteps - 1) as f64, inference_steps)?
.to_vec1::<f64>()?
.iter()
.map(|&f| f as usize)
.rev()
.collect()
}
};
let betas = match config.beta_schedule {
BetaSchedule::ScaledLinear => super::utils::linspace(
config.beta_start.sqrt(),
config.beta_end.sqrt(),
config.train_timesteps,
)?
.sqr()?,
BetaSchedule::Linear => {
super::utils::linspace(config.beta_start, config.beta_end, config.train_timesteps)?
}
BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999)?,
};
let betas = betas.to_vec1::<f64>()?;
let mut alphas_cumprod = Vec::with_capacity(betas.len());
for &beta in betas.iter() {
let alpha = 1.0 - beta;
alphas_cumprod.push(alpha * *alphas_cumprod.last().unwrap_or(&1f64))
}
let sigmas: Vec<f64> = alphas_cumprod
.iter()
.map(|&f| ((1. - f) / f).sqrt())
.collect();
let sigmas_xa: Vec<_> = (0..sigmas.len()).map(|i| i as f64).collect();
let mut sigmas_int = interp(
&timesteps.iter().map(|&t| t as f64).collect::<Vec<_>>(),
&sigmas_xa,
&sigmas,
);
sigmas_int.push(0.0);
// standard deviation of the initial noise distribution
// f64 does not implement Ord such that there is no `max`, so we need to use this workaround
let init_noise_sigma = *sigmas_int
.iter()
.chain(std::iter::once(&0.0))
.reduce(|a, b| if a > b { a } else { b })
.expect("init_noise_sigma could not be reduced from sigmas - this should never happen");
Ok(Self {
sigmas: sigmas_int,
timesteps,
init_noise_sigma,
config,
})
}
}
impl Scheduler for EulerAncestralDiscreteScheduler {
fn timesteps(&self) -> &[usize] {
self.timesteps.as_slice()
}
/// Ensures interchangeability with schedulers that need to scale the denoising model input
/// depending on the current timestep.
///
/// Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm
fn scale_model_input(&self, sample: Tensor, timestep: usize) -> Result<Tensor> {
let step_index = match self.timesteps.iter().position(|&t| t == timestep) {
Some(i) => i,
None => bail!("timestep out of this schedulers bounds: {timestep}"),
};
let sigma = self
.sigmas
.get(step_index)
.expect("step_index out of sigma bounds - this shouldn't happen");
sample / ((sigma.powi(2) + 1.).sqrt())
}
/// Performs a backward step during inference.
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()
.position(|&p| p == timestep)
.ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
let sigma_from = &self.sigmas[step_index];
let sigma_to = &self.sigmas[step_index + 1];
// 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
let pred_original_sample = match self.config.prediction_type {
PredictionType::Epsilon => (sample - (model_output * *sigma_from))?,
PredictionType::VPrediction => {
((model_output * (-sigma_from / (sigma_from.powi(2) + 1.0).sqrt()))?
+ (sample / (sigma_from.powi(2) + 1.0))?)?
}
PredictionType::Sample => bail!("prediction_type not implemented yet: sample"),
};
let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2))
/ sigma_from.powi(2))
.sqrt();
let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt();
// 2. convert to a ODE derivative
let derivative = ((sample - pred_original_sample)? / *sigma_from)?;
let dt = sigma_down - *sigma_from;
let prev_sample = (sample + derivative * dt)?;
let noise = prev_sample.randn_like(0.0, 1.0)?;
prev_sample + noise * sigma_up
}
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor> {
let step_index = self
.timesteps
.iter()
.position(|&p| p == timestep)
.ok_or_else(|| Error::Msg("timestep out of this schedulers bounds".to_string()))?;
let sigma = self
.sigmas
.get(step_index)
.expect("step_index out of sigma bounds - this shouldn't happen");
original + (noise * *sigma)?
}
fn init_noise_sigma(&self) -> f64 {
match self.config.timestep_spacing {
TimestepSpacing::Trailing | TimestepSpacing::Linspace => self.init_noise_sigma,
TimestepSpacing::Leading => (self.init_noise_sigma.powi(2) + 1.0).sqrt(),
}
}
}

View File

@ -3,7 +3,6 @@ pub mod clip;
pub mod ddim;
pub mod ddpm;
pub mod embeddings;
pub mod euler_ancestral_discrete;
pub mod resnet;
pub mod schedulers;
pub mod unet_2d;
@ -11,13 +10,9 @@ pub mod unet_2d_blocks;
pub mod utils;
pub mod vae;
use std::sync::Arc;
use candle::{DType, Device, Result};
use candle_nn as nn;
use self::schedulers::{Scheduler, SchedulerConfig};
#[derive(Clone, Debug)]
pub struct StableDiffusionConfig {
pub width: usize,
@ -26,7 +21,7 @@ pub struct StableDiffusionConfig {
pub clip2: Option<clip::Config>,
autoencoder: vae::AutoEncoderKLConfig,
unet: unet_2d::UNet2DConditionModelConfig,
scheduler: Arc<dyn SchedulerConfig>,
scheduler: ddim::DDIMSchedulerConfig,
}
impl StableDiffusionConfig {
@ -80,18 +75,13 @@ impl StableDiffusionConfig {
512
};
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
prediction_type: schedulers::PredictionType::Epsilon,
..Default::default()
});
StableDiffusionConfig {
Self {
width,
height,
clip: clip::Config::v1_5(),
clip2: None,
autoencoder,
scheduler,
scheduler: Default::default(),
unet,
}
}
@ -134,10 +124,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
let scheduler = ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
});
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -153,7 +143,7 @@ impl StableDiffusionConfig {
768
};
StableDiffusionConfig {
Self {
width,
height,
clip: clip::Config::v2_1(),
@ -215,10 +205,10 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
let scheduler = ddim::DDIMSchedulerConfig {
prediction_type,
..Default::default()
});
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -234,76 +224,6 @@ impl StableDiffusionConfig {
1024
};
StableDiffusionConfig {
width,
height,
clip: clip::Config::sdxl(),
clip2: Some(clip::Config::sdxl2()),
autoencoder,
scheduler,
unet,
}
}
fn sdxl_turbo_(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
prediction_type: schedulers::PredictionType,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = Arc::new(
euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig {
prediction_type,
timestep_spacing: schedulers::TimestepSpacing::Trailing,
..Default::default()
},
);
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
512
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
512
};
Self {
width,
height,
@ -329,20 +249,6 @@ impl StableDiffusionConfig {
)
}
pub fn sdxl_turbo(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
Self::sdxl_turbo_(
sliced_attention_size,
height,
width,
// https://huggingface.co/stabilityai/sdxl-turbo/blob/main/scheduler/scheduler_config.json
schedulers::PredictionType::Epsilon,
)
}
pub fn ssd1b(
sliced_attention_size: Option<usize>,
height: Option<usize>,
@ -379,9 +285,9 @@ impl StableDiffusionConfig {
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = Arc::new(ddim::DDIMSchedulerConfig {
let scheduler = ddim::DDIMSchedulerConfig {
..Default::default()
});
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
@ -441,8 +347,8 @@ impl StableDiffusionConfig {
Ok(unet)
}
pub fn build_scheduler(&self, n_steps: usize) -> Result<Box<dyn Scheduler>> {
self.scheduler.build(n_steps)
pub fn build_scheduler(&self, n_steps: usize) -> Result<ddim::DDIMScheduler> {
ddim::DDIMScheduler::new(n_steps, self.scheduler)
}
}

View File

@ -3,25 +3,9 @@
//!
//! Noise schedulers can be used to set the trade-off between
//! inference speed and quality.
use candle::{Result, Tensor};
pub trait SchedulerConfig: std::fmt::Debug {
fn build(&self, inference_steps: usize) -> Result<Box<dyn Scheduler>>;
}
/// This trait represents a scheduler for the diffusion process.
pub trait Scheduler {
fn timesteps(&self) -> &[usize];
fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Result<Tensor>;
fn init_noise_sigma(&self) -> f64;
fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Result<Tensor>;
fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Result<Tensor>;
}
/// This represents how beta ranges from its minimum value to the maximum
/// during training.
#[derive(Debug, Clone, Copy)]
@ -41,22 +25,6 @@ pub enum PredictionType {
Sample,
}
/// Time step spacing for the diffusion process.
///
/// "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
#[derive(Debug, Clone, Copy)]
pub enum TimestepSpacing {
Leading,
Linspace,
Trailing,
}
impl Default for TimestepSpacing {
fn default() -> Self {
Self::Leading
}
}
/// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
/// `(1-beta)` over time from `t = [0,1]`.
///

View File

@ -13,49 +13,3 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
Tensor::from_vec(vs, steps, &Device::Cpu)
}
}
/// A linear interpolator for a sorted array of x and y values.
struct LinearInterpolator<'x, 'y> {
xp: &'x [f64],
fp: &'y [f64],
cache: usize,
}
impl<'x, 'y> LinearInterpolator<'x, 'y> {
fn accel_find(&mut self, x: f64) -> usize {
let xidx = self.cache;
if x < self.xp[xidx] {
self.cache = self.xp[0..xidx].partition_point(|o| *o < x);
self.cache = self.cache.saturating_sub(1);
} else if x >= self.xp[xidx + 1] {
self.cache = self.xp[xidx..self.xp.len()].partition_point(|o| *o < x) + xidx;
self.cache = self.cache.saturating_sub(1);
}
self.cache
}
fn eval(&mut self, x: f64) -> f64 {
if x < self.xp[0] || x > self.xp[self.xp.len() - 1] {
return f64::NAN;
}
let idx = self.accel_find(x);
let x_l = self.xp[idx];
let x_h = self.xp[idx + 1];
let y_l = self.fp[idx];
let y_h = self.fp[idx + 1];
let dx = x_h - x_l;
if dx > 0.0 {
y_l + (x - x_l) / dx * (y_h - y_l)
} else {
f64::NAN
}
}
}
pub fn interp(x: &[f64], xp: &[f64], fp: &[f64]) -> Vec<f64> {
let mut interpolator = LinearInterpolator { xp, fp, cache: 0 };
x.iter().map(|&x| interpolator.eval(x)).collect()
}

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -108,7 +108,7 @@ impl Component for App {
fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
match msg {
Msg::SetModel(md) => {
self.status = "weights loaded successfully!".to_string();
self.status = "weights loaded succesfully!".to_string();
self.loaded = true;
console_log!("loaded weights");
self.worker.send(WorkerInput::ModelData(md));

View File

@ -24,7 +24,7 @@ macro_rules! console_log {
}
// Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transferred via the following structure.
// on the main thread and transfered via the following structure.
#[derive(Serialize, Deserialize)]
pub struct ModelData {
pub tokenizer: Vec<u8>,

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }

View File

@ -1,7 +1,7 @@
<html>
<head>
<meta content="text/html;charset=utf-8" http-equiv="Content-Type" />
<title>Candle Phi 1.5 / Phi 2.0 Rust/WASM</title>
<title>Candle Phi 1.5 Rust/WASM</title>
</head>
<body></body>
</html>
@ -39,7 +39,7 @@
import hljs from "https://cdn.skypack.dev/highlight.js";
// models base url
const MODELS = {
phi_1_5_q4k: {
phi_1_5_quantized: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-q4k.gguf",
@ -49,7 +49,7 @@
seq_len: 2048,
size: "800 MB",
},
phi_1_5_q80: {
phi_1_5_quantized_2: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-q80.gguf",
@ -59,21 +59,7 @@
seq_len: 2048,
size: "1.51 GB",
},
phi_2_0_q4k: {
base_url:
"https://huggingface.co/radames/phi-2-quantized/resolve/main/",
model: [
"model-v2-q4k.gguf_aa.part",
"model-v2-q4k.gguf_ab.part",
"model-v2-q4k.gguf_ac.part",
],
tokenizer: "tokenizer.json",
config: "config.json",
quantized: true,
seq_len: 2048,
size: "1.57GB",
},
puffin_phi_v2_q4k: {
puffin_phi_v2_quantized: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-puffin-phi-v2-q4k.gguf",
@ -83,7 +69,7 @@
seq_len: 2048,
size: "798 MB",
},
puffin_phi_v2_q80: {
puffin_phi_v2_quantized_2: {
base_url:
"https://huggingface.co/lmz/candle-quantized-phi/resolve/main/",
model: "model-puffin-phi-v2-q80.gguf",
@ -120,8 +106,8 @@ Lets think step by step.`,
},
{
title: "Question answering",
prompt: `Instruct: What is the capital of France?
Output:`,
prompt: `What is the capital of France?
Answer:`,
},
{
title: "Chat mode",
@ -162,10 +148,7 @@ Very polite review:`,
const getValue = (id) => document.querySelector(`#${id}`).value;
const modelID = getValue("model");
const model = MODELS[modelID];
const weightsURL =
model.model instanceof Array
? model.model.map((m) => model.base_url + m)
: model.base_url + model.model;
const weightsURL = model.base_url + model.model;
const tokenizerURL = model.base_url + model.tokenizer;
const configURL = model.base_url + model.config;
@ -263,13 +246,6 @@ Very polite review:`,
option.innerText = `${id} (${model.size})`;
modelSelect.appendChild(option);
}
const query = new URLSearchParams(window.location.search);
const modelID = query.get("model");
if (modelID) {
modelSelect.value = modelID;
} else {
modelSelect.value = "phi_1_5_q4k";
}
for (const [i, { title, prompt }] of TEMPLATES.entries()) {
const div = document.createElement("div");
@ -294,18 +270,8 @@ Very polite review:`,
prompt.value = template;
prompt.style.height = "auto";
prompt.style.height = prompt.scrollHeight + "px";
runBtn.disabled = false;
clearBtn.classList.remove("invisible");
});
modelSelect.addEventListener("change", (e) => {
const query = new URLSearchParams(window.location.search);
query.set("model", e.target.value);
window.history.replaceState(
{},
"",
`${window.location.pathname}?${query}`
);
window.parent.postMessage({ queryString: "?" + query }, "*");
const model = MODELS[e.target.value];
document.querySelector("#max-seq").max = model.seq_len;
document.querySelector("#max-seq").nextElementSibling.value = 200;
@ -354,7 +320,7 @@ Very polite review:`,
<main class="grid grid-cols-1 gap-8 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
<h1 class="text-5xl font-bold">Candle Phi 1.5 / Phi 2.0</h1>
<h1 class="text-5xl font-bold">Candle Phi 1.5</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
The
@ -364,17 +330,10 @@ Very polite review:`,
target="_blank"
>Phi-1.5</a
>
and
<a
href="https://huggingface.co/microsoft/phi-2"
class="link"
target="_blank"
>Phi-2</a
>
models achieve state-of-the-art performance with only 1.3 billion and
2.7 billion parameters, compared to larger models with up to 13
billion parameters. Here you can try the quantized versions.
Additional prompt examples are available in the
model achieves state-of-the-art performance with only 1.3 billion
parameters, compared to models with up to 10 billion. You can try the
quantized version of the model here. Additional prompt examples are
available in the
<a
href="https://arxiv.org/pdf/2309.05463.pdf#page=8"
class="link"
@ -391,7 +350,7 @@ Very polite review:`,
target="_blank"
>Puffin-Phi V2
</a>
quantized version, a fine-tuned version of Phi-1.5 on the
quantized version model, a fine-tuned version of Phi-1.5 on the
<a
href="https://huggingface.co/datasets/LDJnr/Puffin"
class="link"
@ -404,7 +363,7 @@ Very polite review:`,
<p class="text-xs italic max-w-lg">
<b>Note:</b>
When first run, the app will download and cache the model, which could
take a few minutes. The models are <b>~800MB</b> or <b>~1.57GB</b> in
take a few minutes. The models are <b>~800MB</b> or <b>~1.51GB</b> in
size.
</p>
</div>
@ -416,13 +375,8 @@ Very polite review:`,
></select>
</div>
<div>
<details>
<summary class="font-medium cursor-pointer">Prompt Templates</summary>
<form
id="prompt-templates"
class="grid grid-cols-1 sm:grid-cols-2 gap-1 my-2"
></form>
</details>
<h3 class="font-medium">Prompt Templates</h3>
<form id="prompt-templates" class="flex flex-col gap-1 my-2"></form>
</div>
<form
id="form"
@ -432,12 +386,12 @@ Very polite review:`,
<textarea
type="text"
id="prompt"
class="font-light text-lg w-full px-3 py-2 mx-1 resize-none outline-none"
class="font-light w-full px-3 py-2 mx-1 resize-none outline-none"
oninput="this.style.height = 0;this.style.height = this.scrollHeight + 'px'"
placeholder="Add your prompt here..."
>
Instruct: Write a detailed analogy between mathematics and a lighthouse.
Output:</textarea
Write a detailed analogy between mathematics and a lighthouse.
Answer:</textarea
>
<button id="clear-btn">
<svg
@ -563,9 +517,9 @@ Output:</textarea
<div
id="output-counter"
hidden
class="ml-auto font-semibold grid-rows-1"
class="ml-auto font-semibold grid-rows-1 text-sm"
></div>
<p hidden id="output-generation" class="grid-rows-2 text-lg"></p>
<p hidden id="output-generation" class="grid-rows-2"></p>
<span id="output-status" class="m-auto font-light"
>No output yet</span
>

View File

@ -12,20 +12,6 @@ async function fetchArrayBuffer(url) {
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
async function concatenateArrayBuffers(urls) {
const arrayBuffers = await Promise.all(urls.map(url => fetchArrayBuffer(url)));
let totalLength = arrayBuffers.reduce((acc, arrayBuffer) => acc + arrayBuffer.byteLength, 0);
let concatenatedBuffer = new Uint8Array(totalLength);
let offset = 0;
arrayBuffers.forEach(buffer => {
concatenatedBuffer.set(new Uint8Array(buffer), offset);
offset += buffer.byteLength;
});
return concatenatedBuffer;
}
class Phi {
static instance = {};
@ -41,9 +27,10 @@ class Phi {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
await Promise.all([
weightsURL instanceof Array ? concatenateArrayBuffers(weightsURL) : fetchArrayBuffer(weightsURL),
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(configURL),
]);

View File

@ -5,7 +5,6 @@ use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausa
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
use candle_wasm_example_phi::console_log;
use js_sys::Date;
use serde::Deserialize;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
@ -24,12 +23,6 @@ pub struct Model {
repeat_last_n: usize,
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct ModelName {
pub _name_or_path: String,
}
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
@ -41,25 +34,15 @@ impl Model {
) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let name: ModelName = serde_json::from_slice(&config)?;
let config: Config = serde_json::from_slice(&config)?;
console_log!("config loaded {:?}", name);
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let start = Date::now();
console_log!("weights len: {:?}", weights.len());
let model = if quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf_buffer(&weights)?;
console_log!("weights loaded");
if name._name_or_path == "microsoft/phi-2" {
let model = QMixFormer::new_v2(&config, vb)?;
SelectedModel::Quantized(model)
} else {
let model = QMixFormer::new(&config, vb)?;
SelectedModel::Quantized(model)
}
} else {
let device = &Device::Cpu;
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, device)?;

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
# App crates.

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -9,9 +9,9 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.3" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.1" }
num-traits = { workspace = true }
tokenizers = { workspace = true, features = ["unstable_wasm"] }

View File

@ -145,7 +145,7 @@ impl Component for App {
fn update(&mut self, ctx: &Context<Self>, msg: Self::Message) -> bool {
match msg {
Msg::SetDecoder(md) => {
self.status = "weights loaded successfully!".to_string();
self.status = "weights loaded succesfully!".to_string();
self.loaded = true;
console_log!("loaded weights");
self.worker.send(WorkerInput::ModelData(md));

View File

@ -414,7 +414,7 @@ pub enum Task {
}
// Communication to the worker happens through bincode, the model weights and configs are fetched
// on the main thread and transferred via the following structure.
// on the main thread and transfered via the following structure.
#[derive(Serialize, Deserialize)]
pub struct ModelData {
pub weights: Vec<u8>,

View File

@ -9,8 +9,8 @@ categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.3", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.3" }
candle = { path = "../../candle-core", version = "0.3.1", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.1" }
num-traits = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

Some files were not shown because too many files have changed in this diff Show More