Compare commits

..

41 Commits

Author SHA1 Message Date
b97463098c llama2-c wasm fix. 2023-11-02 10:31:47 +01:00
fbd69f952c Lazy detach. (#1242) 2023-11-02 07:33:48 +00:00
6c990a33ea Remove the unused pragma for marian. (#1236) 2023-11-01 20:04:52 +00:00
1704f1b3ae Consolidate the with-tracing usage. (#1234) 2023-11-01 18:21:36 +00:00
693fad511c Preliminary support for ssd1b. (#1233) 2023-11-01 14:37:52 +00:00
36fb84f038 Add a hack for generating random uniform/normal for f16/bf16. (#1228) 2023-10-31 20:27:59 +00:00
c12ad45562 Add a KV cache to marian decoding. (#1226) 2023-10-31 08:47:44 +00:00
7d0202710b Instructions for generating the tokenizer configs for marian-mt. (#1225) 2023-10-31 07:56:26 +01:00
392a00a147 Add support for the marian base model. (#1221) 2023-10-30 19:20:36 +00:00
4c967b9184 Use the hub files for the marian example. (#1220)
* Use the hub files for the marian example.

* Use the secondary decoder.

* Add a readme.

* More readme.
2023-10-30 17:29:36 +00:00
c05c0a8213 PyO3: Add equal and __richcmp__ to candle.Tensor (#1099)
* add `equal` to tensor

* add `__richcmp__` support  for tensors and scalars

* typo

* more typos

* Add `abs` + `candle.testing`

* remove duplicated `broadcast_shape_binary_op`

* `candle.i16` => `candle.i64`

* `tensor.nelements` -> `tensor.nelement`

* Cleanup `abs`
2023-10-30 15:17:28 +00:00
969960847a Bugfixes for marian-mt. (#1219)
* Bugfixes for marian-mt.

* Apply the final decoding head.

* More fixes.
2023-10-30 11:44:19 +00:00
5fc66bd4ba Support negative steps in arange. (#1218) 2023-10-30 07:40:54 +00:00
174b208052 PyO3: Better shape handling (#1143)
* Negative and `*args` shape handling

* Rename to `PyShapeWithHole` + validate that only one hole exists

* Regenerate stubs

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2023-10-29 15:41:44 +00:00
154c674a79 Add i64-abs. (#1216) 2023-10-29 15:28:53 +00:00
7bbde55c61 Marian MT model (#1210)
* Skeleton files for the marian MT model.

* Marian initialization.

* Implement the attention forward method.

* Forward pass for the encoder side.

* Expose the encoder and decoder.

* Start plugging the decoder.

* Forward pass for the decoder layer.

* Set up the marian example.

* Add some missing backtraces.

* Bugfix.
2023-10-29 15:12:22 +00:00
c3f2676d49 PyO3: Add CI to build & upload wheels as artifacts. (#1215)
* Add maturin ci

* fix paths

* Change sdist path
2023-10-29 13:44:05 +00:00
46d6566c99 Fix the conv2d gradient computation. (#1214) 2023-10-29 09:50:04 +00:00
55bc3382cf Allow for different behavior between training and eval (#1213)
* Forward with training.

* Do not use dropout on vgg evaluation.
2023-10-29 07:53:09 +01:00
dece37c6f4 feat: implement VGG13, VGG16 and VGG19 (#1211)
* feat: implement VGG13, VGG16 and VGG19

* Cosmetic fixes.

* More cosmetic tweaks + avoid re-loading the weights on each final layer.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-10-29 06:10:23 +00:00
498c50348c Add DDPG and fix Gym wrapper (#1207)
* Fix Gym wrapper
- It was returning things in the wrong order
- Gym now differentiates between terminated and truncated

* Add DDPG

* Apply fixes

* Remove Result annotations

* Also remove Vec annotation

* rustfmt

* Various small improvements (avoid cloning, mutability, get clippy to pass, ...)

---------

Co-authored-by: Travis Hammond <travis.hammond@alexanderthamm.com>
Co-authored-by: Laurent <laurent.mazare@gmail.com>
2023-10-28 19:53:34 +01:00
012ae0090e Infer the config for llama2-c. (#1208) 2023-10-28 19:00:39 +01:00
95a857cf57 Move the llama2-c model in transformers. (#1205) 2023-10-28 16:51:19 +01:00
612f5b8156 Make more models cloneable. (#1203) 2023-10-28 07:43:08 +01:00
ef33df7ae2 No need for the even constraint on vecdot-q40-q80. (#1202) 2023-10-28 07:23:59 +01:00
c8face3f95 Add the relu2 and relu6 activations. (#1201) 2023-10-27 20:51:16 +01:00
85bea43e5b Make the whisper model cloneable (#1200)
* Add a quantized variant of llama2.c

* Clippy fixes.

* Make the whisper model cloneable.
2023-10-27 16:59:19 +01:00
b3181455d5 Add fuse-conv-bn method for Conv2d (#1196)
* Add fuse-conv-bn method for Conv2d

* no unwrap

* run rustfmp and clippy
2023-10-27 15:56:50 +01:00
e2826e70b3 Add a quantized variant of llama2.c (#1197)
* Add a quantized variant of llama2.c

* Clippy fixes.
2023-10-27 15:34:06 +01:00
916619f70b Minor cleanup (#1194)
* Add some missing backtraces.

* Small cleanup.
2023-10-27 14:08:29 +01:00
9b1158b315 Add some missing backtraces. (#1193) 2023-10-27 06:09:11 +01:00
70d06ab4b0 Add support for the phi-hermes finetuned model. (#1192) 2023-10-27 05:57:08 +01:00
0ec5ebcec4 Use the hub model file when possible. (#1190)
* Use the hub model file when possible.

* And add a mention in the main readme.
2023-10-26 20:00:50 +01:00
c8e197f68c Fixes for jina-bert. (#1189) 2023-10-26 18:52:30 +01:00
5f20697918 Add the jina-bert embeddings model. (#1187)
* Add the jina-bert model.

* Use alibi.

* Remove the unused pragma.

* Recompute the alibi embeddings.

* Generate the token type ids.

* Use the module trait.

* Add the jina-bert example.

* DType fix.

* Get the inference to work.
2023-10-26 16:54:36 +01:00
e37b487767 Add Blip to online demos README.md (#1184)
* Add Blip to online demos README.md

* Punctuation.

---------

Co-authored-by: Laurent Mazare <laurent.mazare@gmail.com>
2023-10-26 11:07:01 +01:00
e5dc8cb4f4 [Wasm] BLIP Example (#1183)
* blip wasm start

* fix dependency issue, move token stream here

* vanilla js worker

* roll back vscode

* spell
2023-10-26 07:24:02 +01:00
e7b886d56f Add a link to the optimisers crate. (#1180) 2023-10-25 21:51:45 +01:00
6a446d9d73 convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor

* separate tests for convert pytorch tensor
2023-10-25 19:39:14 +01:00
0acd16751d Expose the fields from batch-norm. (#1176) 2023-10-25 15:35:32 +01:00
c698e17619 Enable the test for meshgrid + fix the implementation. (#1175) 2023-10-25 13:47:54 +01:00
87 changed files with 5837 additions and 371 deletions

BIN
.github/workflows/maturin.yml vendored Normal file

Binary file not shown.

View File

@ -7,13 +7,7 @@ members = [
"candle-nn",
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
"candle-wasm-examples/phi",
"candle-wasm-examples/t5",
"candle-wasm-examples/*",
"candle-wasm-tests",
]
exclude = ["candle-flash-attn", "candle-kernels"]

View File

@ -56,6 +56,7 @@ These online demos run entirely in your browser:
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
We also provide a some command line based examples using state of the art models:
@ -95,12 +96,15 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text.
Run them using commands like:
```
@ -135,6 +139,8 @@ And then head over to
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
that conforms to the official `peft` implementation.
@ -170,6 +176,8 @@ If you have an addition to this list, please submit a pull request.
- Wurstchen v2.
- Image to text.
- BLIP.
- Text to text.
- Marian MT (Machine Translation).
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8.

View File

@ -238,6 +238,13 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {

View File

@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
let src = match src_l.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "index-select" })?,
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
};
let dim = self.dim;
let n_ids = match self.ids_l.dims() {
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" })?,
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
};
for left_i in 0..ids_left_len {
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::U32(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" })?,
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
}
}

View File

@ -185,8 +185,14 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_uniform(shape, DType::F32, lo, up)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}
}
@ -213,8 +219,14 @@ impl Device {
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
// TODO: Remove the special case if we start supporting generating f16/bf16 directly.
if dtype == DType::F16 || dtype == DType::BF16 {
let storage = device.rand_normal(shape, DType::F32, mean, std)?;
Storage::Cuda(storage).to_dtype(&crate::Layout::contiguous(shape), dtype)
} else {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
}
}

View File

@ -125,3 +125,15 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
self(xs)
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}

View File

@ -536,7 +536,6 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
@ -666,6 +665,40 @@ impl UnaryOpT for Erf {
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";
@ -887,6 +920,10 @@ impl BackpropOp {
};
Self(op)
}
pub(crate) fn is_none(&self) -> bool {
self.0.is_none()
}
}
impl std::ops::Deref for BackpropOp {

View File

@ -50,14 +50,9 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {

View File

@ -236,14 +236,9 @@ impl GgmlType for BlockQ4_0 {
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
// Generic implementation.
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {

View File

@ -19,42 +19,29 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
for i in 0..nb {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];
let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8);
let v0_0 = vld1q_u8(x0.qs.as_ptr());
let v0_1 = vld1q_u8(x1.qs.as_ptr());
// 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b);
let v0_1ls = vsubq_s8(v0_1l, s8b);
let v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
let v1_1l = vld1q_s8(y1.qs.as_ptr());
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
@ -62,28 +49,16 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
Ok(vaddvq_f32(sumv0))
}
}
@ -94,28 +69,18 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
for i in 0..nb {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];
let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
let x1_0 = vld1q_s8(x1.qs.as_ptr());
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
// load y
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
let y1_0 = vld1q_s8(y1.qs.as_ptr());
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
@ -123,28 +88,16 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(p2, p3)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
Ok(vaddvq_f32(sumv0))
}
}

View File

@ -11,10 +11,6 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
@ -61,10 +57,6 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {

View File

@ -203,7 +203,7 @@ impl Shape {
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();

View File

@ -385,11 +385,21 @@ impl Tensor {
step: D,
device: &Device,
) -> Result<Self> {
if D::is_zero(&step) {
crate::bail!("step cannot be zero")
}
let mut data = vec![];
let mut current = start;
while current < end {
data.push(current);
current += step;
if step >= D::zero() {
while current < end {
data.push(current);
current += step;
}
} else {
while current > end {
data.push(current);
current += step;
}
}
let len = data.len();
Self::from_vec_impl(data, len, device, false)
@ -1186,14 +1196,16 @@ impl Tensor {
op: "scatter-add (self, src)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
})?
}
.bt())?
}
if indexes.dims() != source.dims() {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (indexes, src)",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
})?
}
.bt())?
}
let storage = self.storage().scatter_add(
self.layout(),
@ -1265,7 +1277,8 @@ impl Tensor {
op: "slice-scatter (self, src)",
lhs: self.shape().clone(),
rhs: src.shape().clone(),
})?
}
.bt())?
}
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
self.storage()
@ -1299,7 +1312,8 @@ impl Tensor {
op: "index-add (self, source)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
})?
}
.bt())?
}
// The number of element in indexes must match the dimension on which the add is
// performed on the source tensor (and the index values from `indexes` are taken from
@ -1310,7 +1324,8 @@ impl Tensor {
op: "index-add (ids, source))",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
})?
}
.bt())?
}
let storage = self.storage().index_add(
self.layout(),
@ -1358,7 +1373,8 @@ impl Tensor {
op: "gather",
lhs: self.shape().clone(),
rhs: indexes.shape().clone(),
})?
}
.bt())?
}
let storage =
self.storage()
@ -1791,17 +1807,23 @@ impl Tensor {
/// Returns a new tensor detached from the current graph, gradient are not propagated through
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
if self.op.is_none() && !self.is_variable {
Ok(self.clone())
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: self.storage.clone(),
layout: self.layout.clone(),
op: BackpropOp::none(),
is_variable: false,
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
}
}
/// If the target device is the same as the tensor device, only a shallow copy is performed.
@ -2265,6 +2287,11 @@ impl Tensor {
m.forward(self)
}
/// Run the `forward` method of `m` on `self`.
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
m.forward_t(self, train)
}
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}

View File

@ -479,6 +479,71 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
]
);
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 7.87, 0.0, 0.0],
[-1.8, -7.82, 5.9, 0.0, 0.0],
[-3.12, 4.49, 5.52, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[21.73, 3.39, 4.77, 0.0, 0.0],
[8.25, 3.73, 27.61, 0.0, 0.0],
[-20.55, -5.61, -2.77, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-8.98, 9.91, -7.15, 0.0, 0.0],
[4.93, -0.33, 4.56, 0.0, 0.0],
[-6.7, -5.76, -8.05, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[23.54, 6.98, -10.0, 0.0, 0.0],
[9.65, 6.18, 18.72, 0.0, 0.0],
[3.29, -5.27, 0.79, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[-3.47, 7.44, 0.66],
[12.89, -3.4, -9.29],
[-14.16, -0.83, 7.14]
],
[
[-3.23, 5.37, -3.02],
[-2.12, -11.24, 1.94],
[6.97, 7.2, 2.99]
],
[
[-4.04, -3.31, 4.87],
[-6.68, -5.68, 1.73],
[-5.54, 4.32, 0.52]
],
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
]
);
Ok(())
}

View File

@ -29,7 +29,26 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
Ok(())
}
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
[0, 2, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
[0, 3],
);
assert_eq!(
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
[5, 4, 3, 2, 1],
);
Ok(())
}
@ -1037,6 +1056,7 @@ fn randn(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu);
test_device!(arange, arange_cpu, arange_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
@ -1089,3 +1109,11 @@ fn pad_with_same() -> Result<()> {
);
Ok(())
}
#[test]
fn i64_abs() -> Result<()> {
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
let t = t.abs()?;
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}

View File

@ -21,7 +21,7 @@ half = { workspace = true, optional = true }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true }
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }

View File

@ -149,6 +149,6 @@ pub fn main() -> anyhow::Result<()> {
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -0,0 +1,45 @@
# candle-jina-bert
Jina-Bert is a general large language model with a context size of 8192, [model
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
it can be used for two different tasks:
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
## Sentence embeddings
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
> ...
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
> Tensor[[1, 7, 768], f32]
```
## Similarities
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
sentences (hardcoded in the examples). Then cosine similarities are computed for
each sentence pair and they are reported by decreasing values, hence the first
reported pair contains the two sentences that have the highest similarity score.
The sentence embeddings are computed using average pooling through all the
sentence tokens, including some potential padding.
```bash
cargo run --example jina-bert --release
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
> score: 0.78 'I love pasta' 'Do you like pizza?'
> score: 0.68 'I love pasta' 'The new movie is awesome'
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
```

View File

@ -0,0 +1,180 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::jina_bert::{BertModel, Config};
use anyhow::Error as E;
use candle::{DType, Module, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
model: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model = match &self.model {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()?
.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
RepoType::Model,
))
.get("model.safetensors")?,
};
let tokenizer = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()?
.repo(Repo::new(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
RepoType::Model,
))
.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = embeddings.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

View File

@ -6,9 +6,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
mod model;
use candle_transformers::models::llama2_c as model;
use candle_transformers::models::llama2_c_weights as weights;
use candle_transformers::models::quantized_llama2_c as qmodel;
mod training;
mod weights;
use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result};
@ -19,6 +20,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use model::{Config, Llama};
use qmodel::QLlama;
use weights::TransformerWeights;
#[derive(Parser, Debug, Clone)]
@ -152,6 +154,20 @@ fn main() -> anyhow::Result<()> {
Ok(())
}
enum Model {
Llama(Llama),
QLlama(QLlama),
}
impl Model {
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
match self {
Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
}
}
}
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
use std::io::BufRead;
@ -241,24 +257,66 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (vb, config) = if is_safetensors {
let config = Config::tiny();
let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
.dims2()?;
let config = match dim {
64 => Config::tiny_260k(),
288 => Config::tiny_15m(),
512 => Config::tiny_42m(),
768 => Config::tiny_110m(),
_ => anyhow::bail!("no config for dim {dim}"),
};
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
.dequantize(&candle::Device::Cpu)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
.dequantize(&candle::Device::Cpu)?;
let fake_vb = candle_nn::VarBuilder::from_tensors(
[
("freq_cis_real".to_string(), freq_cis_real),
("freq_cis_imag".to_string(), freq_cis_imag),
]
.into_iter()
.collect(),
candle::DType::F32,
&candle::Device::Cpu,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
let config = Config::tiny_15m();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
(vb, config)
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
(vb, config)
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
};
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
@ -273,7 +331,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let start_gen = std::time::Instant::now();
for index in 0.. {
if tokens.len() >= model.config.seq_len {
if tokens.len() >= config.seq_len {
break;
}
let context_size = if index > 0 { 1 } else { tokens.len() };

View File

@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny();
let config = Config::tiny_15m();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);

View File

@ -0,0 +1,38 @@
# candle-marian-mt
`marian-mt` is a neural machine translation model. In this example it is used to
translate text from French to English. See the associated [model
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
the model itself.
## Running an example
```bash
cargo run --example marian-mt --release -- \
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
```
```
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
I know you are waiting for me. I will go through the forest, I will go through the
mountain. I cannot stay far from you any longer.</s>
```
## Generating the tokenizer.json files
You can use the following script to generate the `tokenizer.json` config files
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
packages to be install and use the `convert_slow_tokenizer.py` script from this
directory.
```python
from convert_slow_tokenizer import MarianConverter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
```

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,152 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Big,
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
tokenizer_dec: Option<String>,
/// Choose the variant of the model to run.
#[arg(long, default_value = "big")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
/// Text to be translated
#[arg(long)]
text: String,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let config = match args.which {
Which::Base => marian::Config::opus_mt_fr_en(),
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
};
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-fr.json",
Which::Big => "tokenizer-marian-fr.json",
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let tokenizer_dec = {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-en.json",
Which::Big => "tokenizer-marian-en.json",
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-fr-en".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
))
.get("model.safetensors")?,
Which::Big => Api::new()?
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
.get("model.safetensors")?,
},
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
let mut model = marian::MTModel::new(&config, vb)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let encoder_xs = {
let mut tokens = tokenizer
.encode(args.text, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.push(config.eos_token_id);
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
model.encoder().forward(&tokens, 0)?
};
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
use rand::prelude::*;
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
@ -95,7 +95,7 @@ impl ConvNet {
.flatten_from(1)?
.apply(&self.fc1)?
.relu()?;
self.dropout.forward(&xs, train)?.apply(&self.fc2)
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
}
}

View File

@ -124,6 +124,7 @@ enum WhichModel {
#[value(name = "1.5")]
V1_5,
PuffinPhiV2,
PhiHermes,
}
#[derive(Parser, Debug)]
@ -224,7 +225,9 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::PuffinPhiV2 => "lmz/candle-quantized-phi".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
}
}
}
@ -238,7 +241,7 @@ fn main() -> Result<()> {
match args.model {
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
WhichModel::PuffinPhiV2 => "main".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
}
}
}
@ -248,7 +251,9 @@ fn main() -> Result<()> {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 => repo.get("tokenizer-puffin-phi-v2.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
},
};
let filename = match args.weight_file {
@ -259,11 +264,13 @@ fn main() -> Result<()> {
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
}
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
}
}
}
@ -276,6 +283,7 @@ fn main() -> Result<()> {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;

View File

@ -0,0 +1,451 @@
use std::collections::VecDeque;
use std::fmt::Display;
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
use candle_nn::{
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
VarBuilder, VarMap,
};
use rand::{distributions::Uniform, thread_rng, Rng};
pub struct OuNoise {
mu: f64,
theta: f64,
sigma: f64,
state: Tensor,
}
impl OuNoise {
pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
Ok(Self {
mu,
theta,
sigma,
state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
})
}
pub fn sample(&mut self) -> Result<Tensor> {
let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
self.state = (&self.state + dx)?;
Ok(self.state.clone())
}
}
#[derive(Clone)]
struct Transition {
state: Tensor,
action: Tensor,
reward: Tensor,
next_state: Tensor,
terminated: bool,
truncated: bool,
}
impl Transition {
fn new(
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) -> Self {
Self {
state: state.clone(),
action: action.clone(),
reward: reward.clone(),
next_state: next_state.clone(),
terminated,
truncated,
}
}
}
pub struct ReplayBuffer {
buffer: VecDeque<Transition>,
capacity: usize,
size: usize,
}
impl ReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: VecDeque::with_capacity(capacity),
capacity,
size: 0,
}
}
pub fn push(
&mut self,
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) {
if self.size == self.capacity {
self.buffer.pop_front();
} else {
self.size += 1;
}
self.buffer.push_back(Transition::new(
state, action, reward, next_state, terminated, truncated,
));
}
#[allow(clippy::type_complexity)]
pub fn random_batch(
&self,
batch_size: usize,
) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
if self.size < batch_size {
Ok(None)
} else {
let transitions: Vec<&Transition> = thread_rng()
.sample_iter(Uniform::from(0..self.size))
.take(batch_size)
.map(|i| self.buffer.get(i).unwrap())
.collect();
let states: Vec<Tensor> = transitions
.iter()
.map(|t| t.state.unsqueeze(0))
.collect::<Result<_>>()?;
let actions: Vec<Tensor> = transitions
.iter()
.map(|t| t.action.unsqueeze(0))
.collect::<Result<_>>()?;
let rewards: Vec<Tensor> = transitions
.iter()
.map(|t| t.reward.unsqueeze(0))
.collect::<Result<_>>()?;
let next_states: Vec<Tensor> = transitions
.iter()
.map(|t| t.next_state.unsqueeze(0))
.collect::<Result<_>>()?;
let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
Ok(Some((
Tensor::cat(&states, 0)?,
Tensor::cat(&actions, 0)?,
Tensor::cat(&rewards, 0)?,
Tensor::cat(&next_states, 0)?,
terminateds,
truncateds,
)))
}
}
}
fn track(
varmap: &mut VarMap,
vb: &VarBuilder,
target_prefix: &str,
network_prefix: &str,
dims: &[(usize, usize)],
tau: f64,
) -> Result<()> {
for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
varmap.set_one(
format!("{target_prefix}-fc{i}.weight"),
((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
)?;
let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
varmap.set_one(
format!("{target_prefix}-fc{i}.bias"),
((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
)?;
}
Ok(())
}
struct Actor<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
network: Sequential,
target_network: Sequential,
size_state: usize,
size_action: usize,
dims: Vec<(usize, usize)>,
}
impl Actor<'_> {
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
let mut varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
let make_network = |prefix: &str| {
let seq = seq()
.add(linear(
dims[0].0,
dims[0].1,
vb.pp(format!("{prefix}-fc0")),
)?)
.add(Activation::Relu)
.add(linear(
dims[1].0,
dims[1].1,
vb.pp(format!("{prefix}-fc1")),
)?)
.add(Activation::Relu)
.add(linear(
dims[2].0,
dims[2].1,
vb.pp(format!("{prefix}-fc2")),
)?)
.add(func(|xs| xs.tanh()));
Ok::<Sequential, Error>(seq)
};
let network = make_network("actor")?;
let target_network = make_network("target-actor")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
Ok(Self {
varmap,
vb,
network,
target_network,
size_state,
size_action,
dims,
})
}
fn forward(&self, state: &Tensor) -> Result<Tensor> {
self.network.forward(state)
}
fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
self.target_network.forward(state)
}
fn track(&mut self, tau: f64) -> Result<()> {
track(
&mut self.varmap,
&self.vb,
"target-actor",
"actor",
&self.dims,
tau,
)
}
}
struct Critic<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
network: Sequential,
target_network: Sequential,
size_state: usize,
size_action: usize,
dims: Vec<(usize, usize)>,
}
impl Critic<'_> {
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
let mut varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
let make_network = |prefix: &str| {
let seq = seq()
.add(linear(
dims[0].0,
dims[0].1,
vb.pp(format!("{prefix}-fc0")),
)?)
.add(Activation::Relu)
.add(linear(
dims[1].0,
dims[1].1,
vb.pp(format!("{prefix}-fc1")),
)?)
.add(Activation::Relu)
.add(linear(
dims[2].0,
dims[2].1,
vb.pp(format!("{prefix}-fc2")),
)?);
Ok::<Sequential, Error>(seq)
};
let network = make_network("critic")?;
let target_network = make_network("target-critic")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
Ok(Self {
varmap,
vb,
network,
target_network,
size_state,
size_action,
dims,
})
}
fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
let xs = Tensor::cat(&[action, state], 1)?;
self.network.forward(&xs)
}
fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
let xs = Tensor::cat(&[action, state], 1)?;
self.target_network.forward(&xs)
}
fn track(&mut self, tau: f64) -> Result<()> {
track(
&mut self.varmap,
&self.vb,
"target-critic",
"critic",
&self.dims,
tau,
)
}
}
#[allow(clippy::upper_case_acronyms)]
pub struct DDPG<'a> {
actor: Actor<'a>,
actor_optim: AdamW,
critic: Critic<'a>,
critic_optim: AdamW,
gamma: f64,
tau: f64,
replay_buffer: ReplayBuffer,
ou_noise: OuNoise,
size_state: usize,
size_action: usize,
pub train: bool,
}
impl DDPG<'_> {
#[allow(clippy::too_many_arguments)]
pub fn new(
device: &Device,
size_state: usize,
size_action: usize,
train: bool,
actor_lr: f64,
critic_lr: f64,
gamma: f64,
tau: f64,
buffer_capacity: usize,
ou_noise: OuNoise,
) -> Result<Self> {
let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
varmap
.data()
.lock()
.unwrap()
.iter()
.filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
.collect::<Vec<Var>>()
};
let actor = Actor::new(device, DType::F32, size_state, size_action)?;
let actor_optim = AdamW::new(
filter_by_prefix(&actor.varmap, "actor"),
ParamsAdamW {
lr: actor_lr,
..Default::default()
},
)?;
let critic = Critic::new(device, DType::F32, size_state, size_action)?;
let critic_optim = AdamW::new(
filter_by_prefix(&critic.varmap, "critic"),
ParamsAdamW {
lr: critic_lr,
..Default::default()
},
)?;
Ok(Self {
actor,
actor_optim,
critic,
critic_optim,
gamma,
tau,
replay_buffer: ReplayBuffer::new(buffer_capacity),
ou_noise,
size_state,
size_action,
train,
})
}
pub fn remember(
&mut self,
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) {
self.replay_buffer
.push(state, action, reward, next_state, terminated, truncated)
}
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self
.actor
.forward(&state.detach()?.unsqueeze(0)?)?
.squeeze(0)?;
let actions = if self.train {
(actions + self.ou_noise.sample()?)?
} else {
actions
};
actions.squeeze(0)?.to_scalar::<f32>()
}
pub fn train(&mut self, batch_size: usize) -> Result<()> {
let (states, actions, rewards, next_states, _, _) =
match self.replay_buffer.random_batch(batch_size)? {
Some(v) => v,
_ => return Ok(()),
};
let q_target = self
.critic
.target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
let q_target = (rewards + (self.gamma * q_target)?.detach())?;
let q = self.critic.forward(&states, &actions)?;
let diff = (q_target - q)?;
let critic_loss = diff.sqr()?.mean_all()?;
self.critic_optim.backward_step(&critic_loss)?;
let actor_loss = self
.critic
.forward(&states, &self.actor.forward(&states)?)?
.mean_all()?
.neg()?;
self.actor_optim.backward_step(&actor_loss)?;
self.critic.track(self.tau)?;
self.actor.track(self.tau)?;
Ok(())
}
}

View File

@ -7,20 +7,22 @@ use pyo3::types::PyDict;
/// The return value for a step.
#[derive(Debug)]
pub struct Step<A> {
pub obs: Tensor,
pub state: Tensor,
pub action: A,
pub reward: f64,
pub is_done: bool,
pub terminated: bool,
pub truncated: bool,
}
impl<A: Copy> Step<A> {
/// Returns a copy of this step changing the observation tensor.
pub fn copy_with_obs(&self, obs: &Tensor) -> Step<A> {
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
Step {
obs: obs.clone(),
state: state.clone(),
action: self.action,
reward: self.reward,
is_done: self.is_done,
terminated: self.terminated,
truncated: self.truncated,
}
}
}
@ -63,14 +65,14 @@ impl GymEnv {
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
let obs: Vec<f32> = Python::with_gil(|py| {
let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("seed", seed)?;
let obs = self.env.call_method(py, "reset", (), Some(kwargs))?;
obs.as_ref(py).get_item(0)?.extract()
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
state.as_ref(py).get_item(0)?.extract()
})
.map_err(w)?;
Tensor::new(obs, &Device::Cpu)
Tensor::new(state, &Device::Cpu)
}
/// Applies an environment step using the specified action.
@ -78,21 +80,23 @@ impl GymEnv {
&self,
action: A,
) -> Result<Step<A>> {
let (obs, reward, is_done) = Python::with_gil(|py| {
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
let obs: Vec<f32> = step.get_item(0)?.extract()?;
let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
let is_done: bool = step.get_item(2)?.extract()?;
Ok((obs, reward, is_done))
let terminated: bool = step.get_item(2)?.extract()?;
let truncated: bool = step.get_item(3)?.extract()?;
Ok((state, reward, terminated, truncated))
})
.map_err(w)?;
let obs = Tensor::new(obs, &Device::Cpu)?;
let state = Tensor::new(state, &Device::Cpu)?;
Ok(Step {
obs,
reward,
is_done,
state,
action,
reward,
terminated,
truncated,
})
}

View File

@ -9,14 +9,34 @@ extern crate accelerate_src;
mod gym_env;
mod vec_gym_env;
use candle::Result;
mod ddpg;
use candle::{Device, Result, Tensor};
use clap::Parser;
use rand::Rng;
// The impact of the q value of the next state on the current state's q value.
const GAMMA: f64 = 0.99;
// The weight for updating the target networks.
const TAU: f64 = 0.005;
// The capacity of the replay buffer used for sampling training data.
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
// The training batch size for each training iteration.
const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
// The number of training iterations after one episode finishes.
const TRAINING_ITERATIONS: usize = 200;
// Ornstein-Uhlenbeck process parameters.
const MU: f64 = 0.0;
const THETA: f64 = 0.15;
const SIGMA: f64 = 0.1;
const ACTOR_LEARNING_RATE: f64 = 1e-4;
const CRITIC_LEARNING_RATE: f64 = 1e-3;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
@ -48,28 +68,77 @@ fn main() -> Result<()> {
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let _num_obs = env.observation_space().iter().product::<usize>();
let _num_actions = env.action_space();
let size_state = env.observation_space().iter().product::<usize>();
let size_action = env.action_space();
let mut agent = ddpg::DDPG::new(
&Device::Cpu,
size_state,
size_action,
true,
ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE,
GAMMA,
TAU,
REPLAY_BUFFER_CAPACITY,
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
for episode in 0..MAX_EPISODES {
let mut obs = env.reset(episode as u64)?;
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let actions = rng.gen_range(-2.0..2.0);
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![actions])?;
let step = env.step(vec![action])?;
total_reward += step.reward;
if step.is_done {
agent.remember(
&state,
&Tensor::new(vec![action], &Device::Cpu)?,
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
&step.state,
step.terminated,
step.truncated,
);
if step.terminated || step.truncated {
break;
}
obs = step.obs;
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
for _ in 0..TRAINING_ITERATIONS {
agent.train(TRAINING_BATCH_SIZE)?;
}
}
println!("Testing...");
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
}
Ok(())
}

View File

@ -0,0 +1,13 @@
## VGG Model Implementation
This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library.
The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image.
You can run the example with the following command:
```bash
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
```
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).

View File

@ -0,0 +1,77 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{ModuleT, VarBuilder};
use candle_transformers::models::vgg::{Models, Vgg};
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Vgg13,
Vgg16,
Vgg19,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Variant of the model to use.
#[arg(value_enum, long, default_value_t = Which::Vgg13)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let api = hf_hub::api::sync::Api::new()?;
let repo = match args.which {
Which::Vgg13 => "timm/vgg13.tv_in1k",
Which::Vgg16 => "timm/vgg16.tv_in1k",
Which::Vgg19 => "timm/vgg19.tv_in1k",
};
let api = api.model(repo.into());
let filename = "model.safetensors";
let model_file = api.get(filename)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = match args.which {
Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?,
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
};
let logits = model.forward_t(&image, /*train=*/ false)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
Ok(())
}

View File

@ -1,7 +1,5 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
};
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder};
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Multiples {
@ -76,7 +74,6 @@ impl Module for Upsample {
#[derive(Debug)]
struct ConvBlock {
conv: Conv2d,
bn: BatchNorm,
span: tracing::Span,
}
@ -96,11 +93,10 @@ impl ConvBlock {
groups: 1,
dilation: 1,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
Ok(Self {
conv,
bn,
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
})
}
@ -110,7 +106,6 @@ impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.conv.forward(xs)?;
let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs)
}
}

View File

@ -9,8 +9,11 @@ pub enum Activation {
#[serde(rename = "gated-gelu")]
NewGelu,
Relu,
Relu2,
Relu6,
Silu,
Sigmoid,
Swish,
Elu(f64),
LeakyRelu(f64),
}
@ -22,8 +25,11 @@ impl super::Module for Activation {
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
Self::NewGelu => xs.gelu(),
Self::Relu => xs.relu(),
Self::Relu2 => xs.relu()?.sqr(),
Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => crate::ops::silu(xs),
Self::Sigmoid => crate::ops::sigmoid(xs),
Self::Swish => xs * crate::ops::sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
}

View File

@ -100,9 +100,23 @@ impl BatchNorm {
num_features,
})
}
}
impl BatchNorm {
pub fn running_mean(&self) -> &Tensor {
&self.running_mean
}
pub fn running_var(&self) -> &Tensor {
&self.running_var
}
pub fn eps(&self) -> f64 {
self.eps
}
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
}
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {

View File

@ -1,4 +1,5 @@
//! Convolution Layers.
use crate::BatchNorm;
use candle::{Result, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -115,6 +116,26 @@ impl Conv2d {
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> {
if let Some((w_bn, b_bn)) = bn.weight_and_bias() {
let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?;
let weight = self
.weight()
.broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?;
let bias = match &self.bias {
None => b_bn.sub(&(std_.mul(bn.running_mean())?))?,
Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?,
};
Ok(Self {
weight,
bias: Some(bias),
config: self.config,
})
} else {
candle::bail!("batch norm does not have weight_and_bias")
}
}
}
impl crate::Module for Conv2d {

View File

@ -36,3 +36,38 @@ impl<'a> Func<'a> {
Self { f: Arc::new(f) }
}
}
/// A layer defined by a simple closure.
#[derive(Clone)]
pub struct FuncT<'a> {
#[allow(clippy::type_complexity)]
f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
}
impl<'a> std::fmt::Debug for FuncT<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "func")
}
}
pub fn func_t<'a, F>(f: F) -> FuncT<'a>
where
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
{
FuncT { f: Arc::new(f) }
}
impl<'a> super::ModuleT for FuncT<'a> {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
(*self.f)(xs, train)
}
}
impl<'a> FuncT<'a> {
pub fn new<F>(f: F) -> Self
where
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
{
Self { f: Arc::new(f) }
}
}

View File

@ -22,7 +22,7 @@ pub use conv::{
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
};
pub use embedding::{embedding, Embedding};
pub use func::{func, Func};
pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
@ -34,4 +34,4 @@ pub use sequential::{seq, Sequential};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;
pub use candle::Module;
pub use candle::{Module, ModuleT};

View File

@ -84,6 +84,12 @@ impl Dropout {
}
}
impl candle::ModuleT for Dropout {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
self.forward(xs, train)
}
}
struct SoftmaxLastDim;
impl candle::CustomOp1 for SoftmaxLastDim {

View File

@ -19,10 +19,10 @@ candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
[build-dependencies]
pyo3-build-config = "0.19"
pyo3-build-config = "0.20"
[features]
default = []

View File

@ -53,3 +53,39 @@ class Tensor:
Return a slice of a tensor.
"""
pass
def __eq__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __ne__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __lt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __le__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __gt__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __ge__(self, rhs: Union["Tensor", "Scalar"]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
class bf16(DType):
pass
@ -26,21 +26,21 @@ class i64(DType):
pass
@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
def ones(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor filled with ones.
"""
pass
@staticmethod
def rand(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
def rand(*shape: Shape, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values.
"""
pass
@staticmethod
def randn(shape: Sequence[int], device: Optional[Device] = None) -> Tensor:
def randn(*shape: Shape, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor with random values from a normal distribution.
"""
@ -67,7 +67,7 @@ class u8(DType):
pass
@staticmethod
def zeros(shape: Sequence[int], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
def zeros(*shape: Shape, dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
"""
Creates a new tensor filled with zeros.
"""
@ -124,16 +124,46 @@ class Tensor:
Add a scalar to a tensor or two tensors together.
"""
pass
def __eq__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __ge__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __getitem__(self, index: Union[Index, Tensor, Sequence[Index]]) -> "Tensor":
"""
Return a slice of a tensor.
"""
pass
def __gt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __le__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __lt__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __mul__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Multiply a tensor by a scalar or one tensor by another.
"""
pass
def __ne__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Compare a tensor with a scalar or one tensor with another.
"""
pass
def __radd__(self, rhs: Union[Tensor, Scalar]) -> "Tensor":
"""
Add a scalar to a tensor or two tensors together.
@ -159,6 +189,11 @@ class Tensor:
Divide a tensor by a scalar or one tensor by another.
"""
pass
def abs(self) -> Tensor:
"""
Performs the `abs` operation on the tensor.
"""
pass
def argmax_keepdim(self, dim: int) -> Tensor:
"""
Returns the indices of the maximum value(s) across the selected dimension.
@ -174,7 +209,7 @@ class Tensor:
Adds the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
def broadcast_as(self, shape: Sequence[int]) -> Tensor:
def broadcast_as(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape.
"""
@ -184,7 +219,7 @@ class Tensor:
Divides the two tensors, while broadcasting the right-hand-side tensor to match the shape of the left-hand-side tensor.
"""
pass
def broadcast_left(self, shape: Sequence[int]) -> Tensor:
def broadcast_left(self, *shape: Shape) -> Tensor:
"""
Broadcasts the tensor to the given shape, adding new dimensions on the left.
"""
@ -308,6 +343,12 @@ class Tensor:
ranges from `start` to `start + len`.
"""
pass
@property
def nelement(self) -> int:
"""
Gets the tensor's element count.
"""
pass
def powf(self, p: float) -> Tensor:
"""
Performs the `pow` operation on the tensor with the given exponent.
@ -329,7 +370,7 @@ class Tensor:
Get the `recip` of the tensor.
"""
pass
def reshape(self, shape: Sequence[int]) -> Tensor:
def reshape(self, *shape: Shape) -> Tensor:
"""
Reshapes the tensor to the given shape.
"""
@ -396,6 +437,11 @@ class Tensor:
Convert the tensor to a new dtype.
"""
pass
def to_torch(self) -> torch.Tensor:
"""
Converts candle's tensor to pytorch's tensor
"""
pass
def transpose(self, dim1: int, dim2: int) -> Tensor:
"""
Returns a tensor that is a transposed version of the input, the given dimensions are swapped.

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
@staticmethod

View File

@ -0,0 +1,70 @@
import candle
from candle import Tensor
_UNSIGNED_DTYPES = set([str(candle.u8), str(candle.u32)])
def _assert_tensor_metadata(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
if check_device:
assert actual.device == expected.device, f"Device mismatch: {actual.device} != {expected.device}"
if check_dtype:
assert str(actual.dtype) == str(expected.dtype), f"Dtype mismatch: {actual.dtype} != {expected.dtype}"
if check_layout:
assert actual.shape == expected.shape, f"Shape mismatch: {actual.shape} != {expected.shape}"
if check_stride:
assert actual.stride == expected.stride, f"Stride mismatch: {actual.stride} != {expected.stride}"
def assert_equal(
actual: Tensor,
expected: Tensor,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts that two tensors are exact equals.
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
assert (actual - expected).abs().sum_all().values() == 0, f"Tensors mismatch: {actual} != {expected}"
def assert_almost_equal(
actual: Tensor,
expected: Tensor,
rtol=1e-05,
atol=1e-08,
check_device: bool = True,
check_dtype: bool = True,
check_layout: bool = True,
check_stride: bool = False,
):
"""
Asserts, that two tensors are almost equal by performing an element wise comparison of the tensors with a tolerance.
Computes: |actual - expected| ≤ atol + rtol x |expected|
"""
_assert_tensor_metadata(actual, expected, check_device, check_dtype, check_layout, check_stride)
# Secure against overflow of u32 and u8 tensors
if str(actual.dtype) in _UNSIGNED_DTYPES or str(expected.dtype) in _UNSIGNED_DTYPES:
actual = actual.to(candle.i64)
expected = expected.to(candle.i64)
diff = (actual - expected).abs()
threshold = (expected.abs().to_dtype(candle.f32) * rtol + atol).to(expected)
assert (diff <= threshold).sum_all().values() == actual.nelement, f"Difference between tensors was to great"

View File

@ -18,3 +18,5 @@ Device = TypeVar("Device", CPU, CUDA)
Scalar = Union[int, float]
Index = Union[int, slice, None, "Ellipsis"]
Shape = Union[int, Sequence[int]]

View File

@ -1,7 +1,7 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor
@staticmethod

View File

@ -1,8 +1,11 @@
#![allow(clippy::redundant_closure_call)]
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::pyclass::CompareOp;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::os::raw::c_long;
use std::sync::Arc;
@ -16,26 +19,13 @@ extern crate accelerate_src;
use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};
mod shape;
use shape::{PyShape, PyShapeWithHole};
pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[derive(Clone, Debug)]
struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let dims: Vec<usize> = pyo3::FromPyObject::extract(ob)?;
Ok(PyShape(dims))
}
}
impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}
#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
/// A `candle` tensor.
@ -145,9 +135,10 @@ macro_rules! pydtype {
}
};
}
pydtype!(i64, |v| v);
pydtype!(u8, |v| v);
pydtype!(u32, |v| v);
pydtype!(i64, |v| v);
pydtype!(f16, f32::from);
pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
@ -211,6 +202,16 @@ enum Indexer {
IndexSelect(Tensor),
}
#[derive(Clone, Debug)]
struct TorchTensor(PyObject);
impl<'source> pyo3::FromPyObject<'source> for TorchTensor {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?;
Ok(TorchTensor(numpy_value))
}
}
#[pymethods]
impl PyTensor {
#[new]
@ -246,6 +247,8 @@ impl PyTensor {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(vs) = data.extract::<Vec<Vec<Vec<f32>>>>(py) {
Tensor::new(vs, &Cpu).map_err(wrap_err)?
} else if let Ok(TorchTensor(numpy)) = data.extract::<TorchTensor>(py) {
return PyTensor::new(py, numpy);
} else {
let ty = data.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
@ -299,6 +302,18 @@ impl PyTensor {
M(py).map(self)
}
/// Converts candle's tensor to pytorch's tensor
/// &RETURNS&: torch.Tensor
fn to_torch(&self, py: Python<'_>) -> PyResult<PyObject> {
let candle_values = self.values(py)?;
let torch_tensor: PyObject = py
.import("torch")?
.getattr("tensor")?
.call1((candle_values,))?
.extract()?;
Ok(torch_tensor)
}
#[getter]
/// Gets the tensor's shape.
/// &RETURNS&: Tuple[int]
@ -306,6 +321,13 @@ impl PyTensor {
PyTuple::new(py, self.0.dims()).to_object(py)
}
#[getter]
/// Gets the tensor's element count.
/// &RETURNS&: int
fn nelement(&self) -> usize {
self.0.elem_count()
}
#[getter]
/// Gets the tensor's strides.
/// &RETURNS&: Tuple[int]
@ -342,6 +364,12 @@ impl PyTensor {
self.__repr__()
}
/// Performs the `abs` operation on the tensor.
/// &RETURNS&: Tensor
fn abs(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.abs().map_err(wrap_err)?))
}
/// Performs the `sin` operation on the tensor.
/// &RETURNS&: Tensor
fn sin(&self) -> PyResult<Self> {
@ -659,26 +687,90 @@ impl PyTensor {
};
Ok(Self(tensor))
}
/// Rich-compare two tensors.
/// &RETURNS&: Tensor
fn __richcmp__(&self, rhs: &PyAny, op: CompareOp) -> PyResult<Self> {
let compare = |lhs: &Tensor, rhs: &Tensor| {
let t = match op {
CompareOp::Eq => lhs.eq(rhs),
CompareOp::Ne => lhs.ne(rhs),
CompareOp::Lt => lhs.lt(rhs),
CompareOp::Le => lhs.le(rhs),
CompareOp::Gt => lhs.gt(rhs),
CompareOp::Ge => lhs.ge(rhs),
};
Ok(PyTensor(t.map_err(wrap_err)?))
};
if let Ok(rhs) = rhs.extract::<PyTensor>() {
if self.0.shape() == rhs.0.shape() {
compare(&self.0, &rhs.0)
} else {
// We broadcast manually here because `candle.cmp` does not support automatic broadcasting
let broadcast_shape = self
.0
.shape()
.broadcast_shape_binary_op(rhs.0.shape(), "cmp")
.map_err(wrap_err)?;
let broadcasted_lhs = self.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
let broadcasted_rhs = rhs.0.broadcast_as(&broadcast_shape).map_err(wrap_err)?;
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
compare(&broadcasted_lhs, &broadcasted_rhs)
}
} else if let Ok(rhs) = rhs.extract::<f64>() {
let scalar_tensor = Tensor::new(rhs, self.0.device())
.map_err(wrap_err)?
.to_dtype(self.0.dtype())
.map_err(wrap_err)?
.broadcast_as(self.0.shape())
.map_err(wrap_err)?;
compare(&self.0, &scalar_tensor)
} else {
return Err(PyTypeError::new_err("unsupported rhs for __richcmp__"));
}
}
fn __hash__(&self) -> u64 {
// we have overridden __richcmp__ => py03 wants us to also override __hash__
// we simply hash the address of the tensor
let mut hasher = DefaultHasher::new();
let pointer = &self.0 as *const Tensor;
let address = pointer as usize;
address.hash(&mut hasher);
hasher.finish()
}
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Reshapes the tensor to the given shape.
/// &RETURNS&: Tensor
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
fn reshape(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(
self.0
.reshape(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape.
/// &RETURNS&: Tensor
fn broadcast_as(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
fn broadcast_as(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(
self.0
.broadcast_as(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}
#[pyo3(text_signature = "(self, shape:Sequence[int])")]
#[pyo3(signature=(*shape), text_signature = "(self, *shape:Shape)")]
/// Broadcasts the tensor to the given shape, adding new dimensions on the left.
/// &RETURNS&: Tensor
fn broadcast_left(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
fn broadcast_left(&self, shape: PyShapeWithHole) -> PyResult<Self> {
Ok(PyTensor(
self.0
.broadcast_left(shape.to_absolute(&self.0)?)
.map_err(wrap_err)?,
))
}
#[pyo3(text_signature = "(self, dim:int)")]
@ -891,21 +983,21 @@ impl PyTensor {
}
if let Some(kwargs) = kwargs {
if let Some(any) = kwargs.get_item("dtype") {
if let Ok(Some(any)) = kwargs.get_item("dtype") {
handle_duplicates(
&mut dtype,
any.extract::<PyDType>(),
"cannot specify multiple dtypes",
)?;
}
if let Some(any) = kwargs.get_item("device") {
if let Ok(Some(any)) = kwargs.get_item("device") {
handle_duplicates(
&mut device,
any.extract::<PyDevice>(),
"cannot specify multiple devices",
)?;
}
if let Some(any) = kwargs.get_item("other") {
if let Ok(Some(any)) = kwargs.get_item("other") {
handle_duplicates(
&mut other,
any.extract::<PyTensor>(),
@ -1025,27 +1117,27 @@ fn tensor(py: Python<'_>, data: PyObject) -> PyResult<PyTensor> {
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values.
/// &RETURNS&: Tensor
fn rand(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::rand(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
let tensor = Tensor::rand(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, device=None), text_signature = "(shape:Sequence[int], device:Optional[Device]=None)")]
#[pyo3(signature = (*shape,device=None), text_signature = "(*shape:Shape, device:Optional[Device]=None)")]
/// Creates a new tensor with random values from a normal distribution.
/// &RETURNS&: Tensor
fn randn(_py: Python<'_>, shape: PyShape, device: Option<PyDevice>) -> PyResult<PyTensor> {
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::randn(0f32, 1f32, shape.0, &device).map_err(wrap_err)?;
let tensor = Tensor::randn(0f32, 1f32, shape, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None),text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
#[pyo3(signature = (*shape, dtype=None, device=None),text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with ones.
/// &RETURNS&: Tensor
fn ones(
@ -1059,12 +1151,12 @@ fn ones(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::ones(shape.0, dtype, &device).map_err(wrap_err)?;
let tensor = Tensor::ones(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
#[pyfunction]
#[pyo3(signature = (shape, *, dtype=None, device=None), text_signature = "(shape:Sequence[int], dtype:Optional[DType]=None, device:Optional[Device]=None)")]
#[pyo3(signature = (*shape, dtype=None, device=None), text_signature = "(*shape:Shape, dtype:Optional[DType]=None, device:Optional[Device]=None)")]
/// Creates a new tensor filled with zeros.
/// &RETURNS&: Tensor
fn zeros(
@ -1078,7 +1170,7 @@ fn zeros(
Some(dtype) => PyDType::from_pyobject(dtype, py)?.0,
};
let device = device.unwrap_or(PyDevice::Cpu).as_device()?;
let tensor = Tensor::zeros(shape.0, dtype, &device).map_err(wrap_err)?;
let tensor = Tensor::zeros(shape, dtype, &device).map_err(wrap_err)?;
Ok(PyTensor(tensor))
}
@ -1480,7 +1572,7 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_class::<PyDType>()?;
m.add("u8", PyDType(DType::U8))?;
m.add("u32", PyDType(DType::U32))?;
m.add("i16", PyDType(DType::I64))?;
m.add("i64", PyDType(DType::I64))?;
m.add("bf16", PyDType(DType::BF16))?;
m.add("f16", PyDType(DType::F16))?;
m.add("f32", PyDType(DType::F32))?;

99
candle-pyo3/src/shape.rs Normal file
View File

@ -0,0 +1,99 @@
use ::candle::Tensor;
use pyo3::prelude::*;
#[derive(Clone, Debug)]
/// Represents an absolute shape e.g. (1, 2, 3)
pub struct PyShape(Vec<usize>);
impl<'source> pyo3::FromPyObject<'source> for PyShape {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
));
}
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
let dims: Vec<usize> = pyo3::FromPyObject::extract(first_element)?;
Ok(PyShape(dims))
} else {
let dims: Vec<usize> = pyo3::FromPyObject::extract(tuple)?;
Ok(PyShape(dims))
}
}
}
impl From<PyShape> for ::candle::Shape {
fn from(val: PyShape) -> Self {
val.0.into()
}
}
#[derive(Clone, Debug)]
/// Represents a shape with a hole in it e.g. (1, -1, 3)
pub struct PyShapeWithHole(Vec<isize>);
impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if ob.is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Shape cannot be None",
));
}
let tuple = ob.downcast::<pyo3::types::PyTuple>()?;
let dims: Vec<isize> = if tuple.len() == 1 {
let first_element = tuple.get_item(0)?;
pyo3::FromPyObject::extract(first_element)?
} else {
pyo3::FromPyObject::extract(tuple)?
};
// Ensure we have only positive numbers and at most one "hole" (-1)
let negative_ones = dims.iter().filter(|&&x| x == -1).count();
let any_invalid_dimensions = dims.iter().any(|&x| x < -1 || x == 0);
if negative_ones > 1 || any_invalid_dimensions {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid dimension in shape: {:?}",
dims
)));
}
Ok(PyShapeWithHole(dims))
}
}
impl PyShapeWithHole {
/// Returns `true` if the shape is absolute e.g. (1, 2, 3)
pub fn is_absolute(&self) -> bool {
self.0.iter().all(|x| *x > 0)
}
/// Convert a relative shape to an absolute shape e.g. (1, -1) -> (1, 12)
pub fn to_absolute(&self, t: &Tensor) -> PyResult<PyShape> {
if self.is_absolute() {
return Ok(PyShape(
self.0.iter().map(|x| *x as usize).collect::<Vec<usize>>(),
));
}
let mut elements = t.elem_count();
let mut new_dims: Vec<usize> = vec![];
for dim in self.0.iter() {
if *dim > 0 {
new_dims.push(*dim as usize);
elements /= *dim as usize;
} else if *dim == -1 {
new_dims.push(elements);
} else {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Invalid dimension in shape: {}",
dim
)));
}
}
Ok(PyShape(new_dims))
}
}

View File

@ -13,7 +13,7 @@ GENERATED_COMMENT = "# Generated content DO NOT EDIT\n"
TYPING = """from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
"""
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index\n"
CANDLE_SPECIFIC_TYPING = "from candle.typing import _ArrayLike, Device, Scalar, Index, Shape\n"
CANDLE_TENSOR_IMPORTS = "from candle import Tensor,DType,QTensor\n"
RETURN_TYPE_MARKER = "&RETURNS&: "
ADDITIONAL_TYPEHINTS = {}

View File

@ -0,0 +1,14 @@
import candle
import torch
# convert from candle tensor to torch tensor
t = candle.randn((3, 512, 512))
torch_tensor = t.to_torch()
print(torch_tensor)
print(type(torch_tensor))
# convert from torch tensor to candle tensor
t = torch.randn((3, 512, 512))
candle_tensor = candle.Tensor(t)
print(candle_tensor)
print(type(candle_tensor))

View File

@ -0,0 +1,33 @@
import candle
from candle import Tensor
from candle.testing import assert_equal, assert_almost_equal
import pytest
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_equal(a, b)
with pytest.raises(AssertionError):
assert_equal(a, b + 1)
@pytest.mark.parametrize("dtype", [candle.f32, candle.f64, candle.f16, candle.u32, candle.u8, candle.i64])
def test_assert_almost_equal_asserts_correctly(dtype: candle.DType):
a = Tensor([1, 2, 3]).to(dtype)
b = Tensor([1, 2, 3]).to(dtype)
assert_almost_equal(a, b)
with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1)
assert_almost_equal(a, b + 1, atol=20)
assert_almost_equal(a, b + 1, rtol=20)
with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, atol=0.9)
with pytest.raises(AssertionError):
assert_almost_equal(a, b + 1, rtol=0.1)

View File

@ -0,0 +1,31 @@
from candle import Tensor
from candle import rand
import pytest
def test_absolute_shapes_are_valid():
a = rand((10, 20))
assert a.shape == (10, 20)
b = rand(10, 20)
assert b.shape == (10, 20)
pytest.raises(OverflowError, lambda: rand((10, 20, -1)))
pytest.raises(OverflowError, lambda: rand(-1, 20))
pytest.raises(TypeError, lambda: rand("foo", True))
def test_relative_shapes_are_valid():
a = rand(10, 20)
a = a.reshape((1, -1))
assert a.shape == (1, 200)
b = rand(10, 20)
b = b.reshape(-1, 1)
assert b.shape == (200, 1)
c = rand(10, 20)
pytest.raises(TypeError, lambda: c.reshape(1, "foo"))
pytest.raises(ValueError, lambda: c.reshape(1, -2))
pytest.raises(ValueError, lambda: c.reshape((-2, 1)))
pytest.raises(ValueError, lambda: c.reshape((0, 1)))
pytest.raises(ValueError, lambda: c.reshape((1, -1, -1)))

View File

@ -1,6 +1,7 @@
import candle
from candle import Tensor
from candle.utils import cuda_is_available
from candle.testing import assert_equal
import pytest
@ -77,6 +78,78 @@ def test_tensor_can_be_scliced_3d():
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
def assert_bool(t: Tensor, expected: bool):
assert t.shape == ()
assert str(t.dtype) == str(candle.u8)
assert bool(t.values()) == expected
def test_tensor_supports_equality_opperations_with_scalars():
t = Tensor(42.0)
assert_bool(t == 42.0, True)
assert_bool(t == 43.0, False)
assert_bool(t != 42.0, False)
assert_bool(t != 43.0, True)
assert_bool(t > 41.0, True)
assert_bool(t > 42.0, False)
assert_bool(t >= 41.0, True)
assert_bool(t >= 42.0, True)
assert_bool(t < 43.0, True)
assert_bool(t < 42.0, False)
assert_bool(t <= 43.0, True)
assert_bool(t <= 42.0, True)
def test_tensor_supports_equality_opperations_with_tensors():
t = Tensor(42.0)
same = Tensor(42.0)
other = Tensor(43.0)
assert_bool(t == same, True)
assert_bool(t == other, False)
assert_bool(t != same, False)
assert_bool(t != other, True)
assert_bool(t > same, False)
assert_bool(t > other, False)
assert_bool(t >= same, True)
assert_bool(t >= other, False)
assert_bool(t < same, False)
assert_bool(t < other, True)
assert_bool(t <= same, True)
assert_bool(t <= other, True)
def test_tensor_equality_opperations_can_broadcast():
# Create a decoder attention mask as a test case
# e.g.
# [[1,0,0]
# [1,1,0]
# [1,1,1]]
mask_cond = candle.Tensor([0, 1, 2])
mask = mask_cond < (mask_cond + 1).reshape((3, 1))
assert mask.shape == (3, 3)
assert_equal(mask, Tensor([[1, 0, 0], [1, 1, 0], [1, 1, 1]]).to_dtype(candle.u8))
def test_tensor_can_be_hashed():
t = Tensor(42.0)
other = Tensor(42.0)
# Hash should represent a unique tensor
assert hash(t) != hash(other)
assert hash(t) == hash(t)
def test_tensor_can_be_expanded_with_none():
t = candle.rand((12, 12))

View File

@ -11,6 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.0" }

View File

@ -1,3 +1,4 @@
use super::with_tracing::{linear, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
@ -32,33 +33,6 @@ impl HiddenActLayer {
}
}
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
span: tracing::Span,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { weight, bias, span }
}
pub fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
@ -77,8 +51,10 @@ impl LayerNorm {
span,
}
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
@ -180,12 +156,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
struct Dropout {
#[allow(dead_code)]
pr: f64,
@ -195,7 +165,9 @@ impl Dropout {
fn new(pr: f64) -> Self {
Self { pr }
}
}
impl Module for Dropout {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
// TODO
Ok(x.clone())
@ -316,7 +288,9 @@ impl BertSelfAttention {
let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?;
xs.contiguous()
}
}
impl Module for BertSelfAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(hidden_states)?;
@ -391,7 +365,9 @@ impl BertAttention {
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
}
impl Module for BertAttention {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(hidden_states)?;
@ -416,7 +392,9 @@ impl BertIntermediate {
span: tracing::span!(tracing::Level::TRACE, "inter"),
})
}
}
impl Module for BertIntermediate {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states = self.dense.forward(hidden_states)?;
@ -478,7 +456,9 @@ impl BertLayer {
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
}
impl Module for BertLayer {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let attention_output = self.attention.forward(hidden_states)?;
@ -507,7 +487,9 @@ impl BertEncoder {
let span = tracing::span!(tracing::Level::TRACE, "encoder");
Ok(BertEncoder { layers, span })
}
}
impl Module for BertEncoder {
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut hidden_states = hidden_states.clone();

View File

@ -2,8 +2,9 @@ use super::blip_text;
use super::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, Conv2dConfig, LayerNorm, VarBuilder};
use serde::Deserialize;
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Deserialize)]
pub struct VisionConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
@ -16,7 +17,7 @@ pub struct VisionConfig {
pub layer_norm_eps: f64,
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub text_config: blip_text::Config,
pub vision_config: VisionConfig,
@ -299,4 +300,8 @@ impl BlipForConditionalGeneration {
pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
&mut self.text_decoder
}
pub fn reset_kv_cache(&mut self) {
self.text_decoder.reset_kv_cache();
}
}

View File

@ -1,8 +1,9 @@
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
use serde::Deserialize;
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,

View File

@ -0,0 +1,369 @@
use super::with_tracing::{linear, linear_no_bias, Embedding, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PositionEmbeddingType {
Absolute,
Alibi,
}
// https://huggingface.co/jinaai/jina-bert-implementation/blob/main/configuration_bert.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub intermediate_size: usize,
pub hidden_act: candle_nn::Activation,
pub max_position_embeddings: usize,
pub type_vocab_size: usize,
pub initializer_range: f64,
pub layer_norm_eps: f64,
pub pad_token_id: usize,
pub position_embedding_type: PositionEmbeddingType,
}
impl Config {
pub fn v2_base() -> Self {
// https://huggingface.co/jinaai/jina-embeddings-v2-base-en/blob/main/config.json
Self {
vocab_size: 30528,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: candle_nn::Activation::Gelu,
max_position_embeddings: 8192,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
pad_token_id: 0,
position_embedding_type: PositionEmbeddingType::Alibi,
}
}
}
#[derive(Clone, Debug)]
struct BertEmbeddings {
word_embeddings: Embedding,
// no position_embeddings as we only support alibi.
token_type_embeddings: Embedding,
layer_norm: LayerNorm,
span: tracing::Span,
}
impl BertEmbeddings {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let word_embeddings =
Embedding::new(cfg.vocab_size, cfg.hidden_size, vb.pp("word_embeddings"))?;
let token_type_embeddings = Embedding::new(
cfg.type_vocab_size,
cfg.hidden_size,
vb.pp("token_type_embeddings"),
)?;
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self {
word_embeddings,
token_type_embeddings,
layer_norm,
span: tracing::span!(tracing::Level::TRACE, "embeddings"),
})
}
}
impl Module for BertEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len) = input_ids.dims2()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = Tensor::zeros(seq_len, DType::U32, input_ids.device())?
.broadcast_left(b_size)?
.apply(&self.token_type_embeddings)?;
let embeddings = (&input_embeddings + token_type_embeddings)?;
let embeddings = self.layer_norm.forward(&embeddings)?;
Ok(embeddings)
}
}
#[derive(Clone, Debug)]
struct BertSelfAttention {
query: Linear,
key: Linear,
value: Linear,
num_attention_heads: usize,
attention_head_size: usize,
span: tracing::Span,
span_softmax: tracing::Span,
}
impl BertSelfAttention {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attention_head_size = cfg.hidden_size / cfg.num_attention_heads;
let all_head_size = cfg.num_attention_heads * attention_head_size;
let hidden_size = cfg.hidden_size;
let query = linear(hidden_size, all_head_size, vb.pp("query"))?;
let value = linear(hidden_size, all_head_size, vb.pp("value"))?;
let key = linear(hidden_size, all_head_size, vb.pp("key"))?;
Ok(Self {
query,
key,
value,
num_attention_heads: cfg.num_attention_heads,
attention_head_size,
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
})
}
fn transpose_for_scores(&self, xs: &Tensor) -> Result<Tensor> {
let mut x_shape = xs.dims().to_vec();
x_shape.pop();
x_shape.push(self.num_attention_heads);
x_shape.push(self.attention_head_size);
xs.reshape(x_shape)?.transpose(1, 2)?.contiguous()
}
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let query_layer = self.query.forward(xs)?;
let key_layer = self.key.forward(xs)?;
let value_layer = self.value.forward(xs)?;
let query_layer = self.transpose_for_scores(&query_layer)?;
let key_layer = self.transpose_for_scores(&key_layer)?;
let value_layer = self.transpose_for_scores(&value_layer)?;
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_scores = attention_scores.broadcast_add(bias)?;
let attention_probs = {
let _enter_sm = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attention_scores)?
};
let context_layer = attention_probs.matmul(&value_layer)?;
let context_layer = context_layer.transpose(1, 2)?.contiguous()?;
let context_layer = context_layer.flatten_from(D::Minus2)?;
Ok(context_layer)
}
}
#[derive(Clone, Debug)]
struct BertSelfOutput {
dense: Linear,
layer_norm: LayerNorm,
span: tracing::Span,
}
impl BertSelfOutput {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let dense = linear(cfg.hidden_size, cfg.hidden_size, vb.pp("dense"))?;
let layer_norm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("LayerNorm"))?;
Ok(Self {
dense,
layer_norm,
span: tracing::span!(tracing::Level::TRACE, "self-out"),
})
}
fn forward(&self, xs: &Tensor, input_tensor: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.dense.forward(xs)?;
self.layer_norm.forward(&(xs + input_tensor)?)
}
}
#[derive(Clone, Debug)]
struct BertAttention {
self_attention: BertSelfAttention,
self_output: BertSelfOutput,
span: tracing::Span,
}
impl BertAttention {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let self_attention = BertSelfAttention::new(vb.pp("self"), cfg)?;
let self_output = BertSelfOutput::new(vb.pp("output"), cfg)?;
Ok(Self {
self_attention,
self_output,
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let self_outputs = self.self_attention.forward(xs, bias)?;
let attention_output = self.self_output.forward(&self_outputs, xs)?;
Ok(attention_output)
}
}
#[derive(Clone, Debug)]
struct BertGLUMLP {
gated_layers: Linear,
act: candle_nn::Activation,
wo: Linear,
layernorm: LayerNorm,
intermediate_size: usize,
}
impl BertGLUMLP {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let gated_layers = linear_no_bias(
cfg.hidden_size,
cfg.intermediate_size * 2,
vb.pp("gated_layers"),
)?;
let act = candle_nn::Activation::Gelu; // geglu
let wo = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("wo"))?;
let layernorm = layer_norm(cfg.hidden_size, cfg.layer_norm_eps, vb.pp("layernorm"))?;
Ok(Self {
gated_layers,
act,
wo,
layernorm,
intermediate_size: cfg.intermediate_size,
})
}
}
impl Module for BertGLUMLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = xs.apply(&self.gated_layers)?;
let gated = xs.narrow(D::Minus1, 0, self.intermediate_size)?;
let non_gated = xs.narrow(D::Minus1, self.intermediate_size, self.intermediate_size)?;
let xs = (gated.apply(&self.act) * non_gated)?.apply(&self.wo);
(xs + residual)?.apply(&self.layernorm)
}
}
#[derive(Clone, Debug)]
struct BertLayer {
attention: BertAttention,
mlp: BertGLUMLP,
span: tracing::Span,
}
impl BertLayer {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let attention = BertAttention::new(vb.pp("attention"), cfg)?;
let mlp = BertGLUMLP::new(vb.pp("mlp"), cfg)?;
Ok(Self {
attention,
mlp,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
fn forward(&self, xs: &Tensor, bias: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.attention.forward(xs, bias)?.apply(&self.mlp)
}
}
fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
let n_heads = cfg.num_attention_heads;
let seq_len = cfg.max_position_embeddings;
let alibi_bias = Tensor::arange(0, seq_len as i64, &Device::Cpu)?.to_dtype(DType::F32)?;
let alibi_bias = {
let a1 = alibi_bias.reshape((1, seq_len))?;
let a2 = alibi_bias.reshape((seq_len, 1))?;
a1.broadcast_sub(&a2)?.abs()?.broadcast_left(n_heads)?
};
let mut n_heads2 = 1;
while n_heads2 < n_heads {
n_heads2 *= 2
}
let slopes = (1..=n_heads2)
.map(|v| -1f32 / 2f32.powf((v * 8) as f32 / n_heads2 as f32))
.collect::<Vec<_>>();
let slopes = if n_heads2 == n_heads {
slopes
} else {
slopes
.iter()
.skip(1)
.step_by(2)
.chain(slopes.iter().step_by(2))
.take(n_heads)
.cloned()
.collect::<Vec<f32>>()
};
let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
}
#[derive(Clone, Debug)]
struct BertEncoder {
alibi: Tensor,
layers: Vec<BertLayer>,
span: tracing::Span,
}
impl BertEncoder {
fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
if cfg.position_embedding_type != PositionEmbeddingType::Alibi {
candle::bail!("only alibi is supported as a position-embedding-type")
}
let layers = (0..cfg.num_hidden_layers)
.map(|index| BertLayer::new(vb.pp(&format!("layer.{index}")), cfg))
.collect::<Result<Vec<_>>>()?;
let span = tracing::span!(tracing::Level::TRACE, "encoder");
let alibi = build_alibi_bias(cfg)?.to_device(vb.device())?;
Ok(Self {
alibi,
layers,
span,
})
}
}
impl Module for BertEncoder {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let seq_len = xs.dim(1)?;
let alibi_bias = self.alibi.i((.., .., ..seq_len, ..seq_len))?;
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, &alibi_bias)?
}
Ok(xs)
}
}
#[derive(Clone, Debug)]
pub struct BertModel {
embeddings: BertEmbeddings,
encoder: BertEncoder,
pub device: Device,
span: tracing::Span,
}
impl BertModel {
pub fn new(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let embeddings = BertEmbeddings::new(vb.pp("embeddings"), cfg)?;
let encoder = BertEncoder::new(vb.pp("encoder"), cfg)?;
Ok(Self {
embeddings,
encoder,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
}
impl Module for BertModel {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let embedding_output = self.embeddings.forward(input_ids)?;
let sequence_output = self.encoder.forward(&embedding_output)?;
Ok(sequence_output)
}
}

View File

@ -1,3 +1,4 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
@ -81,21 +82,6 @@ impl Config {
}
}
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@ -150,12 +136,6 @@ impl Cache {
}
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size))

View File

@ -17,7 +17,20 @@ pub struct Config {
}
impl Config {
pub fn tiny() -> Self {
pub fn tiny_260k() -> Self {
Self {
dim: 64,
hidden_dim: 768,
n_layers: 5,
n_heads: 8,
n_kv_heads: 4,
vocab_size: 32000,
seq_len: 512,
norm_eps: 1e-5,
}
}
pub fn tiny_15m() -> Self {
Self {
dim: 288,
hidden_dim: 768,
@ -29,6 +42,32 @@ impl Config {
norm_eps: 1e-5,
}
}
pub fn tiny_42m() -> Self {
Self {
dim: 512,
hidden_dim: 768,
n_layers: 8,
n_heads: 8,
n_kv_heads: 8,
vocab_size: 32000,
seq_len: 1024,
norm_eps: 1e-5,
}
}
pub fn tiny_110m() -> Self {
Self {
dim: 768,
hidden_dim: 768,
n_layers: 12,
n_heads: 12,
n_kv_heads: 12,
vocab_size: 32000,
seq_len: 1024,
norm_eps: 1e-5,
}
}
}
#[derive(Clone)]
@ -36,9 +75,9 @@ pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
pub cos: Tensor,
pub sin: Tensor,
device: Device,
}
@ -75,7 +114,7 @@ impl Cache {
})
}
fn mask(&self, t: usize) -> Result<Tensor> {
pub fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())

View File

@ -1,9 +1,8 @@
use anyhow::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Shape, Tensor};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
use candle_nn::VarBuilder;
use crate::model::Config;
use super::llama2_c::Config;
pub struct TransformerWeights {
// token embedding table

View File

@ -0,0 +1,520 @@
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
#[derive(Debug, Clone)]
pub struct Config {
pub vocab_size: usize,
pub decoder_vocab_size: Option<usize>,
pub max_position_embeddings: usize,
pub encoder_layers: usize,
pub encoder_ffn_dim: usize,
pub encoder_attention_heads: usize,
pub decoder_layers: usize,
pub decoder_ffn_dim: usize,
pub decoder_attention_heads: usize,
pub use_cache: bool,
pub is_encoder_decoder: bool,
pub activation_function: candle_nn::Activation,
pub d_model: usize,
pub decoder_start_token_id: u32,
pub scale_embedding: bool,
pub pad_token_id: u32,
pub eos_token_id: u32,
pub forced_eos_token_id: u32,
pub share_encoder_decoder_embeddings: bool,
}
impl Config {
// https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en/blob/main/config.json
pub fn opus_mt_tc_big_fr_en() -> Self {
Self {
activation_function: candle_nn::Activation::Relu,
d_model: 1024,
decoder_attention_heads: 16,
decoder_ffn_dim: 4096,
decoder_layers: 6,
decoder_start_token_id: 53016,
decoder_vocab_size: Some(53017),
encoder_attention_heads: 16,
encoder_ffn_dim: 4096,
encoder_layers: 6,
eos_token_id: 43311,
forced_eos_token_id: 43311,
is_encoder_decoder: true,
max_position_embeddings: 1024,
pad_token_id: 53016,
scale_embedding: true,
share_encoder_decoder_embeddings: true,
use_cache: true,
vocab_size: 53017,
}
}
// https://huggingface.co/Helsinki-NLP/opus-mt-fr-en/blob/main/config.json
pub fn opus_mt_fr_en() -> Self {
Self {
activation_function: candle_nn::Activation::Swish,
d_model: 512,
decoder_attention_heads: 8,
decoder_ffn_dim: 2048,
decoder_layers: 6,
decoder_start_token_id: 59513,
decoder_vocab_size: Some(59514),
encoder_attention_heads: 8,
encoder_ffn_dim: 2048,
encoder_layers: 6,
eos_token_id: 0,
forced_eos_token_id: 0,
is_encoder_decoder: true,
max_position_embeddings: 512,
pad_token_id: 59513,
scale_embedding: true,
share_encoder_decoder_embeddings: true,
use_cache: true,
vocab_size: 59514,
}
}
}
#[derive(Debug, Clone)]
struct SinusoidalPositionalEmbedding {
emb: Embedding,
}
impl SinusoidalPositionalEmbedding {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let dev = vb.device();
let dtype = vb.dtype();
let num_positions = cfg.max_position_embeddings;
let dim = cfg.d_model;
let inv_freq: Vec<_> = (0..dim)
.step_by(2)
.map(|i| 1f32 / 10000f32.powf(i as f32 / dim as f32))
.collect();
let inv_freq_len = inv_freq.len();
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
let t = Tensor::arange(0u32, num_positions as u32, dev)?
.to_dtype(dtype)?
.reshape((num_positions, 1))?;
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?;
let cos = freqs.cos()?;
let weights = Tensor::cat(&[&sin, &cos], 1)?.contiguous()?;
let emb = Embedding::from_weights(weights)?;
Ok(Self { emb })
}
fn forward(&self, input_ids: &Tensor, past_kv_len: usize) -> Result<Tensor> {
let seq_len = input_ids.dim(1)?;
Tensor::arange(
past_kv_len as u32,
(past_kv_len + seq_len) as u32,
input_ids.device(),
)?
.apply(&self.emb)
}
}
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
out_proj: Linear,
scaling: f64,
num_heads: usize,
head_dim: usize,
kv_cache: Option<(Tensor, Tensor)>,
is_decoder: bool,
}
impl Attention {
fn new(cfg: &Config, is_decoder: bool, vb: VarBuilder) -> Result<Self> {
let num_heads = if is_decoder {
cfg.decoder_attention_heads
} else {
cfg.encoder_attention_heads
};
let embed_dim = cfg.d_model;
let head_dim = embed_dim / num_heads;
let scaling = (head_dim as f64).powf(-0.5);
let q_proj = linear(embed_dim, embed_dim, vb.pp("q_proj"))?;
let k_proj = linear(embed_dim, embed_dim, vb.pp("k_proj"))?;
let v_proj = linear(embed_dim, embed_dim, vb.pp("v_proj"))?;
let out_proj = linear(embed_dim, embed_dim, vb.pp("out_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
out_proj,
scaling,
num_heads,
head_dim,
kv_cache: None,
is_decoder,
})
}
fn _shape(&self, tensor: &Tensor, bsz: usize) -> Result<Tensor> {
tensor
.reshape((bsz, (), self.num_heads, self.head_dim))?
.transpose(1, 2)?
.contiguous()
}
fn forward(
&mut self,
xs: &Tensor,
kv_states: Option<&Tensor>,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_sz, tgt_len, _) = xs.dims3()?;
let query_states = (xs.apply(&self.q_proj)? * self.scaling)?;
let (key_states, value_states) = match kv_states {
None => {
let key_states = self._shape(&xs.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&xs.apply(&self.v_proj)?, b_sz)?;
if self.is_decoder {
let kv_states = match &self.kv_cache {
None => (key_states, value_states),
Some((p_key_states, p_value_states)) => {
let key_states = Tensor::cat(&[p_key_states, &key_states], 2)?;
let value_states = Tensor::cat(&[p_value_states, &value_states], 2)?;
(key_states, value_states)
}
};
self.kv_cache = Some(kv_states.clone());
kv_states
} else {
(key_states, value_states)
}
}
Some(kv_states) => {
let key_states = self._shape(&kv_states.apply(&self.k_proj)?, b_sz)?;
let value_states = self._shape(&kv_states.apply(&self.v_proj)?, b_sz)?;
(key_states, value_states)
}
};
let proj_shape = (b_sz * self.num_heads, (), self.head_dim);
let query_states = self._shape(&query_states, b_sz)?.reshape(proj_shape)?;
let key_states = key_states.reshape(proj_shape)?;
let value_states = value_states.reshape(proj_shape)?;
let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?;
let attn_weights = match attn_mask {
None => attn_weights,
Some(attn_mask) => attn_weights.broadcast_add(attn_mask)?,
};
let attn_probs = candle_nn::ops::softmax_last_dim(&attn_weights)?;
let attn_output = attn_probs.matmul(&value_states)?;
attn_output
.reshape((b_sz, self.num_heads, tgt_len, self.head_dim))?
.transpose(1, 2)?
.reshape((b_sz, tgt_len, self.head_dim * self.num_heads))?
.apply(&self.out_proj)
}
fn reset_kv_cache(&mut self) {
self.kv_cache = None
}
}
#[derive(Debug, Clone)]
struct EncoderLayer {
self_attn: Attention,
self_attn_layer_norm: LayerNorm,
activation_fn: candle_nn::Activation,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
}
impl EncoderLayer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
let fc1 = linear(cfg.d_model, cfg.encoder_ffn_dim, vb.pp("fc1"))?;
let fc2 = linear(cfg.encoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
self_attn,
self_attn_layer_norm,
activation_fn: cfg.activation_function,
fc1,
fc2,
final_layer_norm,
})
}
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = (self.self_attn.forward(xs, None, None)? + residual)?
.apply(&self.self_attn_layer_norm)?;
let residual = &xs;
let xs = xs
.apply(&self.fc1)?
.apply(&self.activation_fn)?
.apply(&self.fc2)?;
(xs + residual)?.apply(&self.final_layer_norm)
}
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache()
}
}
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
self_attn_layer_norm: LayerNorm,
activation_fn: candle_nn::Activation,
encoder_attn: Attention,
encoder_attn_layer_norm: LayerNorm,
fc1: Linear,
fc2: Linear,
final_layer_norm: LayerNorm,
}
impl DecoderLayer {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?;
let encoder_attn_layer_norm =
layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
Ok(Self {
self_attn,
self_attn_layer_norm,
activation_fn: cfg.activation_function,
encoder_attn,
encoder_attn_layer_norm,
fc1,
fc2,
final_layer_norm,
})
}
fn forward(
&mut self,
xs: &Tensor,
encoder_xs: Option<&Tensor>,
attn_mask: &Tensor,
) -> Result<Tensor> {
let residual = xs;
let xs = (self.self_attn.forward(xs, None, Some(attn_mask))? + residual)?
.apply(&self.self_attn_layer_norm)?;
let xs = match encoder_xs {
None => xs,
Some(encoder_xs) => {
let residual = &xs;
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs), None)?;
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
}
};
let residual = &xs;
let xs = xs
.apply(&self.fc1)?
.apply(&self.activation_fn)?
.apply(&self.fc2)?;
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
Ok(xs)
}
fn reset_kv_cache(&mut self) {
self.self_attn.reset_kv_cache();
self.encoder_attn.reset_kv_cache()
}
}
#[derive(Debug, Clone)]
pub struct Encoder {
embed_tokens: Embedding,
embed_positions: SinusoidalPositionalEmbedding,
layers: Vec<EncoderLayer>,
embed_scale: Option<f64>,
}
impl Encoder {
fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
let mut layers = Vec::with_capacity(cfg.encoder_layers);
let vb_l = vb.pp("layers");
for idx in 0..cfg.encoder_layers {
let layer = EncoderLayer::new(cfg, vb_l.pp(idx))?;
layers.push(layer)
}
let embed_scale = if cfg.scale_embedding {
Some((cfg.d_model as f64).sqrt())
} else {
None
};
Ok(Self {
embed_tokens: embed_tokens.clone(),
embed_positions,
layers,
embed_scale,
})
}
pub fn forward(&mut self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
let xs = xs.apply(&self.embed_tokens)?;
let xs = match self.embed_scale {
None => xs,
Some(scale) => (xs * scale)?,
};
let embed_pos = self
.embed_positions
.forward(&xs, past_kv_len)?
.unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs)?
}
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
}
#[derive(Debug, Clone)]
pub struct Decoder {
embed_tokens: Embedding,
embed_positions: SinusoidalPositionalEmbedding,
layers: Vec<DecoderLayer>,
embed_scale: Option<f64>,
}
impl Decoder {
fn new(cfg: &Config, embed_tokens: &Embedding, vb: VarBuilder) -> Result<Self> {
let embed_positions = SinusoidalPositionalEmbedding::new(cfg, vb.pp("embed_positions"))?;
let mut layers = Vec::with_capacity(cfg.decoder_layers);
let vb_l = vb.pp("layers");
for idx in 0..cfg.decoder_layers {
let layer = DecoderLayer::new(cfg, vb_l.pp(idx))?;
layers.push(layer)
}
let embed_scale = if cfg.scale_embedding {
Some((cfg.d_model as f64).sqrt())
} else {
None
};
Ok(Self {
embed_tokens: embed_tokens.clone(),
embed_positions,
layers,
embed_scale,
})
}
pub fn forward(
&mut self,
xs: &Tensor,
encoder_xs: Option<&Tensor>,
past_kv_len: usize,
attn_mask: &Tensor,
) -> Result<Tensor> {
let xs = xs.apply(&self.embed_tokens)?;
let xs = match self.embed_scale {
None => xs,
Some(scale) => (xs * scale)?,
};
let embed_pos = self
.embed_positions
.forward(&xs, past_kv_len)?
.unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, encoder_xs, attn_mask)?;
}
Ok(xs)
}
pub fn reset_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.reset_kv_cache()
}
}
}
#[derive(Debug, Clone)]
struct Model {
shared: Embedding,
encoder: Encoder,
decoder: Decoder,
}
impl Model {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("shared"))?;
let encoder = Encoder::new(cfg, &shared, vb.pp("encoder"))?;
let decoder = Decoder::new(cfg, &shared, vb.pp("decoder"))?;
Ok(Self {
shared,
encoder,
decoder,
})
}
fn reset_kv_cache(&mut self) {
self.encoder.reset_kv_cache();
self.decoder.reset_kv_cache();
}
}
#[derive(Debug, Clone)]
pub struct MTModel {
model: Model,
lm_head: Linear,
final_logits_bias: Tensor,
}
impl MTModel {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size);
let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?;
let model = Model::new(cfg, vb.pp("model"))?;
let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None);
Ok(Self {
model,
lm_head,
final_logits_bias,
})
}
pub fn encoder(&mut self) -> &mut Encoder {
&mut self.model.encoder
}
pub fn decoder(&mut self) -> &mut Decoder {
&mut self.model.decoder
}
pub fn decode(
&mut self,
xs: &Tensor,
encoder_xs: &Tensor,
past_kv_len: usize,
) -> Result<Tensor> {
let seq_len = xs.dim(1)?;
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
let mask = Tensor::from_vec(mask, (seq_len, seq_len), xs.device())?;
self.model
.decoder
.forward(xs, Some(encoder_xs), past_kv_len, &mask)?
.apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias)
}
pub fn reset_kv_cache(&mut self) {
self.model.reset_kv_cache();
}
}

View File

@ -73,6 +73,23 @@ impl Config {
pad_vocab_size_multiple: 64,
}
}
// https://huggingface.co/teknium/Phi-Hermes-1.3B/blob/main/config.json
pub fn phi_hermes_1_3b() -> Self {
Self {
vocab_size: 50304,
n_positions: 2048,
n_embd: 2048,
n_layer: 24,
n_inner: None,
n_head: 32,
rotary_dim: usize::min(32, 2048 / 32),
activation_function: Activation::NewGelu,
layer_norm_epsilon: 1e-5,
tie_word_embeddings: false,
pad_vocab_size_multiple: 64,
}
}
}
#[derive(Debug, Clone)]

View File

@ -6,13 +6,19 @@ pub mod convmixer;
pub mod dinov2;
pub mod efficientnet;
pub mod falcon;
pub mod jina_bert;
pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
pub mod marian;
pub mod mistral;
pub mod mixformer;
pub mod mpt;
pub mod persimmon;
pub mod quantized_blip;
pub mod quantized_blip_text;
pub mod quantized_llama;
pub mod quantized_llama2_c;
pub mod quantized_mistral;
pub mod quantized_mixformer;
pub mod quantized_mpt;
@ -23,6 +29,7 @@ pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
pub mod t5;
pub mod vgg;
pub mod vit;
pub mod whisper;
pub mod with_tracing;

View File

@ -0,0 +1,56 @@
use candle::DType;
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PositionEmbeddingType {
Absolute,
Alibi,
}
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/persimmon/configuration_persimmon.py
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
pub num_attention_heads: usize,
pub num_key_value_heads: usize,
pub hidden_act: candle_nn::Activation,
pub max_position_embeddings: usize,
pub initializer_range: f64,
pub layer_norm_eps: f64,
pub rms_norm_eps: f64,
pub use_cache: bool,
pub tie_word_embeddings: bool,
pub rope_theta: f64,
pub qk_layernorm: bool,
pub partial_rotary_factor: f64,
}
impl Config {
pub fn base_8b() -> Self {
// https://huggingface.co/adept/persimmon-8b-base/blob/main/config.json
Self {
hidden_act: candle_nn::Activation::Relu,
hidden_size: 4096,
initializer_range: 0.02,
intermediate_size: 16384,
layer_norm_eps: 1e-05,
max_position_embeddings: 16384,
num_attention_heads: 64,
num_hidden_layers: 36,
num_key_value_heads: 64,
qk_layernorm: true,
rms_norm_eps: 1e-06,
rope_theta: 25000.0,
tie_word_embeddings: false,
use_cache: true,
vocab_size: 262144,
partial_rotary_factor: 0.5,
}
}
}

View File

@ -255,4 +255,7 @@ impl BlipForConditionalGeneration {
pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
&mut self.text_decoder
}
pub fn reset_kv_cache(&mut self) {
self.text_decoder.reset_kv_cache();
}
}

View File

@ -0,0 +1,227 @@
use super::llama2_c::{Cache, Config};
use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, IndexOp, Module, Result, Tensor, D};
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
v_proj: Linear,
o_proj: Linear,
n_head: usize,
n_key_value_head: usize,
head_dim: usize,
cache: Cache,
}
impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (b_sz, seq_len, h, n_embd) = x.dims4()?;
let cos = self.cache.cos.i(index_pos..index_pos + seq_len)?;
let sin = self.cache.sin.i(index_pos..index_pos + seq_len)?;
let cos = cos.unsqueeze(1)?;
let sin = sin.unsqueeze(1)?;
let cos = cos.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let sin = sin.broadcast_as((b_sz, seq_len, 1, n_embd / 2, 1))?;
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
Ok(rope)
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
let q = q.reshape((b_sz, seq_len, self.n_head, self.head_dim))?;
let k = k.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {
let mut cache = self.cache.kvs.lock().unwrap();
if let Some((cache_k, cache_v)) = &cache[block_idx] {
k = Tensor::cat(&[cache_k, &k], 1)?.contiguous()?;
v = Tensor::cat(&[cache_v, &v], 1)?.contiguous()?;
}
cache[block_idx] = Some((k.clone(), v.clone()))
}
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let q = q.transpose(1, 2)?.contiguous()?;
let k = k.transpose(1, 2)?.contiguous()?;
let v = v.transpose(1, 2)?.contiguous()?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let mask = self.cache.mask(seq_len)?.broadcast_as(att.shape())?;
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax(&att, D::Minus1)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
let y = self.o_proj.forward(&y)?;
Ok(y)
}
fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_key_value_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, seq_len, n_kv_head, head_dim) = x.dims4()?;
let x = x
.unsqueeze(3)?
.expand((b_sz, seq_len, n_kv_head, n_rep, head_dim))?
.reshape((b_sz, seq_len, n_kv_head * n_rep, head_dim))?;
Ok(x)
}
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let size_in = cfg.dim;
let size_q = (cfg.dim / cfg.n_heads) * cfg.n_heads;
let size_kv = (cfg.dim / cfg.n_heads) * cfg.n_kv_heads;
let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?;
let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?;
let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?;
let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?;
Ok(Self {
q_proj,
k_proj,
v_proj,
o_proj,
n_head: cfg.n_heads,
n_key_value_head: cfg.n_kv_heads,
head_dim: cfg.dim / cfg.n_heads,
cache: cache.clone(),
})
}
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
struct Mlp {
c_fc1: Linear,
c_fc2: Linear,
c_proj: Linear,
}
impl Mlp {
fn new(c_fc1: Linear, c_fc2: Linear, c_proj: Linear) -> Self {
Self {
c_fc1,
c_fc2,
c_proj,
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let h_size = cfg.dim;
let i_size = cfg.hidden_dim;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?;
let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?;
Ok(Self::new(c_fc1, c_fc2, c_proj))
}
}
struct Block {
rms_1: RmsNorm,
attn: CausalSelfAttention,
rms_2: RmsNorm,
mlp: Mlp,
}
impl Block {
fn new(rms_1: RmsNorm, attn: CausalSelfAttention, rms_2: RmsNorm, mlp: Mlp) -> Self {
Self {
rms_1,
attn,
rms_2,
mlp,
}
}
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let residual = x;
let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
let residual = &x;
let x = (self.mlp.forward(&self.rms_2.forward(&x)?)? + residual)?;
Ok(x)
}
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let input_layernorm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("input_layernorm"))?;
let post_attention_layernorm =
RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("post_attention_layernorm"))?;
Ok(Self::new(
input_layernorm,
attn,
post_attention_layernorm,
mlp,
))
}
}
pub struct QLlama {
wte: Embedding,
blocks: Vec<Block>,
ln_f: RmsNorm,
lm_head: Linear,
pub config: Config,
}
impl QLlama {
pub fn forward(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let (_b_sz, _seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx)?;
}
let x = self.ln_f.forward(&x)?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: Config) -> Result<Self> {
let wte = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let ln_f = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)
.map(|i| Block::load(vb.pp(format!("model.layers.{i}")), cache, &cfg).unwrap())
.collect();
Ok(Self {
wte,
blocks,
ln_f,
lm_head,
config: cfg,
})
}
}

View File

@ -7,7 +7,7 @@ use std::sync::Arc;
pub use crate::models::stable_lm::Config;
use crate::models::stable_lm::RotaryEmbedding;
#[derive(Debug)]
#[derive(Debug, Clone)]
#[allow(clippy::upper_case_acronyms)]
struct MLP {
gate_proj: Linear,
@ -43,7 +43,7 @@ impl Module for MLP {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct Attention {
q_proj: Linear,
k_proj: Linear,
@ -168,7 +168,7 @@ impl Attention {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct DecoderLayer {
self_attn: Attention,
mlp: MLP,
@ -213,7 +213,7 @@ impl DecoderLayer {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Model {
embed_tokens: Embedding,
layers: Vec<DecoderLayer>,

View File

@ -93,7 +93,7 @@ impl Default for Config {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5LayerNorm {
weight: Tensor,
variance_epsilon: f64,
@ -125,7 +125,7 @@ impl Module for T5LayerNorm {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5DenseActDense {
wi: QMatMul,
wo: QMatMul,
@ -156,7 +156,7 @@ impl Module for T5DenseActDense {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5DenseGatedActDense {
wi_0: QMatMul,
wi_1: QMatMul,
@ -191,7 +191,7 @@ impl Module for T5DenseGatedActDense {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5LayerFF {
dense_act: Option<T5DenseActDense>,
gated_dense_act: Option<T5DenseGatedActDense>,
@ -236,7 +236,7 @@ impl Module for T5LayerFF {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5Attention {
q: QMatMul,
k: QMatMul,
@ -431,7 +431,7 @@ impl T5Attention {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
@ -470,7 +470,7 @@ impl T5LayerSelfAttention {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5LayerCrossAttention {
cross_attention: T5Attention,
layer_norm: T5LayerNorm,
@ -512,7 +512,7 @@ impl T5LayerCrossAttention {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5Block {
self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>,
@ -583,7 +583,7 @@ impl T5Block {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
@ -633,7 +633,7 @@ impl T5Stack {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct T5EncoderModel {
encoder: T5Stack,
device: Device,
@ -666,7 +666,7 @@ impl T5EncoderModel {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct T5ForConditionalGeneration {
encoder: T5Stack,
decoder: T5Stack,

View File

@ -1,3 +1,4 @@
pub use crate::models::with_tracing::Linear;
use candle::{Result, Tensor};
use candle_nn::{Module, VarBuilder};
@ -9,13 +10,11 @@ pub mod tiny_vit;
pub mod transformer;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
if bias {
crate::models::with_tracing::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
@ -85,16 +84,3 @@ impl Module for MlpBlock {
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}

View File

@ -102,6 +102,14 @@ impl Config {
}
}
pub fn ssd1b() -> Self {
Self::sdxl()
}
pub fn ssd1b2() -> Self {
Self::sdxl2()
}
// https://huggingface.co/warp-ai/wuerstchen/blob/main/text_encoder/config.json
pub fn wuerstchen() -> Self {
Self {

View File

@ -249,6 +249,71 @@ impl StableDiffusionConfig {
)
}
pub fn ssd1b(
sliced_attention_size: Option<usize>,
height: Option<usize>,
width: Option<usize>,
) -> Self {
let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig {
out_channels,
use_cross_attn,
attention_head_dim,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/unet/config.json
let unet = unet_2d::UNet2DConditionModelConfig {
blocks: vec![
bc(320, None, 5),
bc(640, Some(2), 10),
bc(1280, Some(10), 20),
],
center_input_sample: false,
cross_attention_dim: 2048,
downsample_padding: 1,
flip_sin_to_cos: true,
freq_shift: 0.,
layers_per_block: 2,
mid_block_scale_factor: 1.,
norm_eps: 1e-5,
norm_num_groups: 32,
sliced_attention_size,
use_linear_projection: true,
};
// https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKLConfig {
block_out_channels: vec![128, 256, 512, 512],
layers_per_block: 2,
latent_channels: 4,
norm_num_groups: 32,
};
let scheduler = ddim::DDIMSchedulerConfig {
..Default::default()
};
let height = if let Some(height) = height {
assert_eq!(height % 8, 0, "height has to be divisible by 8");
height
} else {
1024
};
let width = if let Some(width) = width {
assert_eq!(width % 8, 0, "width has to be divisible by 8");
width
} else {
1024
};
Self {
width,
height,
clip: clip::Config::ssd1b(),
clip2: Some(clip::Config::ssd1b2()),
autoencoder,
scheduler,
unet,
}
}
pub fn build_vae<P: AsRef<std::path::Path>>(
&self,
vae_weights: P,

View File

@ -118,7 +118,7 @@ impl Config {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5LayerNorm {
weight: Tensor,
variance_epsilon: f64,
@ -150,7 +150,7 @@ impl Module for T5LayerNorm {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5DenseActDense {
wi: Linear,
wo: Linear,
@ -181,7 +181,7 @@ impl Module for T5DenseActDense {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5DenseGatedActDense {
wi_0: Linear,
wi_1: Linear,
@ -216,7 +216,7 @@ impl Module for T5DenseGatedActDense {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5LayerFF {
dense_act: Option<T5DenseActDense>,
gated_dense_act: Option<T5DenseGatedActDense>,
@ -261,7 +261,7 @@ impl Module for T5LayerFF {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5Attention {
q: Linear,
k: Linear,
@ -456,7 +456,7 @@ impl T5Attention {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5LayerSelfAttention {
self_attention: T5Attention,
layer_norm: T5LayerNorm,
@ -495,7 +495,7 @@ impl T5LayerSelfAttention {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5LayerCrossAttention {
cross_attention: T5Attention,
layer_norm: T5LayerNorm,
@ -537,7 +537,7 @@ impl T5LayerCrossAttention {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5Block {
self_attn: T5LayerSelfAttention,
cross_attn: Option<T5LayerCrossAttention>,
@ -608,7 +608,7 @@ impl T5Block {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
struct T5Stack {
block: Vec<T5Block>,
shared: Arc<Embedding>,
@ -658,7 +658,7 @@ impl T5Stack {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct T5EncoderModel {
encoder: T5Stack,
device: Device,
@ -691,7 +691,7 @@ impl T5EncoderModel {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct T5ForConditionalGeneration {
encoder: T5Stack,
decoder: T5Stack,

View File

@ -0,0 +1,257 @@
//! VGG-16 model implementation.
//!
//! See Very Deep Convolutional Networks for Large-Scale Image Recognition
//! <https://arxiv.org/abs/1409.1556>
use candle::{ModuleT, Result, Tensor};
use candle_nn::{FuncT, VarBuilder};
// Enum representing the different VGG models
pub enum Models {
Vgg13,
Vgg16,
Vgg19,
}
// Struct representing a VGG model
#[derive(Debug)]
pub struct Vgg<'a> {
blocks: Vec<FuncT<'a>>,
}
// Struct representing the configuration for the pre-logit layer
struct PreLogitConfig {
in_dim: (usize, usize, usize, usize),
target_in: usize,
target_out: usize,
}
// Implementation of the VGG model
impl<'a> Vgg<'a> {
// Function to create a new VGG model
pub fn new(vb: VarBuilder<'a>, model: Models) -> Result<Self> {
let blocks = match model {
Models::Vgg13 => vgg13_blocks(vb)?,
Models::Vgg16 => vgg16_blocks(vb)?,
Models::Vgg19 => vgg19_blocks(vb)?,
};
Ok(Self { blocks })
}
}
// Implementation of the forward pass for the VGG model
impl ModuleT for Vgg<'_> {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
let mut xs = xs.unsqueeze(0)?;
for block in self.blocks.iter() {
xs = xs.apply_t(block, train)?;
}
Ok(xs)
}
}
// Function to create a conv2d block
// The block is composed of two conv2d layers followed by a max pool layer
fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result<FuncT<'static>> {
let layers = convs
.iter()
.enumerate()
.map(|(_, &(in_c, out_c, name))| {
candle_nn::conv2d(
in_c,
out_c,
3,
candle_nn::Conv2dConfig {
stride: 1,
padding: 1,
..Default::default()
},
vb.pp(name),
)
})
.collect::<Result<Vec<_>>>()?;
Ok(FuncT::new(move |xs, _train| {
let mut xs = xs.clone();
for layer in layers.iter() {
xs = xs.apply(layer)?.relu()?
}
xs = xs.max_pool2d_with_stride(2, 2)?;
Ok(xs)
}))
}
// Function to create a fully connected layer
// The layer is composed of two linear layers followed by a dropout layer
fn fully_connected(
num_classes: usize,
pre_logit_1: PreLogitConfig,
pre_logit_2: PreLogitConfig,
vb: VarBuilder,
) -> Result<FuncT> {
let lin = get_weights_and_biases(
&vb.pp("pre_logits.fc1"),
pre_logit_1.in_dim,
pre_logit_1.target_in,
pre_logit_1.target_out,
)?;
let lin2 = get_weights_and_biases(
&vb.pp("pre_logits.fc2"),
pre_logit_2.in_dim,
pre_logit_2.target_in,
pre_logit_2.target_out,
)?;
let dropout1 = candle_nn::Dropout::new(0.5);
let dropout2 = candle_nn::Dropout::new(0.5);
let dropout3 = candle_nn::Dropout::new(0.5);
Ok(FuncT::new(move |xs, train| {
let xs = xs.reshape((1, pre_logit_1.target_out))?;
let xs = xs.apply_t(&dropout1, train)?.apply(&lin)?.relu()?;
let xs = xs.apply_t(&dropout2, train)?.apply(&lin2)?.relu()?;
let lin3 = candle_nn::linear(4096, num_classes, vb.pp("head.fc"))?;
let xs = xs.apply_t(&dropout3, train)?.apply(&lin3)?.relu()?;
Ok(xs)
}))
}
// Function to get the weights and biases for a layer
// This is required because the weights and biases are stored in different format than our linear layer expects
fn get_weights_and_biases(
vs: &VarBuilder,
in_dim: (usize, usize, usize, usize),
target_in: usize,
target_out: usize,
) -> Result<candle_nn::Linear> {
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get_with_hints(in_dim, "weight", init_ws)?;
let ws = ws.reshape((target_in, target_out))?;
let bound = 1. / (target_out as f64).sqrt();
let init_bs = candle_nn::Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vs.get_with_hints(target_in, "bias", init_bs)?;
Ok(candle_nn::Linear::new(ws, Some(bs)))
}
fn vgg13_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
let num_classes = 1000;
let blocks = vec![
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
conv2d_block(&[(128, 256, "features.10"), (256, 256, "features.12")], &vb)?,
conv2d_block(&[(256, 512, "features.15"), (512, 512, "features.17")], &vb)?,
conv2d_block(&[(512, 512, "features.20"), (512, 512, "features.22")], &vb)?,
fully_connected(
num_classes,
PreLogitConfig {
in_dim: (4096, 512, 7, 7),
target_in: 4096,
target_out: 512 * 7 * 7,
},
PreLogitConfig {
in_dim: (4096, 4096, 1, 1),
target_in: 4096,
target_out: 4096,
},
vb.clone(),
)?,
];
Ok(blocks)
}
fn vgg16_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
let num_classes = 1000;
let blocks = vec![
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
conv2d_block(
&[
(128, 256, "features.10"),
(256, 256, "features.12"),
(256, 256, "features.14"),
],
&vb,
)?,
conv2d_block(
&[
(256, 512, "features.17"),
(512, 512, "features.19"),
(512, 512, "features.21"),
],
&vb,
)?,
conv2d_block(
&[
(512, 512, "features.24"),
(512, 512, "features.26"),
(512, 512, "features.28"),
],
&vb,
)?,
fully_connected(
num_classes,
PreLogitConfig {
in_dim: (4096, 512, 7, 7),
target_in: 4096,
target_out: 512 * 7 * 7,
},
PreLogitConfig {
in_dim: (4096, 4096, 1, 1),
target_in: 4096,
target_out: 4096,
},
vb.clone(),
)?,
];
Ok(blocks)
}
fn vgg19_blocks(vb: VarBuilder) -> Result<Vec<FuncT>> {
let num_classes = 1000;
let blocks = vec![
conv2d_block(&[(3, 64, "features.0"), (64, 64, "features.2")], &vb)?,
conv2d_block(&[(64, 128, "features.5"), (128, 128, "features.7")], &vb)?,
conv2d_block(
&[
(128, 256, "features.10"),
(256, 256, "features.12"),
(256, 256, "features.14"),
(256, 256, "features.16"),
],
&vb,
)?,
conv2d_block(
&[
(256, 512, "features.19"),
(512, 512, "features.21"),
(512, 512, "features.23"),
(512, 512, "features.25"),
],
&vb,
)?,
conv2d_block(
&[
(512, 512, "features.28"),
(512, 512, "features.30"),
(512, 512, "features.32"),
(512, 512, "features.34"),
],
&vb,
)?,
fully_connected(
num_classes,
PreLogitConfig {
in_dim: (4096, 512, 7, 7),
target_in: 4096,
target_out: 512 * 7 * 7,
},
PreLogitConfig {
in_dim: (4096, 4096, 1, 1),
target_in: 4096,
target_out: 4096,
},
vb.clone(),
)?,
];
Ok(blocks)
}

View File

@ -1,4 +1,5 @@
use super::Config;
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
use candle::{Device, IndexOp, Result, Tensor, D};
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
@ -6,33 +7,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
//
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn conv1d(
in_channels: usize,
@ -53,6 +27,7 @@ fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
#[derive(Debug, Clone)]
struct MultiHeadAttention {
query: Linear,
key: Linear,
@ -162,6 +137,7 @@ impl MultiHeadAttention {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
#[derive(Debug, Clone)]
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
@ -241,6 +217,7 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
#[derive(Debug, Clone)]
pub struct AudioEncoder {
conv1: Conv1d,
conv2: Conv1d,
@ -316,6 +293,7 @@ impl AudioEncoder {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
#[derive(Debug, Clone)]
pub struct TextDecoder {
token_embedding: Embedding,
positional_embedding: Tensor,
@ -380,6 +358,7 @@ impl TextDecoder {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
#[derive(Debug, Clone)]
pub struct Whisper {
pub encoder: AudioEncoder,
pub decoder: TextDecoder,

View File

@ -19,6 +19,7 @@ fn conv1d(
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62
#[derive(Debug, Clone)]
struct MultiHeadAttention {
query: Linear,
key: Linear,
@ -128,6 +129,7 @@ impl MultiHeadAttention {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111
#[derive(Debug, Clone)]
struct ResidualAttentionBlock {
attn: MultiHeadAttention,
attn_ln: LayerNorm,
@ -206,6 +208,7 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
#[derive(Debug, Clone)]
pub struct AudioEncoder {
conv1: Conv1d,
conv2: Conv1d,
@ -281,6 +284,7 @@ impl AudioEncoder {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176
#[derive(Debug, Clone)]
pub struct TextDecoder {
token_embedding: Embedding,
positional_embedding: Tensor,
@ -347,6 +351,7 @@ impl TextDecoder {
}
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221
#[derive(Debug, Clone)]
pub struct Whisper {
pub encoder: AudioEncoder,
pub decoder: TextDecoder,

View File

@ -14,6 +14,13 @@ impl Embedding {
Ok(Self { inner, span })
}
pub fn from_weights(weights: Tensor) -> Result<Self> {
let (_in_size, out_size) = weights.dims2()?;
let inner = candle_nn::Embedding::new(weights, out_size);
let span = tracing::span!(tracing::Level::TRACE, "embedding");
Ok(Self { inner, span })
}
pub fn embeddings(&self) -> &Tensor {
self.inner.embeddings()
}

View File

@ -77,6 +77,16 @@ impl VarBuilder {
}
}
pub fn get_no_shape(&self, name: &str) -> Result<Arc<QTensor>> {
let path = self.path(name);
match self.data.get(&path) {
None => {
candle::bail!("cannot find tensor {name}")
}
Some(qtensor) => Ok(qtensor.clone()),
}
}
pub fn device(&self) -> &Device {
&self.device
}

View File

@ -0,0 +1,31 @@
[package]
name = "candle-wasm-example-blip"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
candle = { path = "../../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../../candle-nn", version = "0.3.0" }
candle-transformers = { path = "../../candle-transformers", version = "0.3.0" }
tokenizers = { workspace = true, features = ["unstable_wasm"] }
num-traits = { workspace = true }
# App crates.
anyhow = { workspace = true }
byteorder = { workspace = true }
getrandom = { version = "0.2", features = ["js"] }
image = { workspace = true }
log = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
# Wasm specific crates.
console_error_panic_hook = "0.1.7"
wasm-bindgen = "0.2.87"
js-sys = "0.3.64"

View File

@ -0,0 +1,23 @@
## Running [BLIP Image Captioning](https://huggingface.co/Salesforce/blip-image-captioning-large) Example
### Vanilla JS and WebWorkers
To build and test the UI made in Vanilla JS and WebWorkers, first we need to build the WASM library:
```bash
sh build-lib.sh
```
This will bundle the library under `./build` and we can import it inside our WebWorker like a normal JS module:
```js
import init, { Model } from "./build/m.js";
```
The full example can be found under `./index.html`. All needed assets are fetched from the web, so no need to download anything.
Finally, you can preview the example by running a local HTTP server. For example:
```bash
python -m http.server
```
Then open `http://localhost:8000/index.html` in your browser.

View File

@ -0,0 +1,77 @@
import init, { Model } from "./build/m.js";
async function fetchArrayBuffer(url, cacheFile = true) {
if (!cacheFile) return new Uint8Array(await (await fetch(url)).arrayBuffer());
const cacheName = "blip-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
class Blip {
static instance = {};
static async getInstance(
weightsURL,
tokenizerURL,
configURL,
modelID,
quantized
) {
if (!this.instance[modelID]) {
await init();
self.postMessage({ status: "loading", message: "Loading Model" });
const [weightsArrayU8, tokenizerArrayU8, configArrayU8] =
await Promise.all([
fetchArrayBuffer(weightsURL),
fetchArrayBuffer(tokenizerURL),
fetchArrayBuffer(configURL),
]);
this.instance[modelID] = new Model(
weightsArrayU8,
tokenizerArrayU8,
configArrayU8,
quantized
);
} else {
self.postMessage({ status: "ready", message: "Model Already Loaded" });
}
return this.instance[modelID];
}
}
self.addEventListener("message", async (event) => {
const { weightsURL, tokenizerURL, configURL, modelID, imageURL, quantized } =
event.data;
try {
self.postMessage({ status: "status", message: "Loading Blip Model..." });
const model = await Blip.getInstance(
weightsURL,
tokenizerURL,
configURL,
modelID,
quantized
);
self.postMessage({
status: "status",
message: "Running Blip Inference...",
});
const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
const output = model.generate_caption_from_image(imageArrayU8);
self.postMessage({
status: "complete",
message: "complete",
output: output,
});
} catch (e) {
self.postMessage({ error: e });
}
});

View File

@ -0,0 +1,2 @@
cargo build --target wasm32-unknown-unknown --release
wasm-bindgen ../../target/wasm32-unknown-unknown/release/m.wasm --out-dir build --target web

View File

@ -0,0 +1,393 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<style>
@import url("https://fonts.googleapis.com/css2?family=Source+Code+Pro:wght@200;300;400&family=Source+Sans+3:wght@100;200;300;400;500;600;700;800;900&display=swap");
html,
body {
font-family: "Source Sans 3", sans-serif;
}
</style>
<title>Candle Blip Image Captioning Demo</title>
<script src="https://cdn.tailwindcss.com"></script>
<script type="module" src="./code.js"></script>
<script type="module">
const MODELS = {
blip_image_quantized_q4k: {
base_url: "https://huggingface.co/lmz/candle-blip/resolve/main/",
model: "blip-image-captioning-large-q4k.gguf",
config: "config.json",
tokenizer: "tokenizer.json",
quantized: true,
size: "271 MB",
},
blip_image_quantized_q80: {
base_url: "https://huggingface.co/lmz/candle-blip/resolve/main/",
model: "blip-image-captioning-large-q80.gguf",
config: "config.json",
tokenizer: "tokenizer.json",
quantized: true,
size: "505 MB",
},
blip_image_large: {
base_url:
"https://huggingface.co/Salesforce/blip-image-captioning-large/resolve/refs%2Fpr%2F18/",
model: "model.safetensors",
config: "config.json",
tokenizer: "tokenizer.json",
quantized: false,
size: "1.88 GB",
},
};
const blipWorker = new Worker("./blipWorker.js", {
type: "module",
});
const outputStatusEl = document.querySelector("#output-status");
const outputCaptionEl = document.querySelector("#output-caption");
const modelSelectEl = document.querySelector("#model");
const clearBtn = document.querySelector("#clear-btn");
const fileUpload = document.querySelector("#file-upload");
const dropArea = document.querySelector("#drop-area");
const imagesExamples = document.querySelector("#image-select");
const canvas = document.querySelector("#canvas");
const ctxCanvas = canvas.getContext("2d");
let isCaptioning = false;
let currentImageURL = null;
clearBtn.addEventListener("click", () => {
clearImageCanvas();
});
modelSelectEl.addEventListener("change", () => {
if (currentImageURL) {
runInference(currentImageURL);
}
});
//add event listener to file input
fileUpload.addEventListener("input", async (e) => {
const target = e.target;
if (target.files.length > 0) {
const href = URL.createObjectURL(target.files[0]);
clearImageCanvas();
await drawImageCanvas(href);
runInference(href);
}
});
// add event listener to drop-area
dropArea.addEventListener("dragenter", (e) => {
e.preventDefault();
dropArea.classList.add("border-blue-700");
});
dropArea.addEventListener("dragleave", (e) => {
e.preventDefault();
dropArea.classList.remove("border-blue-700");
});
dropArea.addEventListener("dragover", (e) => {
e.preventDefault();
});
dropArea.addEventListener("drop", async (e) => {
e.preventDefault();
dropArea.classList.remove("border-blue-700");
const url = e.dataTransfer.getData("text/uri-list");
const files = e.dataTransfer.files;
if (files.length > 0) {
const href = URL.createObjectURL(files[0]);
clearImageCanvas();
await drawImageCanvas(href);
runInference(href);
} else if (url) {
clearImageCanvas();
await drawImageCanvas(url);
runInference(url);
}
});
imagesExamples.addEventListener("click", async (e) => {
if (isCaptioning) {
return;
}
const target = e.target;
if (target.nodeName === "IMG") {
const href = target.src;
clearImageCanvas();
await drawImageCanvas(href);
runInference(href);
}
});
function clearImageCanvas() {
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
isCaptioning = false;
clearBtn.disabled = true;
canvas.parentElement.style.height = "auto";
outputStatusEl.hidden = false;
outputCaptionEl.hidden = true;
outputStatusEl.innerText = "Please select an image";
currentImageURL = null;
}
async function drawImageCanvas(imgURL) {
if (!imgURL) {
throw new Error("No image URL provided");
}
return new Promise((resolve, reject) => {
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
ctxCanvas.clearRect(0, 0, canvas.width, canvas.height);
const img = new Image();
img.crossOrigin = "anonymous";
img.onload = () => {
canvas.width = img.width;
canvas.height = img.height;
ctxCanvas.drawImage(img, 0, 0);
canvas.parentElement.style.height = canvas.offsetHeight + "px";
clearBtn.disabled = false;
resolve(img);
};
img.src = imgURL;
currentImageURL = imgURL;
});
}
document.addEventListener("DOMContentLoaded", () => {
for (const [id, model] of Object.entries(MODELS)) {
const option = document.createElement("option");
option.value = id;
option.innerText = `${id} (${model.size})`;
modelSelectEl.appendChild(option);
}
});
async function getImageCaption(
worker,
weightsURL,
tokenizerURL,
configURL,
modelID,
imageURL,
quantized,
updateStatus = null
) {
return new Promise((resolve, reject) => {
worker.postMessage({
weightsURL,
tokenizerURL,
configURL,
modelID,
imageURL,
quantized,
});
function messageHandler(event) {
if ("error" in event.data) {
worker.removeEventListener("message", messageHandler);
reject(new Error(event.data.error));
}
if (event.data.status === "complete") {
worker.removeEventListener("message", messageHandler);
resolve(event.data);
}
if (updateStatus) updateStatus(event.data);
}
worker.addEventListener("message", messageHandler);
});
}
function updateStatus(data) {
if (data.status === "status") {
outputStatusEl.innerText = data.message;
}
}
async function runInference(imageURL) {
if (isCaptioning || !imageURL) {
alert("Please select an image first");
return;
}
outputStatusEl.hidden = false;
outputCaptionEl.hidden = true;
clearBtn.disabled = true;
modelSelectEl.disabled = true;
isCaptioning = true;
const selectedModel = modelSelectEl.value;
const model = MODELS[selectedModel];
const weightsURL = `${model.base_url}${model.model}`;
const tokenizerURL = `${model.base_url}${model.tokenizer}`;
const configURL = `${model.base_url}${model.config}`;
const quantized = model.quantized;
try {
const time = performance.now();
const caption = await getImageCaption(
blipWorker,
weightsURL,
tokenizerURL,
configURL,
selectedModel,
imageURL,
quantized,
updateStatus
);
outputStatusEl.hidden = true;
outputCaptionEl.hidden = false;
const totalTime = ((performance.now() - time)/1000).toFixed(2);
outputCaptionEl.innerHTML = `${
caption.output
}<br/><span class="text-xs">Inference time: ${totalTime} s</span>`;
} catch (err) {
console.error(err);
outputStatusEl.hidden = false;
outputCaptionEl.hidden = true;
outputStatusEl.innerText = err.message;
}
clearBtn.disabled = false;
modelSelectEl.disabled = false;
isCaptioning = false;
}
</script>
</head>
<body class="container max-w-4xl mx-auto p-4">
<main class="grid grid-cols-1 gap-5 relative">
<span class="absolute text-5xl -ml-[1em]"> 🕯️ </span>
<div>
<h1 class="text-5xl font-bold">Candle BLIP Image Captioning</h1>
<h2 class="text-2xl font-bold">Rust/WASM Demo</h2>
<p class="max-w-lg">
<a
href="https://huggingface.co/Salesforce/blip-image-captioning-large"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>BLIP Image Captioning
</a>
running in the browser using
<a
href="https://github.com/huggingface/candle/"
target="_blank"
class="underline hover:text-blue-500 hover:no-underline"
>Candle</a
>, a minimalist ML framework for Rust.
</p>
<p class="text-xs max-w-lg py-2">
<b>Note:</b>
The image captioning on the smallest model takes about ~50 seconds, it
will vary depending on your machine and model size.
</p>
</div>
<div>
<label for="model" class="font-medium block">Models Options: </label>
<select
id="model"
class="border-2 border-gray-500 rounded-md font-light interactive disabled:cursor-not-allowed w-full max-w-max"
></select>
</div>
<!-- drag and drop area -->
<div class="grid gap-4 sm:grid-cols-2 py-4">
<div class="relative max-w-lg">
<div
class="absolute w-full bottom-full flex justify-between items-center"
>
<div class="flex gap-2 w-full">
<button
id="clear-btn"
disabled
title="Clear Image"
class="ml-auto text-xs bg-white rounded-md disabled:opacity-50 flex gap-1 items-center"
>
<svg
class=""
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 13 12"
height="1em"
>
<path
d="M1.6.7 12 11.1M12 .7 1.6 11.1"
stroke="#2E3036"
stroke-width="2"
/>
</svg>
</button>
</div>
</div>
<div
id="drop-area"
class="flex flex-col items-center justify-center border-2 border-gray-300 border-dashed rounded-xl relative aspect-video w-full overflow-hidden"
>
<div
class="flex flex-col items-center justify-center space-y-1 text-center"
>
<svg
width="25"
height="25"
viewBox="0 0 25 25"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<path
d="M3.5 24.3a3 3 0 0 1-1.9-.8c-.5-.5-.8-1.2-.8-1.9V2.9c0-.7.3-1.3.8-1.9.6-.5 1.2-.7 2-.7h18.6c.7 0 1.3.2 1.9.7.5.6.7 1.2.7 2v18.6c0 .7-.2 1.4-.7 1.9a3 3 0 0 1-2 .8H3.6Zm0-2.7h18.7V2.9H3.5v18.7Zm2.7-2.7h13.3c.3 0 .5 0 .6-.3v-.7l-3.7-5a.6.6 0 0 0-.6-.2c-.2 0-.4 0-.5.3l-3.5 4.6-2.4-3.3a.6.6 0 0 0-.6-.3c-.2 0-.4.1-.5.3l-2.7 3.6c-.1.2-.2.4 0 .7.1.2.3.3.6.3Z"
fill="#000"
/>
</svg>
<div class="flex text-sm text-gray-600">
<label
for="file-upload"
class="relative cursor-pointer bg-white rounded-md font-medium text-blue-950 hover:text-blue-700"
>
<span>Drag and drop y our image here</span>
<span class="block text-xs">or</span>
<span class="block text-xs">Click to upload</span>
</label>
</div>
<input
id="file-upload"
name="file-upload"
type="file"
class="sr-only"
/>
</div>
<canvas
id="canvas"
class="absolute pointer-events-none w-full"
></canvas>
</div>
</div>
<div class="">
<div
class="h-full bg-slate-100 text-gray-500 p-4 rounded-md flex flex-col gap-2"
>
<p
id="output-caption"
class="m-auto text-xl text-center p-2"
hidden
></p>
<span id="output-status" class="m-auto font-light">
Please select an image
</span>
</div>
</div>
</div>
<div>
<div
class="flex gap-3 items-center overflow-x-scroll"
id="image-select"
>
<h3 class="font-medium">Examples:</h3>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/sf.jpg"
class="cursor-pointer w-24 h-24 object-cover"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/bike.jpeg"
class="cursor-pointer w-24 h-24 object-cover"
/>
<img
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/candle/examples/000000000077.jpg"
class="cursor-pointer w-24 h-24 object-cover"
/>
</div>
</div>
</main>
</body>
</html>

View File

@ -0,0 +1,148 @@
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::blip;
use candle_transformers::models::quantized_blip;
use candle_wasm_example_blip::console_log;
use candle_wasm_example_blip::token_output_stream::TokenOutputStream;
use js_sys::Date;
use tokenizers::Tokenizer;
use wasm_bindgen::prelude::*;
enum SelectedModel {
M(blip::BlipForConditionalGeneration),
Q(quantized_blip::BlipForConditionalGeneration),
}
impl SelectedModel {
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor, JsError> {
match self {
Self::M(m) => m
.text_decoder()
.forward(xs, img_xs)
.map_err(|e| JsError::new(&e.to_string())),
Self::Q(m) => m
.text_decoder()
.forward(xs, img_xs)
.map_err(|e| JsError::new(&e.to_string())),
}
}
fn reset_kv_cache(&mut self) {
match self {
Self::M(m) => m.reset_kv_cache(),
Self::Q(m) => m.reset_kv_cache(),
}
}
}
#[wasm_bindgen]
pub struct Model {
model: SelectedModel,
tokenizer: TokenOutputStream,
}
const SEP_TOKEN_ID: u32 = 102;
#[wasm_bindgen]
impl Model {
#[wasm_bindgen(constructor)]
pub fn load(
weights: Vec<u8>,
tokenizer: Vec<u8>,
config: Vec<u8>,
quantized: bool,
) -> Result<Model, JsError> {
console_error_panic_hook::set_once();
console_log!("loading model");
let tokenizer =
Tokenizer::from_bytes(&tokenizer).map_err(|m| JsError::new(&m.to_string()))?;
let tokenizer = TokenOutputStream::new(tokenizer);
let config: blip::Config = serde_json::from_slice(&config)?;
let device = Device::Cpu;
let start = Date::now();
let model: SelectedModel = if quantized {
let vb = quantized_blip::VarBuilder::from_gguf_buffer(&weights)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
SelectedModel::Q(model)
} else {
let vb = VarBuilder::from_buffered_safetensors(weights, DType::F32, &device)?;
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
SelectedModel::M(model)
};
console_log!("model loaded in {:?}s", (Date::now() - start) / 1000.);
Ok(Self { model, tokenizer })
}
#[wasm_bindgen]
pub fn generate_caption_from_image(&mut self, image: Vec<u8>) -> Result<String, JsError> {
self.model.reset_kv_cache();
let device = Device::Cpu;
console_log!("loading image as tensor");
let start = Date::now();
let image: Tensor = self.load_image(image)?.to_device(&device)?;
console_log!("image loaded in {:?}s", (Date::now() - start) / 1000.);
let start = Date::now();
let image_embeds: Tensor = match &mut self.model {
SelectedModel::M(m) => image.unsqueeze(0)?.apply(m.vision_model())?,
SelectedModel::Q(m) => image.unsqueeze(0)?.apply(m.vision_model())?,
};
console_log!("image embedded in {:?}s", (Date::now() - start) / 1000.);
let mut logits_processor = LogitsProcessor::new(299792458, None, None);
let mut token_ids = vec![30522u32];
let mut text: String = "".to_string();
let start = Date::now();
for index in 0..1000 {
let context_size = if index > 0 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = self.model.text_decoder_forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
if let Some(t) = self.tokenizer.next_token(token)? {
text.push_str(&t);
}
}
if let Some(rest) = self
.tokenizer
.decode_rest()
.map_err(|m| JsError::new(&m.to_string()))?
{
text.push_str(&rest);
}
console_log!("caption generated in {:?}s", (Date::now() - start) / 1000.);
Ok(text)
}
}
impl Model {
fn load_image(&self, image: Vec<u8>) -> Result<Tensor, JsError> {
let device = &Device::Cpu;
let img = image::io::Reader::new(std::io::Cursor::new(image))
.with_guessed_format()?
.decode()
.map_err(|e| JsError::new(&e.to_string()))?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (384, 384, 3), device)?.permute((2, 0, 1))?;
let mean =
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], device)?.reshape((3, 1, 1))?;
let std =
Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], device)?.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
.map_err(|e| JsError::new(&e.to_string()))
}
}
fn main() {
console_error_panic_hook::set_once();
}

View File

@ -0,0 +1,17 @@
use wasm_bindgen::prelude::*;
pub mod token_output_stream;
#[wasm_bindgen]
extern "C" {
// Use `js_namespace` here to bind `console.log(..)` instead of just
// `log(..)`
#[wasm_bindgen(js_namespace = console)]
pub fn log(s: &str);
}
#[macro_export]
macro_rules! console_log {
// Note that this is using the `log` function imported above during
// `bare_bones`
($($t:tt)*) => ($crate::log(&format_args!($($t)*).to_string()))
}

View File

@ -0,0 +1,86 @@
use candle::Result;
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
/// streaming way rather than having to wait for the full decoding.
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
prev_index: usize,
current_index: usize,
}
impl TokenOutputStream {
pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
Self {
tokenizer,
tokens: Vec::new(),
prev_index: 0,
current_index: 0,
}
}
pub fn into_inner(self) -> tokenizers::Tokenizer {
self.tokenizer
}
fn decode(&self, tokens: &[u32]) -> Result<String> {
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => candle::bail!("cannot decode: {err}"),
}
}
// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
self.tokens.push(token);
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() && text.chars().last().unwrap().is_ascii() {
let text = text.split_at(prev_text.len());
self.prev_index = self.current_index;
self.current_index = self.tokens.len();
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_rest(&self) -> Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
} else {
let tokens = &self.tokens[self.prev_index..self.current_index];
self.decode(tokens)?
};
let text = self.decode(&self.tokens[self.prev_index..])?;
if text.len() > prev_text.len() {
let text = text.split_at(prev_text.len());
Ok(Some(text.1.to_string()))
} else {
Ok(None)
}
}
pub fn decode_all(&self) -> Result<String> {
self.decode(&self.tokens)
}
pub fn get_token(&self, token_s: &str) -> Option<u32> {
self.tokenizer.get_vocab(true).get(token_s).copied()
}
pub fn tokenizer(&self) -> &tokenizers::Tokenizer {
&self.tokenizer
}
pub fn clear(&mut self) {
self.tokens.clear();
self.prev_index = 0;
self.current_index = 0;
}
}

View File

@ -1,5 +1,7 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder};
use candle_nn::{
embedding, linear_no_bias as linear, rms_norm, Embedding, Linear, Module, RmsNorm, VarBuilder,
};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
@ -57,20 +59,6 @@ impl Cache {
}
}
fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)?
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
Ok(Linear::new(weight, None))
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.dim), "weight")?;
Ok(Embedding::new(embeddings, cfg.dim))
}
struct CausalSelfAttention {
q_proj: Linear,
k_proj: Linear,
@ -198,7 +186,7 @@ impl Mlp {
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
let x = (candle_nn::ops::silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x)
}
@ -283,7 +271,7 @@ impl Llama {
}
pub fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let wte = embedding(cfg, vb.pp("model.embed_tokens"))?;
let wte = embedding(cfg.vocab_size, cfg.dim, vb.pp("model.embed_tokens"))?;
let lm_head = linear(cfg.dim, cfg.vocab_size, vb.pp("lm_head"))?;
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("model.norm"))?;
let blocks: Vec<_> = (0..cfg.n_layers)